Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn import init | |
from torchvision import models | |
import os | |
import numpy as np | |
class Options: | |
def __init__(self): | |
# Image dimensions | |
self.fine_height = 256 | |
self.fine_width = 192 | |
# GMM parameters | |
self.grid_size = 5 | |
self.input_nc = 22 # For extractionA | |
self.input_nc_B = 1 # For extractionB | |
# TOM parameters | |
self.tom_input_nc = 26 # 3(agnostic) + 3(warped) + 1(mask) + 19(features) | |
self.tom_output_nc = 4 # 3(rendered) + 1(composite mask) | |
# Training settings | |
self.use_dropout = False | |
self.norm_layer = nn.BatchNorm2d | |
def weights_init_normal(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
init.normal_(m.weight.data, 0.0, 0.02) | |
elif classname.find('Linear') != -1: | |
init.normal_(m.weight.data, 0.0, 0.02) | |
elif classname.find('BatchNorm') != -1: | |
init.normal_(m.weight.data, 1.0, 0.02) | |
init.constant_(m.bias.data, 0.0) | |
def init_weights(net, init_type='normal'): | |
print(f'initialization method [{init_type}]') | |
net.apply(weights_init_normal) | |
class FeatureExtraction(nn.Module): | |
def __init__(self, input_nc, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d): | |
super(FeatureExtraction, self).__init__() | |
# Build feature extraction layers | |
layers = [ | |
nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1), | |
nn.ReLU(True), | |
norm_layer(ngf) | |
] | |
for i in range(n_layers): | |
in_channels = min(2**i * ngf, 512) | |
out_channels = min(2**(i+1) * ngf, 512) | |
layers += [ | |
nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1), | |
nn.ReLU(True), | |
norm_layer(out_channels) | |
] | |
# Final processing blocks | |
layers += [ | |
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(True), | |
norm_layer(512), | |
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(True) | |
] | |
self.model = nn.Sequential(*layers) | |
init_weights(self.model) | |
def forward(self, x): | |
return self.model(x) | |
class FeatureL2Norm(nn.Module): | |
def __init__(self): | |
super(FeatureL2Norm, self).__init__() | |
def forward(self, feature): | |
epsilon = 1e-6 | |
norm = torch.pow(torch.sum(torch.pow(feature, 2), 1) + epsilon, 0.5).unsqueeze(1).expand_as(feature) | |
return torch.div(feature, norm) | |
class FeatureCorrelation(nn.Module): | |
def __init__(self): | |
super(FeatureCorrelation, self).__init__() | |
def forward(self, feature_A, feature_B): | |
b, c, h, w = feature_A.size() | |
feature_A = feature_A.transpose(2, 3).contiguous().view(b, c, h*w) | |
feature_B = feature_B.view(b, c, h*w).transpose(1, 2) | |
feature_mul = torch.bmm(feature_B, feature_A) | |
return feature_mul.view(b, h, w, h*w).transpose(2, 3).transpose(1, 2) | |
class FeatureRegression(nn.Module): | |
def __init__(self, input_nc=512, output_dim=6): | |
super(FeatureRegression, self).__init__() | |
self.conv = nn.Sequential( | |
nn.Conv2d(input_nc, 512, kernel_size=4, stride=2, padding=1), | |
nn.BatchNorm2d(512), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(512, 256, kernel_size=4, stride=2, padding=1), | |
nn.BatchNorm2d(256), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(256, 128, kernel_size=3, padding=1), | |
nn.BatchNorm2d(128), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(128, 64, kernel_size=3, padding=1), | |
nn.BatchNorm2d(64), | |
nn.ReLU(inplace=True) | |
) | |
self.linear = nn.Linear(64 * 4 * 3, output_dim) | |
self.tanh = nn.Tanh() | |
def forward(self, x): | |
x = self.conv(x) | |
x = x.contiguous().view(x.size(0), -1) | |
x = self.linear(x) | |
return self.tanh(x) | |
class TpsGridGen(nn.Module): | |
def __init__(self, out_h=256, out_w=192, grid_size=5): | |
super(TpsGridGen, self).__init__() | |
self.out_h = out_h | |
self.out_w = out_w | |
self.grid_size = grid_size | |
self.N = grid_size * grid_size | |
# Create regular grid of control points | |
axis_coords = np.linspace(-1, 1, grid_size) | |
P_Y, P_X = np.meshgrid(axis_coords, axis_coords) | |
P_X = torch.FloatTensor(P_X.reshape(-1, 1)) # (N,1) | |
P_Y = torch.FloatTensor(P_Y.reshape(-1, 1)) # (N,1) | |
# Register buffers to persist through saving/loading | |
self.register_buffer('P_X_base', P_X) | |
self.register_buffer('P_Y_base', P_Y) | |
# Compute inverse matrix L^-1 | |
Li = self.compute_L_inverse(P_X, P_Y) | |
self.register_buffer('Li', Li) | |
# Create sampling grid | |
grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, out_w), np.linspace(-1, 1, out_h)) | |
self.register_buffer('grid_X', torch.FloatTensor(grid_X).unsqueeze(0).unsqueeze(3)) # (1,H,W,1) | |
self.register_buffer('grid_Y', torch.FloatTensor(grid_Y).unsqueeze(0).unsqueeze(3)) # (1,H,W,1) | |
def compute_L_inverse(self, X, Y): | |
N = X.size(0) | |
device = X.device | |
# Construct distance matrix | |
Xmat = X.expand(N, N) | |
Ymat = Y.expand(N, N) | |
P_dist_squared = torch.pow(Xmat - Xmat.t(), 2) + torch.pow(Ymat - Ymat.t(), 2) | |
P_dist_squared[P_dist_squared == 0] = 1 # Avoid log(0) | |
K = torch.mul(P_dist_squared, torch.log(P_dist_squared)) | |
# Construct L matrix | |
O = torch.ones(N, 1, device=device) | |
Z = torch.zeros(3, 3, device=device) | |
P = torch.cat((O, X, Y), 1) | |
L_top = torch.cat((K, P), 1) | |
L_bottom = torch.cat((P.t(), Z), 1) | |
L = torch.cat((L_top, L_bottom), 0) | |
return torch.inverse(L) | |
def forward(self, theta): | |
batch_size = theta.size(0) | |
device = theta.device | |
# Split theta into x and y components | |
Q_X = theta[:, :self.N].view(batch_size, self.N, 1) | |
Q_Y = theta[:, self.N:].view(batch_size, self.N, 1) | |
Q_X = Q_X + self.P_X_base.expand_as(Q_X) | |
Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y) | |
# Extract top-left NxN block of Li matrix | |
Li_block = self.Li[:self.N, :self.N] | |
# Compute weights | |
W_X = torch.bmm(Li_block.expand(batch_size, -1, -1), Q_X) | |
W_Y = torch.bmm(Li_block.expand(batch_size, -1, -1), Q_Y) | |
# Prepare grid tensors | |
grid_X = self.grid_X.expand(batch_size, -1, -1, -1) | |
grid_Y = self.grid_Y.expand(batch_size, -1, -1, -1) | |
# Compute transformed coordinates | |
points_X = self.transform_points(grid_X, W_X, Q_X) | |
points_Y = self.transform_points(grid_Y, W_Y, Q_Y) | |
return torch.cat((points_X, points_Y), 3) | |
# In TpsGridGen class, replace transform_points method with this: | |
def transform_points(self, grid, W, Q): | |
batch_size, h, w, _ = grid.size() | |
n_points = h * w | |
# Control points P (N, 2) | |
P = torch.cat([self.P_X_base, self.P_Y_base], 1) | |
P = P.unsqueeze(0).expand(batch_size, -1, -1) # (B, N, 2) | |
# Compute U = r^2 * log(r^2) | |
grid_flat = grid.view(batch_size, n_points, 2) # (B, H*W, 2) | |
dist = grid_flat.unsqueeze(2) - P.unsqueeze(1) # (B, H*W, N, 2) | |
dist_squared = torch.sum(dist**2, dim=3) # (B, H*W, N) | |
dist_squared[dist_squared == 0] = 1 # Avoid log(0) | |
U = dist_squared * torch.log(dist_squared) | |
# Compute affine part [1, x, y] | |
ones = torch.ones(batch_size, n_points, 1, device=grid.device) | |
A = torch.cat([ones, grid_flat], dim=2) # (B, H*W, 3) | |
# Warp coefficients | |
W = W.view(batch_size, self.N, 1) | |
Q = Q.view(batch_size, self.N, 1) | |
# Non-affine part | |
non_affine = torch.bmm(U, W) # (B, H*W, 1) | |
# Affine part | |
affine = torch.bmm(A, Q) # (B, H*W, 1) | |
# Combine components | |
points = affine + non_affine | |
return points.view(batch_size, h, w, 1) | |
class GMM(nn.Module): | |
def __init__(self, opt=None): | |
super(GMM, self).__init__() | |
if opt is None: | |
opt = Options() | |
self.extractionA = FeatureExtraction(opt.input_nc) | |
self.extractionB = FeatureExtraction(opt.input_nc_B) | |
self.l2norm = FeatureL2Norm() | |
self.correlation = FeatureCorrelation() | |
self.regression = FeatureRegression(input_nc=192, output_dim=2*opt.grid_size**2) | |
self.gridGen = TpsGridGen(opt.fine_height, opt.fine_width, opt.grid_size) | |
def forward(self, inputA, inputB): | |
featureA = self.extractionA(inputA) | |
featureB = self.extractionB(inputB) | |
featureA = self.l2norm(featureA) | |
featureB = self.l2norm(featureB) | |
correlation = self.correlation(featureA, featureB) | |
theta = self.regression(correlation) | |
grid = self.gridGen(theta) | |
return grid, theta | |
class UnetSkipConnectionBlock(nn.Module): | |
def __init__(self, outer_nc, inner_nc, input_nc=None, | |
submodule=None, outermost=False, innermost=False, | |
norm_layer=nn.InstanceNorm2d, use_dropout=False): | |
super(UnetSkipConnectionBlock, self).__init__() | |
self.outermost = outermost | |
use_bias = norm_layer == nn.InstanceNorm2d | |
if input_nc is None: | |
input_nc = outer_nc | |
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, | |
stride=2, padding=1, bias=use_bias) | |
downrelu = nn.LeakyReLU(0.2, True) | |
downnorm = norm_layer(inner_nc) | |
uprelu = nn.ReLU(True) | |
upnorm = norm_layer(outer_nc) | |
if outermost: | |
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, | |
kernel_size=4, stride=2, | |
padding=1) | |
down = [downconv] | |
up = [uprelu, upconv, nn.Tanh()] | |
model = down + [submodule] + up | |
elif innermost: | |
upconv = nn.ConvTranspose2d(inner_nc, outer_nc, | |
kernel_size=4, stride=2, | |
padding=1, bias=use_bias) | |
down = [downrelu, downconv] | |
up = [uprelu, upconv, upnorm] | |
model = down + up | |
else: | |
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, | |
kernel_size=4, stride=2, | |
padding=1, bias=use_bias) | |
down = [downrelu, downconv, downnorm] | |
up = [uprelu, upconv, upnorm] | |
if use_dropout: | |
model = down + [submodule] + up + [nn.Dropout(0.5)] | |
else: | |
model = down + [submodule] + up | |
self.model = nn.Sequential(*model) | |
def forward(self, x): | |
if self.outermost: | |
return self.model(x) | |
else: | |
return torch.cat([x, self.model(x)], 1) | |
class UnetGenerator(nn.Module): | |
def __init__(self, input_nc, output_nc, num_downs, ngf=64, | |
norm_layer=nn.InstanceNorm2d, use_dropout=False): | |
super(UnetGenerator, self).__init__() | |
# Build UNet structure | |
unet_block = UnetSkipConnectionBlock( | |
ngf * 8, ngf * 8, input_nc=None, submodule=None, | |
norm_layer=norm_layer, innermost=True) | |
for i in range(num_downs - 5): | |
unet_block = UnetSkipConnectionBlock( | |
ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, | |
norm_layer=norm_layer, use_dropout=use_dropout) | |
unet_block = UnetSkipConnectionBlock( | |
ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, | |
norm_layer=norm_layer) | |
unet_block = UnetSkipConnectionBlock( | |
ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, | |
norm_layer=norm_layer) | |
unet_block = UnetSkipConnectionBlock( | |
ngf, ngf * 2, input_nc=None, submodule=unet_block, | |
norm_layer=norm_layer) | |
self.model = UnetSkipConnectionBlock( | |
output_nc, ngf, input_nc=input_nc, submodule=unet_block, | |
outermost=True, norm_layer=norm_layer) | |
def forward(self, input): | |
return self.model(input) | |
class TOM(nn.Module): | |
def __init__(self, opt=None): | |
super(TOM, self).__init__() | |
if opt is None: | |
opt = Options() | |
self.unet = UnetGenerator( | |
input_nc=opt.tom_input_nc, | |
output_nc=opt.tom_output_nc, | |
num_downs=6, | |
norm_layer=nn.InstanceNorm2d | |
) | |
def forward(self, x): | |
output = self.unet(x) | |
p_rendered, m_composite = torch.split(output, [3, 1], dim=1) | |
p_rendered = torch.tanh(p_rendered) | |
m_composite = torch.sigmoid(m_composite) | |
return p_rendered, m_composite | |
def save_checkpoint(model, save_path): | |
if not os.path.exists(os.path.dirname(save_path)): | |
os.makedirs(os.path.dirname(save_path)) | |
torch.save(model.state_dict(), save_path) | |
def load_checkpoint(model, checkpoint_path, strict=True): | |
if not os.path.exists(checkpoint_path): | |
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") | |
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu')) | |
# Create a new state dict that matches our model architecture | |
new_state_dict = {} | |
for key, value in state_dict.items(): | |
# Handle any name changes here if needed | |
new_key = key | |
if 'gridGen' in key: | |
# Map old parameter names to new ones | |
if 'P_X' in key and 'base' not in key: | |
new_key = key.replace('P_X', 'P_X_base') | |
elif 'P_Y' in key and 'base' not in key: | |
new_key = key.replace('P_Y', 'P_Y_base') | |
# Only include keys that exist in the current model | |
if new_key in model.state_dict(): | |
new_state_dict[new_key] = value | |
# Add missing TPS parameters if needed | |
tps_params = ['gridGen.P_X_base', 'gridGen.P_Y_base', 'gridGen.Li', | |
'gridGen.grid_X', 'gridGen.grid_Y'] | |
for param in tps_params: | |
if param not in new_state_dict and hasattr(model, 'gridGen'): | |
print(f"Initializing missing TPS parameter: {param}") | |
# Initialize with current model's value | |
new_state_dict[param] = model.state_dict()[param] | |
# Load the state dict | |
model.load_state_dict(new_state_dict, strict=False) # Use strict=False to ignore missing keys | |
# Print warnings | |
model_keys = set(model.state_dict().keys()) | |
loaded_keys = set(new_state_dict.keys()) | |
missing = model_keys - loaded_keys | |
unexpected = set(state_dict.keys()) - set(new_state_dict.keys()) | |
if missing: | |
print(f"Missing keys: {sorted(missing)}") | |
if unexpected: | |
print(f"Unexpected keys: {sorted(unexpected)}") | |
return model |