Realcat's picture
add: liftfeat
13760e8
"""
"LiftFeat: 3D Geometry-Aware Local Feature Matching"
training script
"""
import argparse
import os
import time
import sys
sys.path.append(os.path.dirname(__file__))
def parse_arguments():
parser = argparse.ArgumentParser(description="LiftFeat training script.")
parser.add_argument('--name',type=str,default='LiftFeat',help='set process name')
# MegaDepth dataset setting
parser.add_argument('--use_megadepth',action='store_true')
parser.add_argument('--megadepth_root_path', type=str,
default='/home/yepeng_liu/code_python/dataset/MegaDepth/phoenix/S6/zl548',
help='Path to the MegaDepth dataset root directory.')
parser.add_argument('--megadepth_batch_size', type=int, default=6)
# COCO20k dataset setting
parser.add_argument('--use_coco',action='store_true')
parser.add_argument('--coco_root_path', type=str, default='/home/yepeng_liu/code_python/dataset/coco_20k',
help='Path to the COCO20k dataset root directory.')
parser.add_argument('--coco_batch_size',type=int,default=4)
parser.add_argument('--ckpt_save_path', type=str, default='/home/yepeng_liu/code_python/LiftFeat/trained_weights/test',
help='Path to save the checkpoints.')
parser.add_argument('--n_steps', type=int, default=160_000,
help='Number of training steps. Default is 160000.')
parser.add_argument('--lr', type=float, default=3e-4,
help='Learning rate. Default is 0.0003.')
parser.add_argument('--gamma_steplr', type=float, default=0.5,
help='Gamma value for StepLR scheduler. Default is 0.5.')
parser.add_argument('--training_res', type=lambda s: tuple(map(int, s.split(','))),
default=(800, 608), help='Training resolution as width,height. Default is (800, 608).')
parser.add_argument('--device_num', type=str, default='0',
help='Device number to use for training. Default is "0".')
parser.add_argument('--dry_run', action='store_true',
help='If set, perform a dry run training with a mini-batch for sanity check.')
parser.add_argument('--save_ckpt_every', type=int, default=500,
help='Save checkpoints every N steps. Default is 500.')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_num
return args
args = parse_arguments()
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
import numpy as np
import tqdm
import glob
from models.model import LiftFeatSPModel
from loss.loss import LiftFeatLoss
from utils.config import featureboost_config
from models.interpolator import InterpolateSparse2d
from utils.depth_anything_wrapper import DepthAnythingExtractor
from utils.alike_wrapper import ALikeExtractor
from dataset import megadepth_wrapper
from dataset import coco_wrapper
from dataset.megadepth import MegaDepthDataset
from dataset.coco_augmentor import COCOAugmentor
import setproctitle
class Trainer():
def __init__(self, megadepth_root_path,use_megadepth,megadepth_batch_size,
coco_root_path,use_coco,coco_batch_size,
ckpt_save_path,
model_name = 'LiftFeat',
n_steps = 160_000, lr= 3e-4, gamma_steplr=0.5,
training_res = (800, 608), device_num="0", dry_run = False,
save_ckpt_every = 500):
print(f'MegeDepth: {use_megadepth}-{megadepth_batch_size}')
print(f'COCO20k: {use_coco}-{coco_batch_size}')
self.dev = torch.device ('cuda' if torch.cuda.is_available() else 'cpu')
# training model
self.net = LiftFeatSPModel(featureboost_config, use_kenc=False, use_normal=True, use_cross=True).to(self.dev)
self.loss_fn=LiftFeatLoss(self.dev,lam_descs=1,lam_kpts=2,lam_heatmap=1)
# depth-anything model
self.depth_net=DepthAnythingExtractor('vits',self.dev,256)
# alike model
self.alike_net=ALikeExtractor('alike-t',self.dev)
#Setup optimizer
self.steps = n_steps
self.opt = optim.Adam(filter(lambda x: x.requires_grad, self.net.parameters()) , lr = lr)
self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=10_000, gamma=gamma_steplr)
##################### COCO INIT ##########################
self.use_coco=use_coco
self.coco_batch_size=coco_batch_size
if self.use_coco:
self.augmentor=COCOAugmentor(
img_dir=coco_root_path,
device=self.dev,load_dataset=True,
batch_size=self.coco_batch_size,
out_resolution=training_res,
warp_resolution=training_res,
sides_crop=0.1,
max_num_imgs=3000,
num_test_imgs=5,
photometric=True,
geometric=True,
reload_step=4000
)
##################### COCO END #######################
##################### MEGADEPTH INIT ##########################
self.use_megadepth=use_megadepth
self.megadepth_batch_size=megadepth_batch_size
if self.use_megadepth:
TRAIN_BASE_PATH = f"{megadepth_root_path}/train_data/megadepth_indices"
TRAINVAL_DATA_SOURCE = f"{megadepth_root_path}/MegaDepth_v1"
TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7"
npz_paths = glob.glob(TRAIN_NPZ_ROOT + '/*.npz')[:]
megadepth_dataset = torch.utils.data.ConcatDataset( [MegaDepthDataset(root_dir = TRAINVAL_DATA_SOURCE,
npz_path = path) for path in tqdm.tqdm(npz_paths, desc="[MegaDepth] Loading metadata")] )
self.megadepth_dataloader = DataLoader(megadepth_dataset, batch_size=megadepth_batch_size, shuffle=True)
self.megadepth_data_iter = iter(self.megadepth_dataloader)
##################### MEGADEPTH INIT END #######################
os.makedirs(ckpt_save_path, exist_ok=True)
os.makedirs(ckpt_save_path + '/logdir', exist_ok=True)
self.dry_run = dry_run
self.save_ckpt_every = save_ckpt_every
self.ckpt_save_path = ckpt_save_path
self.writer = SummaryWriter(ckpt_save_path + f'/logdir/{model_name}_' + time.strftime("%Y_%m_%d-%H_%M_%S"))
self.model_name = model_name
def generate_train_data(self):
imgs1_t,imgs2_t=[],[]
imgs1_np,imgs2_np=[],[]
# norms0,norms1=[],[]
positives_coarse=[]
if self.use_coco:
coco_imgs1, coco_imgs2, H1, H2 = coco_wrapper.make_batch(self.augmentor, 0.1)
h_coarse, w_coarse = coco_imgs1[0].shape[-2] // 8, coco_imgs1[0].shape[-1] // 8
_ , positives_coco_coarse = coco_wrapper.get_corresponding_pts(coco_imgs1, coco_imgs2, H1, H2, self.augmentor, h_coarse, w_coarse)
coco_imgs1=coco_imgs1.mean(1,keepdim=True);coco_imgs2=coco_imgs2.mean(1,keepdim=True)
imgs1_t.append(coco_imgs1);imgs2_t.append(coco_imgs2)
positives_coarse += positives_coco_coarse
if self.use_megadepth:
try:
megadepth_data=next(self.megadepth_data_iter)
except StopIteration:
print('End of MD DATASET')
self.megadepth_data_iter=iter(self.megadepth_dataloader)
megadepth_data=next(self.megadepth_data_iter)
if megadepth_data is not None:
for k in megadepth_data.keys():
if isinstance(megadepth_data[k],torch.Tensor):
megadepth_data[k]=megadepth_data[k].to(self.dev)
megadepth_imgs1_t,megadepth_imgs2_t=megadepth_data['image0'],megadepth_data['image1']
megadepth_imgs1_t=megadepth_imgs1_t.mean(1,keepdim=True);megadepth_imgs2_t=megadepth_imgs2_t.mean(1,keepdim=True)
imgs1_t.append(megadepth_imgs1_t);imgs2_t.append(megadepth_imgs2_t)
megadepth_imgs1_np,megadepth_imgs2_np=megadepth_data['image0_np'],megadepth_data['image1_np']
for np_idx in range(megadepth_imgs1_np.shape[0]):
img1_np,img2_np=megadepth_imgs1_np[np_idx].squeeze(0).cpu().numpy(),megadepth_imgs2_np[np_idx].squeeze(0).cpu().numpy()
imgs1_np.append(img1_np);imgs2_np.append(img2_np)
positives_megadepth_coarse=megadepth_wrapper.spvs_coarse(megadepth_data,8)
positives_coarse += positives_megadepth_coarse
with torch.no_grad():
imgs1_t=torch.cat(imgs1_t,dim=0)
imgs2_t=torch.cat(imgs2_t,dim=0)
return imgs1_t,imgs2_t,imgs1_np,imgs2_np,positives_coarse
def train(self):
self.net.train()
with tqdm.tqdm(total=self.steps) as pbar:
for i in range(self.steps):
# import pdb;pdb.set_trace()
imgs1_t,imgs2_t,imgs1_np,imgs2_np,positives_coarse=self.generate_train_data()
#Check if batch is corrupted with too few correspondences
is_corrupted = False
for p in positives_coarse:
if len(p) < 30:
is_corrupted = True
if is_corrupted:
continue
# import pdb;pdb.set_trace()
#Forward pass
# start=time.perf_counter()
feats1,kpts1,normals1 = self.net.forward1(imgs1_t)
feats2,kpts2,normals2 = self.net.forward1(imgs2_t)
coordinates,fb_coordinates=[],[]
alike_kpts1,alike_kpts2=[],[]
DA_normals1,DA_normals2=[],[]
# import pdb;pdb.set_trace()
fb_feats1,fb_feats2=[],[]
for b in range(feats1.shape[0]):
feat1=feats1[b].permute(1,2,0).reshape(-1,feats1.shape[1])
feat2=feats2[b].permute(1,2,0).reshape(-1,feats2.shape[1])
coordinate=self.net.fine_matcher(torch.cat([feat1,feat2],dim=-1))
coordinates.append(coordinate)
fb_feat1=self.net.forward2(feats1[b].unsqueeze(0),kpts1[b].unsqueeze(0),normals1[b].unsqueeze(0))
fb_feat2=self.net.forward2(feats2[b].unsqueeze(0),kpts2[b].unsqueeze(0),normals2[b].unsqueeze(0))
fb_coordinate=self.net.fine_matcher(torch.cat([fb_feat1,fb_feat2],dim=-1))
fb_coordinates.append(fb_coordinate)
fb_feats1.append(fb_feat1.unsqueeze(0));fb_feats2.append(fb_feat2.unsqueeze(0))
img1,img2=imgs1_t[b],imgs2_t[b]
img1=img1.permute(1,2,0).expand(-1,-1,3).cpu().numpy() * 255
img2=img2.permute(1,2,0).expand(-1,-1,3).cpu().numpy() * 255
alike_kpt1=torch.tensor(self.alike_net.extract_alike_kpts(img1),device=self.dev)
alike_kpt2=torch.tensor(self.alike_net.extract_alike_kpts(img2),device=self.dev)
alike_kpts1.append(alike_kpt1);alike_kpts2.append(alike_kpt2)
# import pdb;pdb.set_trace()
for b in range(len(imgs1_np)):
megadepth_depth1,megadepth_norm1=self.depth_net.extract(imgs1_np[b])
megadepth_depth2,megadepth_norm2=self.depth_net.extract(imgs2_np[b])
DA_normals1.append(megadepth_norm1);DA_normals2.append(megadepth_norm2)
# import pdb;pdb.set_trace()
fb_feats1=torch.cat(fb_feats1,dim=0)
fb_feats2=torch.cat(fb_feats2,dim=0)
fb_feats1=fb_feats1.reshape(feats1.shape[0],feats1.shape[2],feats1.shape[3],-1).permute(0,3,1,2)
fb_feats2=fb_feats2.reshape(feats2.shape[0],feats2.shape[2],feats2.shape[3],-1).permute(0,3,1,2)
coordinates=torch.cat(coordinates,dim=0)
coordinates=coordinates.reshape(feats1.shape[0],feats1.shape[2],feats1.shape[3],-1).permute(0,3,1,2)
fb_coordinates=torch.cat(fb_coordinates,dim=0)
fb_coordinates=fb_coordinates.reshape(feats1.shape[0],feats1.shape[2],feats1.shape[3],-1).permute(0,3,1,2)
# end=time.perf_counter()
# print(f"forward1 cost {end-start} seconds")
loss_items = []
# import pdb;pdb.set_trace()
loss_info=self.loss_fn(
feats1,fb_feats1,kpts1,normals1,
feats2,fb_feats2,kpts2,normals2,
positives_coarse,
coordinates,fb_coordinates,
alike_kpts1,alike_kpts2,
DA_normals1,DA_normals2,
self.megadepth_batch_size,self.coco_batch_size)
loss_descs,acc_coarse=loss_info['loss_descs'],loss_info['acc_coarse']
loss_coordinates,acc_coordinates=loss_info['loss_coordinates'],loss_info['acc_coordinates']
loss_fb_descs,acc_fb_coarse=loss_info['loss_fb_descs'],loss_info['acc_fb_coarse']
loss_fb_coordinates,acc_fb_coordinates=loss_info['loss_fb_coordinates'],loss_info['acc_fb_coordinates']
loss_kpts,acc_kpt=loss_info['loss_kpts'],loss_info['acc_kpt']
loss_normals=loss_info['loss_normals']
# loss_items.append(loss_descs.unsqueeze(0))
# loss_items.append(loss_coordinates.unsqueeze(0))
loss_items.append(loss_fb_descs.unsqueeze(0))
loss_items.append(loss_fb_coordinates.unsqueeze(0))
loss_items.append(loss_kpts.unsqueeze(0))
loss_items.append(loss_normals.unsqueeze(0))
# nb_coarse = len(m1)
# nb_coarse = len(fb_m1)
loss = torch.cat(loss_items, -1).mean()
# Compute Backward Pass
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 1.)
self.opt.step()
self.opt.zero_grad()
self.scheduler.step()
# import pdb;pdb.set_trace()
if (i+1) % self.save_ckpt_every == 0:
print('saving iter ', i+1)
torch.save(self.net.state_dict(), self.ckpt_save_path + f'/{self.model_name}_{i+1}.pth')
pbar.set_description(
'Loss: {:.4f} \
loss_descs: {:.3f} acc_coarse: {:.3f} \
loss_coordinates: {:.3f} acc_coordinates: {:.3f} \
loss_fb_descs: {:.3f} acc_fb_coarse: {:.3f} \
loss_fb_coordinates: {:.3f} acc_fb_coordinates: {:.3f} \
loss_kpts: {:.3f} acc_kpts: {:.3f} \
loss_normals: {:.3f}'.format( \
loss.item(), \
loss_descs.item(), acc_coarse, \
loss_coordinates.item(), acc_coordinates, \
loss_fb_descs.item(), acc_fb_coarse, \
loss_fb_coordinates.item(), acc_fb_coordinates, \
loss_kpts.item(), acc_kpt, \
loss_normals.item()) )
pbar.update(1)
# Log metrics
self.writer.add_scalar('Loss/total', loss.item(), i)
self.writer.add_scalar('Accuracy/acc_coarse', acc_coarse, i)
self.writer.add_scalar('Accuracy/acc_coordinates', acc_coordinates, i)
self.writer.add_scalar('Accuracy/acc_fb_coarse', acc_fb_coarse, i)
self.writer.add_scalar('Accuracy/acc_fb_coordinates', acc_fb_coordinates, i)
self.writer.add_scalar('Loss/descs', loss_descs.item(), i)
self.writer.add_scalar('Loss/coordinates', loss_coordinates.item(), i)
self.writer.add_scalar('Loss/fb_descs', loss_fb_descs.item(), i)
self.writer.add_scalar('Loss/fb_coordinates', loss_fb_coordinates.item(), i)
self.writer.add_scalar('Loss/kpts', loss_kpts.item(), i)
self.writer.add_scalar('Loss/normals', loss_normals.item(), i)
if __name__ == '__main__':
setproctitle.setproctitle(args.name)
trainer = Trainer(
megadepth_root_path=args.megadepth_root_path,
use_megadepth=args.use_megadepth,
megadepth_batch_size=args.megadepth_batch_size,
coco_root_path=args.coco_root_path,
use_coco=args.use_coco,
coco_batch_size=args.coco_batch_size,
ckpt_save_path=args.ckpt_save_path,
n_steps=args.n_steps,
lr=args.lr,
gamma_steplr=args.gamma_steplr,
training_res=args.training_res,
device_num=args.device_num,
dry_run=args.dry_run,
save_ckpt_every=args.save_ckpt_every
)
#The most fun part
trainer.train()