Spaces:
Build error
Build error
import os | |
import copy | |
import random | |
import argparse | |
import time | |
import numpy as np | |
from PIL import Image | |
import scipy.io as scio | |
import scipy.misc | |
import torch | |
from torch.optim.lr_scheduler import ReduceLROnPlateau | |
from torch.autograd import Variable | |
import torch.optim as optim | |
import torch.nn as nn | |
from torch.backends import cudnn | |
from data_controller import SegDataset | |
from loss import Loss | |
from segnet import SegNet as segnet | |
import sys | |
sys.path.append("..") | |
from lib.utils import setup_logger | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--dataset_root', default='/home/data1/jeremy/YCB_Video_Dataset', help="dataset root dir (''YCB_Video Dataset'')") | |
parser.add_argument('--batch_size', default=3, help="batch size") | |
parser.add_argument('--n_epochs', default=600, help="epochs to train") | |
parser.add_argument('--workers', type=int, default=10, help='number of data loading workers') | |
parser.add_argument('--lr', default=0.0001, help="learning rate") | |
parser.add_argument('--logs_path', default='logs/', help="path to save logs") | |
parser.add_argument('--model_save_path', default='trained_models/', help="path to save models") | |
parser.add_argument('--log_dir', default='logs/', help="path to save logs") | |
parser.add_argument('--resume_model', default='', help="resume model name") | |
opt = parser.parse_args() | |
if __name__ == '__main__': | |
opt.manualSeed = random.randint(1, 10000) | |
random.seed(opt.manualSeed) | |
torch.manual_seed(opt.manualSeed) | |
dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/train_data_list.txt', True, 5000) | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers)) | |
test_dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/test_data_list.txt', False, 1000) | |
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=int(opt.workers)) | |
print(len(dataset), len(test_dataset)) | |
model = segnet() | |
model = model.cuda() | |
if opt.resume_model != '': | |
checkpoint = torch.load('{0}/{1}'.format(opt.model_save_path, opt.resume_model)) | |
model.load_state_dict(checkpoint) | |
for log in os.listdir(opt.log_dir): | |
os.remove(os.path.join(opt.log_dir, log)) | |
optimizer = optim.Adam(model.parameters(), lr=opt.lr) | |
criterion = Loss() | |
best_val_cost = np.Inf | |
st_time = time.time() | |
for epoch in range(1, opt.n_epochs): | |
model.train() | |
train_all_cost = 0.0 | |
train_time = 0 | |
logger = setup_logger('epoch%d' % epoch, os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch)) | |
logger.info('Train time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Training started')) | |
for i, data in enumerate(dataloader, 0): | |
rgb, target = data | |
rgb, target = Variable(rgb).cuda(), Variable(target).cuda() | |
semantic = model(rgb) | |
optimizer.zero_grad() | |
semantic_loss = criterion(semantic, target) | |
train_all_cost += semantic_loss.item() | |
semantic_loss.backward() | |
optimizer.step() | |
logger.info('Train time {0} Batch {1} CEloss {2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), train_time, semantic_loss.item())) | |
if train_time != 0 and train_time % 1000 == 0: | |
torch.save(model.state_dict(), os.path.join(opt.model_save_path, 'model_current.pth')) | |
train_time += 1 | |
train_all_cost = train_all_cost / train_time | |
logger.info('Train Finish Avg CEloss: {0}'.format(train_all_cost)) | |
model.eval() | |
test_all_cost = 0.0 | |
test_time = 0 | |
logger = setup_logger('epoch%d_test' % epoch, os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch)) | |
logger.info('Test time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Testing started')) | |
for j, data in enumerate(test_dataloader, 0): | |
rgb, target = data | |
rgb, target = Variable(rgb).cuda(), Variable(target).cuda() | |
semantic = model(rgb) | |
semantic_loss = criterion(semantic, target) | |
test_all_cost += semantic_loss.item() | |
test_time += 1 | |
logger.info('Test time {0} Batch {1} CEloss {2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), test_time, semantic_loss.item())) | |
test_all_cost = test_all_cost / test_time | |
logger.info('Test Finish Avg CEloss: {0}'.format(test_all_cost)) | |
if test_all_cost <= best_val_cost: | |
best_val_cost = test_all_cost | |
torch.save(model.state_dict(), os.path.join(opt.model_save_path, 'model_{}_{}.pth'.format(epoch, test_all_cost))) | |
print('----------->BEST SAVED<-----------') | |