Spaces:
Build error
Build error
# -------------------------------------------------------- | |
# DenseFusion 6D Object Pose Estimation by Iterative Dense Fusion | |
# Licensed under The MIT License [see LICENSE for details] | |
# Written by Chen | |
# -------------------------------------------------------- | |
import _init_paths | |
import argparse | |
import os | |
import random | |
import time | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.parallel | |
import torch.backends.cudnn as cudnn | |
import torch.optim as optim | |
import torch.utils.data | |
import torchvision.datasets as dset | |
import torchvision.transforms as transforms | |
import torchvision.utils as vutils | |
from torch.autograd import Variable | |
from datasets.ycb.dataset import PoseDataset as PoseDataset_ycb | |
from datasets.linemod.dataset import PoseDataset as PoseDataset_linemod | |
from lib.network import PoseNet, PoseRefineNet | |
from lib.loss import Loss | |
from lib.loss_refiner import Loss_refine | |
from lib.utils import setup_logger | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--dataset', type=str, default = 'ycb', help='ycb or linemod') | |
parser.add_argument('--dataset_root', type=str, default = '', help='dataset root dir (''YCB_Video_Dataset'' or ''Linemod_preprocessed'')') | |
parser.add_argument('--batch_size', type=int, default = 8, help='batch size') | |
parser.add_argument('--workers', type=int, default = 10, help='number of data loading workers') | |
parser.add_argument('--lr', default=0.0001, help='learning rate') | |
parser.add_argument('--lr_rate', default=0.3, help='learning rate decay rate') | |
parser.add_argument('--w', default=0.015, help='learning rate') | |
parser.add_argument('--w_rate', default=0.3, help='learning rate decay rate') | |
parser.add_argument('--decay_margin', default=0.016, help='margin to decay lr & w') | |
parser.add_argument('--refine_margin', default=0.013, help='margin to start the training of iterative refinement') | |
parser.add_argument('--noise_trans', default=0.03, help='range of the random noise of translation added to the training data') | |
parser.add_argument('--iteration', type=int, default = 2, help='number of refinement iterations') | |
parser.add_argument('--nepoch', type=int, default=500, help='max number of epochs to train') | |
parser.add_argument('--resume_posenet', type=str, default = '', help='resume PoseNet model') | |
parser.add_argument('--resume_refinenet', type=str, default = '', help='resume PoseRefineNet model') | |
parser.add_argument('--start_epoch', type=int, default = 1, help='which epoch to start') | |
opt = parser.parse_args() | |
def main(): | |
opt.manualSeed = random.randint(1, 10000) | |
random.seed(opt.manualSeed) | |
torch.manual_seed(opt.manualSeed) | |
if opt.dataset == 'ycb': | |
opt.num_objects = 21 #number of object classes in the dataset | |
opt.num_points = 1000 #number of points on the input pointcloud | |
opt.outf = 'trained_models/ycb' #folder to save trained models | |
opt.log_dir = 'experiments/logs/ycb' #folder to save logs | |
opt.repeat_epoch = 1 #number of repeat times for one epoch training | |
elif opt.dataset == 'linemod': | |
opt.num_objects = 13 | |
opt.num_points = 500 | |
opt.outf = 'trained_models/linemod' | |
opt.log_dir = 'experiments/logs/linemod' | |
opt.repeat_epoch = 20 | |
else: | |
print('Unknown dataset') | |
return | |
estimator = PoseNet(num_points = opt.num_points, num_obj = opt.num_objects) | |
estimator.cuda() | |
refiner = PoseRefineNet(num_points = opt.num_points, num_obj = opt.num_objects) | |
refiner.cuda() | |
if opt.resume_posenet != '': | |
estimator.load_state_dict(torch.load('{0}/{1}'.format(opt.outf, opt.resume_posenet))) | |
if opt.resume_refinenet != '': | |
refiner.load_state_dict(torch.load('{0}/{1}'.format(opt.outf, opt.resume_refinenet))) | |
opt.refine_start = True | |
opt.decay_start = True | |
opt.lr *= opt.lr_rate | |
opt.w *= opt.w_rate | |
opt.batch_size = int(opt.batch_size / opt.iteration) | |
optimizer = optim.Adam(refiner.parameters(), lr=opt.lr) | |
else: | |
opt.refine_start = False | |
opt.decay_start = False | |
optimizer = optim.Adam(estimator.parameters(), lr=opt.lr) | |
if opt.dataset == 'ycb': | |
dataset = PoseDataset_ycb('train', opt.num_points, True, opt.dataset_root, opt.noise_trans, opt.refine_start) | |
elif opt.dataset == 'linemod': | |
dataset = PoseDataset_linemod('train', opt.num_points, True, opt.dataset_root, opt.noise_trans, opt.refine_start) | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=opt.workers) | |
if opt.dataset == 'ycb': | |
test_dataset = PoseDataset_ycb('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start) | |
elif opt.dataset == 'linemod': | |
test_dataset = PoseDataset_linemod('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start) | |
testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=opt.workers) | |
opt.sym_list = dataset.get_sym_list() | |
opt.num_points_mesh = dataset.get_num_points_mesh() | |
print('>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}'.format(len(dataset), len(test_dataset), opt.num_points_mesh, opt.sym_list)) | |
criterion = Loss(opt.num_points_mesh, opt.sym_list) | |
criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list) | |
best_test = np.Inf | |
if opt.start_epoch == 1: | |
for log in os.listdir(opt.log_dir): | |
os.remove(os.path.join(opt.log_dir, log)) | |
st_time = time.time() | |
for epoch in range(opt.start_epoch, opt.nepoch): | |
logger = setup_logger('epoch%d' % epoch, os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch)) | |
logger.info('Train time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Training started')) | |
train_count = 0 | |
train_dis_avg = 0.0 | |
if opt.refine_start: | |
estimator.eval() | |
refiner.train() | |
else: | |
estimator.train() | |
optimizer.zero_grad() | |
for rep in range(opt.repeat_epoch): | |
for i, data in enumerate(dataloader, 0): | |
points, choose, img, target, model_points, idx = data | |
points, choose, img, target, model_points, idx = Variable(points).cuda(), \ | |
Variable(choose).cuda(), \ | |
Variable(img).cuda(), \ | |
Variable(target).cuda(), \ | |
Variable(model_points).cuda(), \ | |
Variable(idx).cuda() | |
pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx) | |
loss, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c, target, model_points, idx, points, opt.w, opt.refine_start) | |
if opt.refine_start: | |
for ite in range(0, opt.iteration): | |
pred_r, pred_t = refiner(new_points, emb, idx) | |
dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_target, model_points, idx, new_points) | |
dis.backward() | |
else: | |
loss.backward() | |
train_dis_avg += dis.item() | |
train_count += 1 | |
if train_count % opt.batch_size == 0: | |
logger.info('Train time {0} Epoch {1} Batch {2} Frame {3} Avg_dis:{4}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, int(train_count / opt.batch_size), train_count, train_dis_avg / opt.batch_size)) | |
optimizer.step() | |
optimizer.zero_grad() | |
train_dis_avg = 0 | |
if train_count != 0 and train_count % 1000 == 0: | |
if opt.refine_start: | |
torch.save(refiner.state_dict(), '{0}/pose_refine_model_current.pth'.format(opt.outf)) | |
else: | |
torch.save(estimator.state_dict(), '{0}/pose_model_current.pth'.format(opt.outf)) | |
print('>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.format(epoch)) | |
logger = setup_logger('epoch%d_test' % epoch, os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch)) | |
logger.info('Test time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Testing started')) | |
test_dis = 0.0 | |
test_count = 0 | |
estimator.eval() | |
refiner.eval() | |
for j, data in enumerate(testdataloader, 0): | |
points, choose, img, target, model_points, idx = data | |
points, choose, img, target, model_points, idx = Variable(points).cuda(), \ | |
Variable(choose).cuda(), \ | |
Variable(img).cuda(), \ | |
Variable(target).cuda(), \ | |
Variable(model_points).cuda(), \ | |
Variable(idx).cuda() | |
pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx) | |
_, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c, target, model_points, idx, points, opt.w, opt.refine_start) | |
if opt.refine_start: | |
for ite in range(0, opt.iteration): | |
pred_r, pred_t = refiner(new_points, emb, idx) | |
dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_target, model_points, idx, new_points) | |
test_dis += dis.item() | |
logger.info('Test time {0} Test Frame No.{1} dis:{2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), test_count, dis)) | |
test_count += 1 | |
test_dis = test_dis / test_count | |
logger.info('Test time {0} Epoch {1} TEST FINISH Avg dis: {2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, test_dis)) | |
if test_dis <= best_test: | |
best_test = test_dis | |
if opt.refine_start: | |
torch.save(refiner.state_dict(), '{0}/pose_refine_model_{1}_{2}.pth'.format(opt.outf, epoch, test_dis)) | |
else: | |
torch.save(estimator.state_dict(), '{0}/pose_model_{1}_{2}.pth'.format(opt.outf, epoch, test_dis)) | |
print(epoch, '>>>>>>>>----------BEST TEST MODEL SAVED---------<<<<<<<<') | |
if best_test < opt.decay_margin and not opt.decay_start: | |
opt.decay_start = True | |
opt.lr *= opt.lr_rate | |
opt.w *= opt.w_rate | |
optimizer = optim.Adam(estimator.parameters(), lr=opt.lr) | |
if best_test < opt.refine_margin and not opt.refine_start: | |
opt.refine_start = True | |
opt.batch_size = int(opt.batch_size / opt.iteration) | |
optimizer = optim.Adam(refiner.parameters(), lr=opt.lr) | |
if opt.dataset == 'ycb': | |
dataset = PoseDataset_ycb('train', opt.num_points, True, opt.dataset_root, opt.noise_trans, opt.refine_start) | |
elif opt.dataset == 'linemod': | |
dataset = PoseDataset_linemod('train', opt.num_points, True, opt.dataset_root, opt.noise_trans, opt.refine_start) | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=opt.workers) | |
if opt.dataset == 'ycb': | |
test_dataset = PoseDataset_ycb('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start) | |
elif opt.dataset == 'linemod': | |
test_dataset = PoseDataset_linemod('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start) | |
testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=opt.workers) | |
opt.sym_list = dataset.get_sym_list() | |
opt.num_points_mesh = dataset.get_num_points_mesh() | |
print('>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}'.format(len(dataset), len(test_dataset), opt.num_points_mesh, opt.sym_list)) | |
criterion = Loss(opt.num_points_mesh, opt.sym_list) | |
criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list) | |
if __name__ == '__main__': | |
main() | |