Spaces:
Running
Running
import os | |
import torch | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from PIL import Image | |
import numpy as np | |
from networks import GMM, TOM, load_checkpoint, Options | |
from preprocessing import pad_to_22_channels | |
def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir): | |
os.makedirs(output_dir, exist_ok=True) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Preprocessing | |
im_h, im_w = 256, 192 | |
tf = transforms.Compose([ | |
transforms.Resize((im_h, im_w)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
# Load images | |
dress_img = Image.open(dress_path).convert("RGB") | |
design_img = Image.open(design_path).convert("RGB") | |
# Convert to tensors | |
dress_tensor = tf(dress_img).unsqueeze(0).to(device) | |
design_tensor = tf(design_img).unsqueeze(0).to(device) | |
# Create design mask | |
design_mask = torch.ones_like(design_tensor[:, :1, :, :]).to(device) | |
# Prepare agnostic input | |
agnostic = pad_to_22_channels(dress_tensor).to(device) | |
# GMM Processing | |
opt = Options() | |
gmm = GMM(opt).to(device) | |
load_checkpoint(gmm, gmm_ckpt, strict=False) | |
gmm.eval() | |
with torch.no_grad(): | |
grid, _ = gmm(agnostic, design_mask) | |
warped_design = F.grid_sample( | |
design_tensor, | |
grid, | |
padding_mode='border', | |
align_corners=True | |
) | |
warped_mask = F.grid_sample( | |
design_mask, | |
grid, | |
padding_mode='zeros', | |
align_corners=True | |
) | |
# TOM Processing | |
tom = TOM(opt).to(device) | |
load_checkpoint(tom, tom_ckpt, strict=False) | |
tom.eval() | |
with torch.no_grad(): | |
# Create simplified feature inputs | |
gray = dress_tensor.mean(dim=1, keepdim=True) | |
# Create edge detection kernel | |
kernel = torch.tensor( | |
[[[-1, -1, -1], | |
[-1, 8, -1], | |
[-1, -1, -1]]], dtype=torch.float32, device=device) | |
# Calculate edges | |
edges = torch.abs(F.conv2d(gray, kernel, padding=1)) | |
# Combine inputs | |
tom_input = torch.cat([ | |
dress_tensor, # 3 channels | |
warped_design, # 3 channels | |
warped_mask, # 1 channel | |
gray, # 1 channel | |
edges, # 1 channel | |
torch.zeros_like(dress_tensor)[:, :17] # 17 dummy channels | |
], dim=1) # Total: 3+3+1+1+1+17 = 26 channels | |
# Generate try-on result | |
p_rendered, m_composite = tom(tom_input) | |
tryon = warped_design * m_composite + p_rendered * (1 - m_composite) | |
# Convert to PIL image | |
tryon = tryon.squeeze().detach().cpu() | |
tryon = (tryon.permute(1, 2, 0).numpy() + 1) * 127.5 | |
tryon = np.clip(tryon, 0, 255).astype("uint8") | |
out_pil = Image.fromarray(tryon) | |
# Save output | |
output_path = os.path.join(output_dir, "tryon.jpg") | |
out_pil.save(output_path) | |
return output_path |