from collections import defaultdict import torch from torch import nn import torch.nn.functional as F import numpy as np import time from ..backbones import build_backbone from .hafm import HAFMencoder from .losses import * import math import cv2 import matplotlib.pyplot as plt class ScaleLSD(nn.Module): def __init__(self, gray_scale=False, use_layer_scale=False, enable_attention_hooks=False): super(ScaleLSD, self).__init__() num_junctions_inference = 512 junction_threshold_hm = 0.008 self.distance_threshold = 5.0 self.hafm_encoder = HAFMencoder(dis_th=self.distance_threshold) # self.backbone = build_backbone(gray_scale=gray_scale, use_layer_scale=use_layer_scale) self.backbone = build_backbone(gray_scale=gray_scale, use_layer_scale=use_layer_scale, enable_attention_hooks=enable_attention_hooks) self.j2l_threshold = 10 self.num_residuals = 0 self.loss = nn.CrossEntropyLoss(reduction='none') self.bce_loss = nn.BCEWithLogitsLoss(reduction='none') self.stride = self.backbone.stride self.train_step = 0 @classmethod def configure(cls, opts): try: cls.num_junctions_inference = opts.num_junctions cls.junction_threshold_hm = opts.junction_hm except: pass @classmethod def cli(cls, parser): try: parser.add_argument('-nj', '--num-junctions', default=512, type=int, help='number of junctions') parser.add_argument('-jh', '--junction-hm', default=0.008, type=float, help='junction threshold heatmap') except: pass def hafm_decoding(self,md_maps, dis_maps, residual_maps, scale=5.0, flatten = True, return_points = False): device = md_maps.device scale = self.distance_threshold batch_size, _, height, width = md_maps.shape _y = torch.arange(0,height,device=device).float() _x = torch.arange(0,width, device=device).float() y0, x0 =torch.meshgrid(_y, _x,indexing='ij') y0 = y0[None,None] x0 = x0[None,None] sign_pad = torch.arange(-self.num_residuals,self.num_residuals+1,device=device,dtype=torch.float32).reshape(1,-1,1,1) if residual_maps is not None: residual = residual_maps*sign_pad distance_fields = dis_maps + residual else: distance_fields = dis_maps distance_fields = distance_fields.clamp(min=0,max=1.0) md_un = (md_maps[:,:1] - 0.5)*np.pi*2 st_un = md_maps[:,1:2]*np.pi/2.0 ed_un = -md_maps[:,2:3]*np.pi/2.0 cs_md = md_un.cos() ss_md = md_un.sin() y_st = torch.tan(st_un) y_ed = torch.tan(ed_un) x_st_rotated = (cs_md - ss_md*y_st)*distance_fields*scale y_st_rotated = (ss_md + cs_md*y_st)*distance_fields*scale x_ed_rotated = (cs_md - ss_md*y_ed)*distance_fields*scale y_ed_rotated = (ss_md + cs_md*y_ed)*distance_fields*scale x_st_final = (x_st_rotated + x0).clamp(min=0,max=width-1) y_st_final = (y_st_rotated + y0).clamp(min=0,max=height-1) x_ed_final = (x_ed_rotated + x0).clamp(min=0,max=width-1) y_ed_final = (y_ed_rotated + y0).clamp(min=0,max=height-1) lines = torch.stack((x_st_final,y_st_final,x_ed_final,y_ed_final),dim=-1) if flatten: lines = lines.reshape(batch_size,-1,4) if return_points: points = torch.stack((x0,y0),dim=-1) points = points.repeat((batch_size,2*self.num_residuals+1,1,1,1)) if flatten: points = points.reshape(batch_size,-1,2) return lines, points return lines @staticmethod def non_maximum_suppression(a, kernel_size=3): ap = F.max_pool2d(a, kernel_size, stride=1, padding=kernel_size//2) mask = (a == ap).float().clamp(min=0.0) return a * mask @staticmethod def get_junctions(jloc, joff, topk = 300, th=0): height, width = jloc.size(1), jloc.size(2) jloc = jloc.reshape(-1) joff = joff.reshape(2, -1) scores, index = torch.topk(jloc, k=topk) # y = (index // width).float() + torch.gather(joff[1], 0, index) + 0.5 y = torch.div(index,width,rounding_mode='trunc').float()+ torch.gather(joff[1], 0, index) + 0.5 x = (index % width).float() + torch.gather(joff[0], 0, index) + 0.5 junctions = torch.stack((x, y)).t() if th>0 : return junctions[scores>th], scores[scores>th] else: return junctions, scores def wireframe_matcher(self, juncs_pred, lines_pred, hat_points, is_train=False): cost1 = torch.sum((lines_pred[:,:2]-juncs_pred[:,None])**2,dim=-1) cost2 = torch.sum((lines_pred[:,2:]-juncs_pred[:,None])**2,dim=-1) dis1, idx_junc_to_end1 = cost1.min(dim=0) dis2, idx_junc_to_end2 = cost2.min(dim=0) length = torch.sum((lines_pred[:,:2]-lines_pred[:,2:])**2,dim=-1) idx_junc_to_end_min = torch.min(idx_junc_to_end1,idx_junc_to_end2) idx_junc_to_end_max = torch.max(idx_junc_to_end1,idx_junc_to_end2) iskeep = idx_junc_to_end_min < idx_junc_to_end_max ## not the same junction if self.j2l_threshold>0: iskeep *= (dis10: for auxput in auxputs: loss_dict = self.compute_loss(auxput, targets, mask, loss_dict) for key in extra_info.keys(): extra_info[key] = extra_info[key]/batch_size return loss_dict, extra_info @torch.no_grad() def forward_test(self, images, annotations=None, merge=False): device = images.device batch_size, _, height, width = images.shape outputs, features, aux = self.forward_backbone(images) if "use_lsd" not in annotations.keys(): annotations["use_lsd"] = True # use lsd for theta prediction if annotations['use_lsd']: ws = images.shape[3]//self.stride hs = images.shape[2]//self.stride lsd = cv2.createLineSegmentDetector(0) md_lsd_batch = [] dis_lsd_batch = [] for i in range(batch_size): image = np.array(images[i,0].cpu().numpy()*255,dtype=np.uint8) lsd_lines = lsd.detect(image)[0].reshape(-1,4) # transform lsd lines to lsd-hat-field md_lsd, dis_lsd, _ = self.hafm_encoder.lines2hafm(torch.from_numpy(lsd_lines).to(images.device)/self.stride, hs, ws) md_lsd_batch.append(md_lsd) dis_lsd_batch.append(dis_lsd) md_pred = torch.stack(md_lsd_batch, dim=0) dis_pred = torch.stack(dis_lsd_batch, dim=0) # for junctions/endpoints extraction md_pred[:,1:3] = outputs[:,1:3].sigmoid() # dist dis_pred = outputs[:,3:4].sigmoid() jloc_pred= outputs[:,5:7].softmax(1)[:,1:] joff_pred= outputs[:,7:9].sigmoid() - 0.5 else: md_pred = outputs[:,:3].sigmoid() dis_pred = outputs[:,3:4].sigmoid() res_pred = outputs[:,4:5].sigmoid() jloc_pred= outputs[:,5:7].softmax(1)[:,1:] jloc_logits = outputs[:,5:7].softmax(1) joff_pred= outputs[:,7:9].sigmoid() - 0.5 lines_pred_batch, hat_points_batch = self.hafm_decoding(md_pred, dis_pred, None, flatten = True, return_points=True) output_list = [] graph_pred = torch.zeros((batch_size, self.num_junctions_inference, self.num_junctions_inference), device=images.device) for i in range(batch_size): if annotations['use_nms']: jloc_pred_nms = self.non_maximum_suppression(jloc_pred[i]) else: jloc_pred_nms = self.non_maximum_suppression(jloc_pred[i], kernel_size=1) topK = min(self.num_junctions_inference, int((jloc_pred_nms>self.junction_threshold_hm).float().sum().item())) juncs_pred, juncs_score = self.get_junctions(jloc_pred_nms,joff_pred[i], topk=topK, th=self.junction_threshold_hm) lines_adjusted, indices_adj, supports, hat_points, counts = self.wireframe_matcher(juncs_pred, lines_pred_batch[i], hat_points_batch[i]) jscales = torch.tensor( [ annotations['width']/md_pred.size(3), annotations['height']/md_pred.size(2) ], device=images.device ) junctions = juncs_pred * jscales supports = [_*self.stride for _ in supports] num_junctions = junctions.shape[0] graph_pred[i, indices_adj[:,0], indices_adj[:,1]] += counts.float() graph_pred[i, indices_adj[:,1], indices_adj[:,0]] += counts.float() graph_i = graph_pred[i,:num_junctions,:num_junctions] edges = graph_i.triu().nonzero() lines = junctions[edges].reshape(-1,4) scores = graph_pred[i, edges[:,0], edges[:,1]] output_list.append( { 'lines_pred': lines, 'lines_score': scores, 'juncs_pred': junctions, 'lines_support': supports, 'juncs_score': juncs_score, 'graph': graph_i, 'width': annotations['width'], 'height': annotations['height'], } ) return output_list, {}