""" "LiftFeat: 3D Geometry-Aware Local Feature Matching" training script """ import argparse import os import time import sys sys.path.append(os.path.dirname(__file__)) def parse_arguments(): parser = argparse.ArgumentParser(description="LiftFeat training script.") parser.add_argument('--name',type=str,default='LiftFeat',help='set process name') # MegaDepth dataset setting parser.add_argument('--use_megadepth',action='store_true') parser.add_argument('--megadepth_root_path', type=str, default='/home/yepeng_liu/code_python/dataset/MegaDepth/phoenix/S6/zl548', help='Path to the MegaDepth dataset root directory.') parser.add_argument('--megadepth_batch_size', type=int, default=6) # COCO20k dataset setting parser.add_argument('--use_coco',action='store_true') parser.add_argument('--coco_root_path', type=str, default='/home/yepeng_liu/code_python/dataset/coco_20k', help='Path to the COCO20k dataset root directory.') parser.add_argument('--coco_batch_size',type=int,default=4) parser.add_argument('--ckpt_save_path', type=str, default='/home/yepeng_liu/code_python/LiftFeat/trained_weights/test', help='Path to save the checkpoints.') parser.add_argument('--n_steps', type=int, default=160_000, help='Number of training steps. Default is 160000.') parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate. Default is 0.0003.') parser.add_argument('--gamma_steplr', type=float, default=0.5, help='Gamma value for StepLR scheduler. Default is 0.5.') parser.add_argument('--training_res', type=lambda s: tuple(map(int, s.split(','))), default=(800, 608), help='Training resolution as width,height. Default is (800, 608).') parser.add_argument('--device_num', type=str, default='0', help='Device number to use for training. Default is "0".') parser.add_argument('--dry_run', action='store_true', help='If set, perform a dry run training with a mini-batch for sanity check.') parser.add_argument('--save_ckpt_every', type=int, default=500, help='Save checkpoints every N steps. Default is 500.') args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = args.device_num return args args = parse_arguments() import torch from torch import nn from torch import optim import torch.nn.functional as F from torch.utils.tensorboard import SummaryWriter from torch.utils.data import Dataset, DataLoader import numpy as np import tqdm import glob from models.model import LiftFeatSPModel from loss.loss import LiftFeatLoss from utils.config import featureboost_config from models.interpolator import InterpolateSparse2d from utils.depth_anything_wrapper import DepthAnythingExtractor from utils.alike_wrapper import ALikeExtractor from dataset import megadepth_wrapper from dataset import coco_wrapper from dataset.megadepth import MegaDepthDataset from dataset.coco_augmentor import COCOAugmentor import setproctitle class Trainer(): def __init__(self, megadepth_root_path,use_megadepth,megadepth_batch_size, coco_root_path,use_coco,coco_batch_size, ckpt_save_path, model_name = 'LiftFeat', n_steps = 160_000, lr= 3e-4, gamma_steplr=0.5, training_res = (800, 608), device_num="0", dry_run = False, save_ckpt_every = 500): print(f'MegeDepth: {use_megadepth}-{megadepth_batch_size}') print(f'COCO20k: {use_coco}-{coco_batch_size}') self.dev = torch.device ('cuda' if torch.cuda.is_available() else 'cpu') # training model self.net = LiftFeatSPModel(featureboost_config, use_kenc=False, use_normal=True, use_cross=True).to(self.dev) self.loss_fn=LiftFeatLoss(self.dev,lam_descs=1,lam_kpts=2,lam_heatmap=1) # depth-anything model self.depth_net=DepthAnythingExtractor('vits',self.dev,256) # alike model self.alike_net=ALikeExtractor('alike-t',self.dev) #Setup optimizer self.steps = n_steps self.opt = optim.Adam(filter(lambda x: x.requires_grad, self.net.parameters()) , lr = lr) self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=10_000, gamma=gamma_steplr) ##################### COCO INIT ########################## self.use_coco=use_coco self.coco_batch_size=coco_batch_size if self.use_coco: self.augmentor=COCOAugmentor( img_dir=coco_root_path, device=self.dev,load_dataset=True, batch_size=self.coco_batch_size, out_resolution=training_res, warp_resolution=training_res, sides_crop=0.1, max_num_imgs=3000, num_test_imgs=5, photometric=True, geometric=True, reload_step=4000 ) ##################### COCO END ####################### ##################### MEGADEPTH INIT ########################## self.use_megadepth=use_megadepth self.megadepth_batch_size=megadepth_batch_size if self.use_megadepth: TRAIN_BASE_PATH = f"{megadepth_root_path}/train_data/megadepth_indices" TRAINVAL_DATA_SOURCE = f"{megadepth_root_path}/MegaDepth_v1" TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7" npz_paths = glob.glob(TRAIN_NPZ_ROOT + '/*.npz')[:] megadepth_dataset = torch.utils.data.ConcatDataset( [MegaDepthDataset(root_dir = TRAINVAL_DATA_SOURCE, npz_path = path) for path in tqdm.tqdm(npz_paths, desc="[MegaDepth] Loading metadata")] ) self.megadepth_dataloader = DataLoader(megadepth_dataset, batch_size=megadepth_batch_size, shuffle=True) self.megadepth_data_iter = iter(self.megadepth_dataloader) ##################### MEGADEPTH INIT END ####################### os.makedirs(ckpt_save_path, exist_ok=True) os.makedirs(ckpt_save_path + '/logdir', exist_ok=True) self.dry_run = dry_run self.save_ckpt_every = save_ckpt_every self.ckpt_save_path = ckpt_save_path self.writer = SummaryWriter(ckpt_save_path + f'/logdir/{model_name}_' + time.strftime("%Y_%m_%d-%H_%M_%S")) self.model_name = model_name def generate_train_data(self): imgs1_t,imgs2_t=[],[] imgs1_np,imgs2_np=[],[] # norms0,norms1=[],[] positives_coarse=[] if self.use_coco: coco_imgs1, coco_imgs2, H1, H2 = coco_wrapper.make_batch(self.augmentor, 0.1) h_coarse, w_coarse = coco_imgs1[0].shape[-2] // 8, coco_imgs1[0].shape[-1] // 8 _ , positives_coco_coarse = coco_wrapper.get_corresponding_pts(coco_imgs1, coco_imgs2, H1, H2, self.augmentor, h_coarse, w_coarse) coco_imgs1=coco_imgs1.mean(1,keepdim=True);coco_imgs2=coco_imgs2.mean(1,keepdim=True) imgs1_t.append(coco_imgs1);imgs2_t.append(coco_imgs2) positives_coarse += positives_coco_coarse if self.use_megadepth: try: megadepth_data=next(self.megadepth_data_iter) except StopIteration: print('End of MD DATASET') self.megadepth_data_iter=iter(self.megadepth_dataloader) megadepth_data=next(self.megadepth_data_iter) if megadepth_data is not None: for k in megadepth_data.keys(): if isinstance(megadepth_data[k],torch.Tensor): megadepth_data[k]=megadepth_data[k].to(self.dev) megadepth_imgs1_t,megadepth_imgs2_t=megadepth_data['image0'],megadepth_data['image1'] megadepth_imgs1_t=megadepth_imgs1_t.mean(1,keepdim=True);megadepth_imgs2_t=megadepth_imgs2_t.mean(1,keepdim=True) imgs1_t.append(megadepth_imgs1_t);imgs2_t.append(megadepth_imgs2_t) megadepth_imgs1_np,megadepth_imgs2_np=megadepth_data['image0_np'],megadepth_data['image1_np'] for np_idx in range(megadepth_imgs1_np.shape[0]): img1_np,img2_np=megadepth_imgs1_np[np_idx].squeeze(0).cpu().numpy(),megadepth_imgs2_np[np_idx].squeeze(0).cpu().numpy() imgs1_np.append(img1_np);imgs2_np.append(img2_np) positives_megadepth_coarse=megadepth_wrapper.spvs_coarse(megadepth_data,8) positives_coarse += positives_megadepth_coarse with torch.no_grad(): imgs1_t=torch.cat(imgs1_t,dim=0) imgs2_t=torch.cat(imgs2_t,dim=0) return imgs1_t,imgs2_t,imgs1_np,imgs2_np,positives_coarse def train(self): self.net.train() with tqdm.tqdm(total=self.steps) as pbar: for i in range(self.steps): # import pdb;pdb.set_trace() imgs1_t,imgs2_t,imgs1_np,imgs2_np,positives_coarse=self.generate_train_data() #Check if batch is corrupted with too few correspondences is_corrupted = False for p in positives_coarse: if len(p) < 30: is_corrupted = True if is_corrupted: continue # import pdb;pdb.set_trace() #Forward pass # start=time.perf_counter() feats1,kpts1,normals1 = self.net.forward1(imgs1_t) feats2,kpts2,normals2 = self.net.forward1(imgs2_t) coordinates,fb_coordinates=[],[] alike_kpts1,alike_kpts2=[],[] DA_normals1,DA_normals2=[],[] # import pdb;pdb.set_trace() fb_feats1,fb_feats2=[],[] for b in range(feats1.shape[0]): feat1=feats1[b].permute(1,2,0).reshape(-1,feats1.shape[1]) feat2=feats2[b].permute(1,2,0).reshape(-1,feats2.shape[1]) coordinate=self.net.fine_matcher(torch.cat([feat1,feat2],dim=-1)) coordinates.append(coordinate) fb_feat1=self.net.forward2(feats1[b].unsqueeze(0),kpts1[b].unsqueeze(0),normals1[b].unsqueeze(0)) fb_feat2=self.net.forward2(feats2[b].unsqueeze(0),kpts2[b].unsqueeze(0),normals2[b].unsqueeze(0)) fb_coordinate=self.net.fine_matcher(torch.cat([fb_feat1,fb_feat2],dim=-1)) fb_coordinates.append(fb_coordinate) fb_feats1.append(fb_feat1.unsqueeze(0));fb_feats2.append(fb_feat2.unsqueeze(0)) img1,img2=imgs1_t[b],imgs2_t[b] img1=img1.permute(1,2,0).expand(-1,-1,3).cpu().numpy() * 255 img2=img2.permute(1,2,0).expand(-1,-1,3).cpu().numpy() * 255 alike_kpt1=torch.tensor(self.alike_net.extract_alike_kpts(img1),device=self.dev) alike_kpt2=torch.tensor(self.alike_net.extract_alike_kpts(img2),device=self.dev) alike_kpts1.append(alike_kpt1);alike_kpts2.append(alike_kpt2) # import pdb;pdb.set_trace() for b in range(len(imgs1_np)): megadepth_depth1,megadepth_norm1=self.depth_net.extract(imgs1_np[b]) megadepth_depth2,megadepth_norm2=self.depth_net.extract(imgs2_np[b]) DA_normals1.append(megadepth_norm1);DA_normals2.append(megadepth_norm2) # import pdb;pdb.set_trace() fb_feats1=torch.cat(fb_feats1,dim=0) fb_feats2=torch.cat(fb_feats2,dim=0) fb_feats1=fb_feats1.reshape(feats1.shape[0],feats1.shape[2],feats1.shape[3],-1).permute(0,3,1,2) fb_feats2=fb_feats2.reshape(feats2.shape[0],feats2.shape[2],feats2.shape[3],-1).permute(0,3,1,2) coordinates=torch.cat(coordinates,dim=0) coordinates=coordinates.reshape(feats1.shape[0],feats1.shape[2],feats1.shape[3],-1).permute(0,3,1,2) fb_coordinates=torch.cat(fb_coordinates,dim=0) fb_coordinates=fb_coordinates.reshape(feats1.shape[0],feats1.shape[2],feats1.shape[3],-1).permute(0,3,1,2) # end=time.perf_counter() # print(f"forward1 cost {end-start} seconds") loss_items = [] # import pdb;pdb.set_trace() loss_info=self.loss_fn( feats1,fb_feats1,kpts1,normals1, feats2,fb_feats2,kpts2,normals2, positives_coarse, coordinates,fb_coordinates, alike_kpts1,alike_kpts2, DA_normals1,DA_normals2, self.megadepth_batch_size,self.coco_batch_size) loss_descs,acc_coarse=loss_info['loss_descs'],loss_info['acc_coarse'] loss_coordinates,acc_coordinates=loss_info['loss_coordinates'],loss_info['acc_coordinates'] loss_fb_descs,acc_fb_coarse=loss_info['loss_fb_descs'],loss_info['acc_fb_coarse'] loss_fb_coordinates,acc_fb_coordinates=loss_info['loss_fb_coordinates'],loss_info['acc_fb_coordinates'] loss_kpts,acc_kpt=loss_info['loss_kpts'],loss_info['acc_kpt'] loss_normals=loss_info['loss_normals'] # loss_items.append(loss_descs.unsqueeze(0)) # loss_items.append(loss_coordinates.unsqueeze(0)) loss_items.append(loss_fb_descs.unsqueeze(0)) loss_items.append(loss_fb_coordinates.unsqueeze(0)) loss_items.append(loss_kpts.unsqueeze(0)) loss_items.append(loss_normals.unsqueeze(0)) # nb_coarse = len(m1) # nb_coarse = len(fb_m1) loss = torch.cat(loss_items, -1).mean() # Compute Backward Pass loss.backward() torch.nn.utils.clip_grad_norm_(self.net.parameters(), 1.) self.opt.step() self.opt.zero_grad() self.scheduler.step() # import pdb;pdb.set_trace() if (i+1) % self.save_ckpt_every == 0: print('saving iter ', i+1) torch.save(self.net.state_dict(), self.ckpt_save_path + f'/{self.model_name}_{i+1}.pth') pbar.set_description( 'Loss: {:.4f} \ loss_descs: {:.3f} acc_coarse: {:.3f} \ loss_coordinates: {:.3f} acc_coordinates: {:.3f} \ loss_fb_descs: {:.3f} acc_fb_coarse: {:.3f} \ loss_fb_coordinates: {:.3f} acc_fb_coordinates: {:.3f} \ loss_kpts: {:.3f} acc_kpts: {:.3f} \ loss_normals: {:.3f}'.format( \ loss.item(), \ loss_descs.item(), acc_coarse, \ loss_coordinates.item(), acc_coordinates, \ loss_fb_descs.item(), acc_fb_coarse, \ loss_fb_coordinates.item(), acc_fb_coordinates, \ loss_kpts.item(), acc_kpt, \ loss_normals.item()) ) pbar.update(1) # Log metrics self.writer.add_scalar('Loss/total', loss.item(), i) self.writer.add_scalar('Accuracy/acc_coarse', acc_coarse, i) self.writer.add_scalar('Accuracy/acc_coordinates', acc_coordinates, i) self.writer.add_scalar('Accuracy/acc_fb_coarse', acc_fb_coarse, i) self.writer.add_scalar('Accuracy/acc_fb_coordinates', acc_fb_coordinates, i) self.writer.add_scalar('Loss/descs', loss_descs.item(), i) self.writer.add_scalar('Loss/coordinates', loss_coordinates.item(), i) self.writer.add_scalar('Loss/fb_descs', loss_fb_descs.item(), i) self.writer.add_scalar('Loss/fb_coordinates', loss_fb_coordinates.item(), i) self.writer.add_scalar('Loss/kpts', loss_kpts.item(), i) self.writer.add_scalar('Loss/normals', loss_normals.item(), i) if __name__ == '__main__': setproctitle.setproctitle(args.name) trainer = Trainer( megadepth_root_path=args.megadepth_root_path, use_megadepth=args.use_megadepth, megadepth_batch_size=args.megadepth_batch_size, coco_root_path=args.coco_root_path, use_coco=args.use_coco, coco_batch_size=args.coco_batch_size, ckpt_save_path=args.ckpt_save_path, n_steps=args.n_steps, lr=args.lr, gamma_steplr=args.gamma_steplr, training_res=args.training_res, device_num=args.device_num, dry_run=args.dry_run, save_ckpt_every=args.save_ckpt_every ) #The most fun part trainer.train()