Design_warper / warp_design_on_dress.py
gaur3009's picture
Update warp_design_on_dress.py
a4a6754 verified
raw
history blame
3.96 kB
import os
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from networks import GMM, TOM, load_checkpoint, Options # Updated imports
from preprocessing import pad_to_22_channels
def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
os.makedirs(output_dir, exist_ok=True)
# Preprocessing with enhanced normalization
im_h, im_w = 256, 192
tf = transforms.Compose([
transforms.Resize((im_h, im_w)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load and prepare images with error handling
try:
dress_img = Image.open(dress_path).convert("RGB")
design_img = Image.open(design_path).convert("RGB")
except Exception as e:
raise ValueError(f"Error loading images: {str(e)}")
# Convert to tensors
dress_tensor = tf(dress_img).unsqueeze(0).cpu()
design_tensor = tf(design_img).unsqueeze(0).cpu()
design_mask = torch.ones_like(design_tensor[:, :1, :, :])
# Prepare agnostic (dress image)
agnostic = dress_tensor.clone()
# Initialize models with proper device handling
opt = Options()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# GMM Processing
gmm = GMM(opt).to(device)
load_checkpoint(gmm, gmm_ckpt, strict=False)
gmm.eval()
# Convert to required channels and move to device
agnostic_22ch = pad_to_22_channels(agnostic).contiguous().to(device)
design_mask_22ch = pad_to_22_channels(design_mask).contiguous().to(device)
design_tensor = design_tensor.to(device)
design_mask = design_mask.to(device)
with torch.no_grad():
# Process through GMM with align_corners
grid, _ = gmm(agnostic_22ch, design_mask_22ch)
warped_design = F.grid_sample(
design_tensor,
grid,
padding_mode='border',
align_corners=True
)
warped_mask = F.grid_sample(
design_mask,
grid,
padding_mode='zeros',
align_corners=True
)
# TOM Processing
tom = TOM(opt).to(device) # Using the new TOM class
load_checkpoint(tom, tom_ckpt, strict=False)
tom.eval()
with torch.no_grad():
# Prepare proper 26-channel input
# Generate additional features (replace with actual feature extraction if available)
gray = agnostic.mean(dim=1, keepdim=True)
edges_x = torch.abs(F.conv2d(gray,
torch.tensor([[[[1,0,-1],[2,0,-2],[1,0,-1]]]], device=device).float()))
edges_y = torch.abs(F.conv2d(gray,
torch.tensor([[[[1,2,1],[0,0,0],[-1,-2,-1]]]], device=device).float()))
# Combine all features (3+3+1+19=26)
tom_input = torch.cat([
agnostic, # 3 channels
warped_design, # 3 channels
warped_mask, # 1 channel
gray, # 1 channel
edges_x, # 1 channel
edges_y, # 1 channel
torch.zeros_like(agnostic)[:, :16] # 16 dummy channels (replace with real features)
], dim=1)
# Process through TOM
p_rendered, m_composite = tom(tom_input)
tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
# Save output with proper denormalization
tryon = tryon.clamp(-1, 1) # Ensure valid range
out_img = (tryon.squeeze().permute(1, 2, 0).cpu().numpy() + 1) * 127.5
out_img = out_img.clip(0, 255).astype("uint8")
try:
out_pil = Image.fromarray(out_img)
output_path = os.path.join(output_dir, "tryon.jpg")
out_pil.save(output_path)
return output_path
except Exception as e:
raise ValueError(f"Error saving output image: {str(e)}")