| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import torch, math | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def ciou(bboxes1, bboxes2): | 
					
					
						
						| 
							 | 
						    bboxes1 = torch.sigmoid(bboxes1) | 
					
					
						
						| 
							 | 
						    bboxes2 = torch.sigmoid(bboxes2) | 
					
					
						
						| 
							 | 
						    rows = bboxes1.shape[0] | 
					
					
						
						| 
							 | 
						    cols = bboxes2.shape[0] | 
					
					
						
						| 
							 | 
						    cious = torch.zeros((rows, cols)) | 
					
					
						
						| 
							 | 
						    if rows * cols == 0: | 
					
					
						
						| 
							 | 
						        return cious | 
					
					
						
						| 
							 | 
						    exchange = False | 
					
					
						
						| 
							 | 
						    if bboxes1.shape[0] > bboxes2.shape[0]: | 
					
					
						
						| 
							 | 
						        bboxes1, bboxes2 = bboxes2, bboxes1 | 
					
					
						
						| 
							 | 
						        cious = torch.zeros((cols, rows)) | 
					
					
						
						| 
							 | 
						        exchange = True | 
					
					
						
						| 
							 | 
						    w1 = torch.exp(bboxes1[:, 2]) | 
					
					
						
						| 
							 | 
						    h1 = torch.exp(bboxes1[:, 3]) | 
					
					
						
						| 
							 | 
						    w2 = torch.exp(bboxes2[:, 2]) | 
					
					
						
						| 
							 | 
						    h2 = torch.exp(bboxes2[:, 3]) | 
					
					
						
						| 
							 | 
						    area1 = w1 * h1 | 
					
					
						
						| 
							 | 
						    area2 = w2 * h2 | 
					
					
						
						| 
							 | 
						    center_x1 = bboxes1[:, 0] | 
					
					
						
						| 
							 | 
						    center_y1 = bboxes1[:, 1] | 
					
					
						
						| 
							 | 
						    center_x2 = bboxes2[:, 0] | 
					
					
						
						| 
							 | 
						    center_y2 = bboxes2[:, 1] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    inter_l = torch.max(center_x1 - w1 / 2,center_x2 - w2 / 2) | 
					
					
						
						| 
							 | 
						    inter_r = torch.min(center_x1 + w1 / 2,center_x2 + w2 / 2) | 
					
					
						
						| 
							 | 
						    inter_t = torch.max(center_y1 - h1 / 2,center_y2 - h2 / 2) | 
					
					
						
						| 
							 | 
						    inter_b = torch.min(center_y1 + h1 / 2,center_y2 + h2 / 2) | 
					
					
						
						| 
							 | 
						    inter_area = torch.clamp((inter_r - inter_l),min=0) * torch.clamp((inter_b - inter_t),min=0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    c_l = torch.min(center_x1 - w1 / 2,center_x2 - w2 / 2) | 
					
					
						
						| 
							 | 
						    c_r = torch.max(center_x1 + w1 / 2,center_x2 + w2 / 2) | 
					
					
						
						| 
							 | 
						    c_t = torch.min(center_y1 - h1 / 2,center_y2 - h2 / 2) | 
					
					
						
						| 
							 | 
						    c_b = torch.max(center_y1 + h1 / 2,center_y2 + h2 / 2) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    inter_diag = (center_x2 - center_x1)**2 + (center_y2 - center_y1)**2 | 
					
					
						
						| 
							 | 
						    c_diag = torch.clamp((c_r - c_l),min=0)**2 + torch.clamp((c_b - c_t),min=0)**2 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    union = area1+area2-inter_area | 
					
					
						
						| 
							 | 
						    u = (inter_diag) / c_diag | 
					
					
						
						| 
							 | 
						    iou = inter_area / union | 
					
					
						
						| 
							 | 
						    v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(w2 / h2) - torch.atan(w1 / h1)), 2) | 
					
					
						
						| 
							 | 
						    with torch.no_grad(): | 
					
					
						
						| 
							 | 
						        S = (iou>0.5).float() | 
					
					
						
						| 
							 | 
						        alpha= S*v/(1-iou+v) | 
					
					
						
						| 
							 | 
						    cious = iou - u - alpha * v | 
					
					
						
						| 
							 | 
						    cious = torch.clamp(cious,min=-1.0,max = 1.0) | 
					
					
						
						| 
							 | 
						    if exchange: | 
					
					
						
						| 
							 | 
						        cious = cious.T | 
					
					
						
						| 
							 | 
						    return 1-cious | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def diou(bboxes1, bboxes2): | 
					
					
						
						| 
							 | 
						    bboxes1 = torch.sigmoid(bboxes1) | 
					
					
						
						| 
							 | 
						    bboxes2 = torch.sigmoid(bboxes2) | 
					
					
						
						| 
							 | 
						    rows = bboxes1.shape[0] | 
					
					
						
						| 
							 | 
						    cols = bboxes2.shape[0] | 
					
					
						
						| 
							 | 
						    cious = torch.zeros((rows, cols)) | 
					
					
						
						| 
							 | 
						    if rows * cols == 0: | 
					
					
						
						| 
							 | 
						        return cious | 
					
					
						
						| 
							 | 
						    exchange = False | 
					
					
						
						| 
							 | 
						    if bboxes1.shape[0] > bboxes2.shape[0]: | 
					
					
						
						| 
							 | 
						        bboxes1, bboxes2 = bboxes2, bboxes1 | 
					
					
						
						| 
							 | 
						        cious = torch.zeros((cols, rows)) | 
					
					
						
						| 
							 | 
						        exchange = True | 
					
					
						
						| 
							 | 
						    w1 = torch.exp(bboxes1[:, 2]) | 
					
					
						
						| 
							 | 
						    h1 = torch.exp(bboxes1[:, 3]) | 
					
					
						
						| 
							 | 
						    w2 = torch.exp(bboxes2[:, 2]) | 
					
					
						
						| 
							 | 
						    h2 = torch.exp(bboxes2[:, 3]) | 
					
					
						
						| 
							 | 
						    area1 = w1 * h1 | 
					
					
						
						| 
							 | 
						    area2 = w2 * h2 | 
					
					
						
						| 
							 | 
						    center_x1 = bboxes1[:, 0] | 
					
					
						
						| 
							 | 
						    center_y1 = bboxes1[:, 1] | 
					
					
						
						| 
							 | 
						    center_x2 = bboxes2[:, 0] | 
					
					
						
						| 
							 | 
						    center_y2 = bboxes2[:, 1] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    inter_l = torch.max(center_x1 - w1 / 2,center_x2 - w2 / 2) | 
					
					
						
						| 
							 | 
						    inter_r = torch.min(center_x1 + w1 / 2,center_x2 + w2 / 2) | 
					
					
						
						| 
							 | 
						    inter_t = torch.max(center_y1 - h1 / 2,center_y2 - h2 / 2) | 
					
					
						
						| 
							 | 
						    inter_b = torch.min(center_y1 + h1 / 2,center_y2 + h2 / 2) | 
					
					
						
						| 
							 | 
						    inter_area = torch.clamp((inter_r - inter_l),min=0) * torch.clamp((inter_b - inter_t),min=0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    c_l = torch.min(center_x1 - w1 / 2,center_x2 - w2 / 2) | 
					
					
						
						| 
							 | 
						    c_r = torch.max(center_x1 + w1 / 2,center_x2 + w2 / 2) | 
					
					
						
						| 
							 | 
						    c_t = torch.min(center_y1 - h1 / 2,center_y2 - h2 / 2) | 
					
					
						
						| 
							 | 
						    c_b = torch.max(center_y1 + h1 / 2,center_y2 + h2 / 2) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    inter_diag = (center_x2 - center_x1)**2 + (center_y2 - center_y1)**2 | 
					
					
						
						| 
							 | 
						    c_diag = torch.clamp((c_r - c_l),min=0)**2 + torch.clamp((c_b - c_t),min=0)**2 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    union = area1+area2-inter_area | 
					
					
						
						| 
							 | 
						    u = (inter_diag) / c_diag | 
					
					
						
						| 
							 | 
						    iou = inter_area / union | 
					
					
						
						| 
							 | 
						    dious = iou - u | 
					
					
						
						| 
							 | 
						    dious = torch.clamp(dious,min=-1.0,max = 1.0) | 
					
					
						
						| 
							 | 
						    if exchange: | 
					
					
						
						| 
							 | 
						        dious = dious.T | 
					
					
						
						| 
							 | 
						    return 1-dious | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if __name__ == "__main__": | 
					
					
						
						| 
							 | 
						    x = torch.rand(10, 4) | 
					
					
						
						| 
							 | 
						    y = torch.rand(10,4) | 
					
					
						
						| 
							 | 
						    import ipdb;ipdb.set_trace() | 
					
					
						
						| 
							 | 
						    cxy = ciou(x, y) | 
					
					
						
						| 
							 | 
						    dxy = diou(x, y) | 
					
					
						
						| 
							 | 
						    print(cxy.shape, dxy.shape) | 
					
					
						
						| 
							 | 
						
 |