Design_warper / warp_design_on_dress.py
gaur3009's picture
Update warp_design_on_dress.py
1964060 verified
raw
history blame
3.13 kB
import os
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
from networks import GMM, TOM, load_checkpoint, Options
from preprocessing import pad_to_22_channels
def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
os.makedirs(output_dir, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Preprocessing
im_h, im_w = 256, 192
tf = transforms.Compose([
transforms.Resize((im_h, im_w)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load images
dress_img = Image.open(dress_path).convert("RGB")
design_img = Image.open(design_path).convert("RGB")
# Convert to tensors
dress_tensor = tf(dress_img).unsqueeze(0).to(device)
design_tensor = tf(design_img).unsqueeze(0).to(device)
# Create design mask
design_mask = torch.ones_like(design_tensor[:, :1, :, :]).to(device)
# Prepare agnostic input
agnostic = pad_to_22_channels(dress_tensor).to(device)
# GMM Processing
opt = Options()
gmm = GMM(opt).to(device)
load_checkpoint(gmm, gmm_ckpt, strict=False)
gmm.eval()
with torch.no_grad():
grid, _ = gmm(agnostic, design_mask)
warped_design = F.grid_sample(
design_tensor,
grid,
padding_mode='border',
align_corners=True
)
warped_mask = F.grid_sample(
design_mask,
grid,
padding_mode='zeros',
align_corners=True
)
# TOM Processing
tom = TOM(opt).to(device)
load_checkpoint(tom, tom_ckpt, strict=False)
tom.eval()
with torch.no_grad():
# Create simplified feature inputs
gray = dress_tensor.mean(dim=1, keepdim=True)
# Create edge detection kernel
kernel = torch.tensor(
[[[-1, -1, -1],
[-1, 8, -1],
[-1, -1, -1]]], dtype=torch.float32, device=device)
# Calculate edges
edges = torch.abs(F.conv2d(gray, kernel, padding=1))
# Combine inputs
tom_input = torch.cat([
dress_tensor, # 3 channels
warped_design, # 3 channels
warped_mask, # 1 channel
gray, # 1 channel
edges, # 1 channel
torch.zeros_like(dress_tensor)[:, :17] # 17 dummy channels
], dim=1) # Total: 3+3+1+1+1+17 = 26 channels
# Generate try-on result
p_rendered, m_composite = tom(tom_input)
tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
# Convert to PIL image
tryon = tryon.squeeze().detach().cpu()
tryon = (tryon.permute(1, 2, 0).numpy() + 1) * 127.5
tryon = np.clip(tryon, 0, 255).astype("uint8")
out_pil = Image.fromarray(tryon)
# Save output
output_path = os.path.join(output_dir, "tryon.jpg")
out_pil.save(output_path)
return output_path