Spaces:
Build error
Build error
import argparse | |
import os | |
import random | |
import torch | |
import torch.nn as nn | |
import torch.nn.parallel | |
import torch.backends.cudnn as cudnn | |
import torch.optim as optim | |
import torch.utils.data | |
import torchvision.transforms as transforms | |
import torchvision.utils as vutils | |
from torch.autograd import Variable | |
from PIL import Image | |
import numpy as np | |
import pdb | |
import torch.nn.functional as F | |
from lib.pspnet import PSPNet | |
psp_models = { | |
'resnet18': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet18'), | |
'resnet34': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet34'), | |
'resnet50': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50'), | |
'resnet101': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet101'), | |
'resnet152': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet152') | |
} | |
class ModifiedResnet(nn.Module): | |
def __init__(self, usegpu=True): | |
super(ModifiedResnet, self).__init__() | |
self.model = psp_models['resnet18'.lower()]() | |
self.model = nn.DataParallel(self.model) | |
def forward(self, x): | |
x = self.model(x) | |
return x | |
class PoseNetFeat(nn.Module): | |
def __init__(self, num_points): | |
super(PoseNetFeat, self).__init__() | |
self.conv1 = torch.nn.Conv1d(3, 64, 1) | |
self.conv2 = torch.nn.Conv1d(64, 128, 1) | |
self.e_conv1 = torch.nn.Conv1d(32, 64, 1) | |
self.e_conv2 = torch.nn.Conv1d(64, 128, 1) | |
self.conv5 = torch.nn.Conv1d(256, 512, 1) | |
self.conv6 = torch.nn.Conv1d(512, 1024, 1) | |
self.ap1 = torch.nn.AvgPool1d(num_points) | |
self.num_points = num_points | |
def forward(self, x, emb): | |
x = F.relu(self.conv1(x)) | |
emb = F.relu(self.e_conv1(emb)) | |
pointfeat_1 = torch.cat((x, emb), dim=1) | |
x = F.relu(self.conv2(x)) | |
emb = F.relu(self.e_conv2(emb)) | |
pointfeat_2 = torch.cat((x, emb), dim=1) | |
x = F.relu(self.conv5(pointfeat_2)) | |
x = F.relu(self.conv6(x)) | |
ap_x = self.ap1(x) | |
ap_x = ap_x.view(-1, 1024, 1).repeat(1, 1, self.num_points) | |
return torch.cat([pointfeat_1, pointfeat_2, ap_x], 1) #128 + 256 + 1024 | |
class PoseNet(nn.Module): | |
def __init__(self, num_points, num_obj): | |
super(PoseNet, self).__init__() | |
self.num_points = num_points | |
self.cnn = ModifiedResnet() | |
self.feat = PoseNetFeat(num_points) | |
self.conv1_r = torch.nn.Conv1d(1408, 640, 1) | |
self.conv1_t = torch.nn.Conv1d(1408, 640, 1) | |
self.conv1_c = torch.nn.Conv1d(1408, 640, 1) | |
self.conv2_r = torch.nn.Conv1d(640, 256, 1) | |
self.conv2_t = torch.nn.Conv1d(640, 256, 1) | |
self.conv2_c = torch.nn.Conv1d(640, 256, 1) | |
self.conv3_r = torch.nn.Conv1d(256, 128, 1) | |
self.conv3_t = torch.nn.Conv1d(256, 128, 1) | |
self.conv3_c = torch.nn.Conv1d(256, 128, 1) | |
self.conv4_r = torch.nn.Conv1d(128, num_obj*4, 1) #quaternion | |
self.conv4_t = torch.nn.Conv1d(128, num_obj*3, 1) #translation | |
self.conv4_c = torch.nn.Conv1d(128, num_obj*1, 1) #confidence | |
self.num_obj = num_obj | |
def forward(self, img, x, choose, obj): | |
out_img = self.cnn(img) | |
bs, di, _, _ = out_img.size() | |
emb = out_img.view(bs, di, -1) | |
choose = choose.repeat(1, di, 1) | |
emb = torch.gather(emb, 2, choose).contiguous() | |
x = x.transpose(2, 1).contiguous() | |
ap_x = self.feat(x, emb) | |
rx = F.relu(self.conv1_r(ap_x)) | |
tx = F.relu(self.conv1_t(ap_x)) | |
cx = F.relu(self.conv1_c(ap_x)) | |
rx = F.relu(self.conv2_r(rx)) | |
tx = F.relu(self.conv2_t(tx)) | |
cx = F.relu(self.conv2_c(cx)) | |
rx = F.relu(self.conv3_r(rx)) | |
tx = F.relu(self.conv3_t(tx)) | |
cx = F.relu(self.conv3_c(cx)) | |
rx = self.conv4_r(rx).view(bs, self.num_obj, 4, self.num_points) | |
tx = self.conv4_t(tx).view(bs, self.num_obj, 3, self.num_points) | |
cx = torch.sigmoid(self.conv4_c(cx)).view(bs, self.num_obj, 1, self.num_points) | |
b = 0 | |
out_rx = torch.index_select(rx[b], 0, obj[b]) | |
out_tx = torch.index_select(tx[b], 0, obj[b]) | |
out_cx = torch.index_select(cx[b], 0, obj[b]) | |
out_rx = out_rx.contiguous().transpose(2, 1).contiguous() | |
out_cx = out_cx.contiguous().transpose(2, 1).contiguous() | |
out_tx = out_tx.contiguous().transpose(2, 1).contiguous() | |
return out_rx, out_tx, out_cx, emb.detach() | |
class PoseRefineNetFeat(nn.Module): | |
def __init__(self, num_points): | |
super(PoseRefineNetFeat, self).__init__() | |
self.conv1 = torch.nn.Conv1d(3, 64, 1) | |
self.conv2 = torch.nn.Conv1d(64, 128, 1) | |
self.e_conv1 = torch.nn.Conv1d(32, 64, 1) | |
self.e_conv2 = torch.nn.Conv1d(64, 128, 1) | |
self.conv5 = torch.nn.Conv1d(384, 512, 1) | |
self.conv6 = torch.nn.Conv1d(512, 1024, 1) | |
self.ap1 = torch.nn.AvgPool1d(num_points) | |
self.num_points = num_points | |
def forward(self, x, emb): | |
x = F.relu(self.conv1(x)) | |
emb = F.relu(self.e_conv1(emb)) | |
pointfeat_1 = torch.cat([x, emb], dim=1) | |
x = F.relu(self.conv2(x)) | |
emb = F.relu(self.e_conv2(emb)) | |
pointfeat_2 = torch.cat([x, emb], dim=1) | |
pointfeat_3 = torch.cat([pointfeat_1, pointfeat_2], dim=1) | |
x = F.relu(self.conv5(pointfeat_3)) | |
x = F.relu(self.conv6(x)) | |
ap_x = self.ap1(x) | |
ap_x = ap_x.view(-1, 1024) | |
return ap_x | |
class PoseRefineNet(nn.Module): | |
def __init__(self, num_points, num_obj): | |
super(PoseRefineNet, self).__init__() | |
self.num_points = num_points | |
self.feat = PoseRefineNetFeat(num_points) | |
self.conv1_r = torch.nn.Linear(1024, 512) | |
self.conv1_t = torch.nn.Linear(1024, 512) | |
self.conv2_r = torch.nn.Linear(512, 128) | |
self.conv2_t = torch.nn.Linear(512, 128) | |
self.conv3_r = torch.nn.Linear(128, num_obj*4) #quaternion | |
self.conv3_t = torch.nn.Linear(128, num_obj*3) #translation | |
self.num_obj = num_obj | |
def forward(self, x, emb, obj): | |
bs = x.size()[0] | |
x = x.transpose(2, 1).contiguous() | |
ap_x = self.feat(x, emb) | |
rx = F.relu(self.conv1_r(ap_x)) | |
tx = F.relu(self.conv1_t(ap_x)) | |
rx = F.relu(self.conv2_r(rx)) | |
tx = F.relu(self.conv2_t(tx)) | |
rx = self.conv3_r(rx).view(bs, self.num_obj, 4) | |
tx = self.conv3_t(tx).view(bs, self.num_obj, 3) | |
b = 0 | |
out_rx = torch.index_select(rx[b], 0, obj[b]) | |
out_tx = torch.index_select(tx[b], 0, obj[b]) | |
return out_rx, out_tx | |