Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from PIL import Image | |
from networks import GMM, TOM, load_checkpoint, Options # Updated imports | |
from preprocessing import pad_to_22_channels | |
def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir): | |
os.makedirs(output_dir, exist_ok=True) | |
# Preprocessing with enhanced normalization | |
im_h, im_w = 256, 192 | |
tf = transforms.Compose([ | |
transforms.Resize((im_h, im_w)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
# Load and prepare images with error handling | |
try: | |
dress_img = Image.open(dress_path).convert("RGB") | |
design_img = Image.open(design_path).convert("RGB") | |
except Exception as e: | |
raise ValueError(f"Error loading images: {str(e)}") | |
# Convert to tensors | |
dress_tensor = tf(dress_img).unsqueeze(0).cpu() | |
design_tensor = tf(design_img).unsqueeze(0).cpu() | |
design_mask = torch.ones_like(design_tensor[:, :1, :, :]) | |
# Prepare agnostic (dress image) | |
agnostic = dress_tensor.clone() | |
# Initialize models with proper device handling | |
opt = Options() | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# GMM Processing | |
gmm = GMM(opt).to(device) | |
load_checkpoint(gmm, gmm_ckpt, strict=False) | |
gmm.eval() | |
# Convert to required channels and move to device | |
agnostic_22ch = pad_to_22_channels(agnostic).contiguous().to(device) | |
design_mask_22ch = pad_to_22_channels(design_mask).contiguous().to(device) | |
design_tensor = design_tensor.to(device) | |
design_mask = design_mask.to(device) | |
with torch.no_grad(): | |
# Process through GMM with align_corners | |
grid, _ = gmm(agnostic_22ch, design_mask_22ch) | |
warped_design = F.grid_sample( | |
design_tensor, | |
grid, | |
padding_mode='border', | |
align_corners=True | |
) | |
warped_mask = F.grid_sample( | |
design_mask, | |
grid, | |
padding_mode='zeros', | |
align_corners=True | |
) | |
# TOM Processing | |
tom = TOM(opt).to(device) # Using the new TOM class | |
load_checkpoint(tom, tom_ckpt, strict=False) | |
tom.eval() | |
with torch.no_grad(): | |
# Prepare proper 26-channel input | |
# Generate additional features (replace with actual feature extraction if available) | |
gray = agnostic.mean(dim=1, keepdim=True) | |
edges_x = torch.abs(F.conv2d(gray, | |
torch.tensor([[[[1,0,-1],[2,0,-2],[1,0,-1]]]], device=device).float())) | |
edges_y = torch.abs(F.conv2d(gray, | |
torch.tensor([[[[1,2,1],[0,0,0],[-1,-2,-1]]]], device=device).float())) | |
# Combine all features (3+3+1+19=26) | |
tom_input = torch.cat([ | |
agnostic, # 3 channels | |
warped_design, # 3 channels | |
warped_mask, # 1 channel | |
gray, # 1 channel | |
edges_x, # 1 channel | |
edges_y, # 1 channel | |
torch.zeros_like(agnostic)[:, :16] # 16 dummy channels (replace with real features) | |
], dim=1) | |
# Process through TOM | |
p_rendered, m_composite = tom(tom_input) | |
tryon = warped_design * m_composite + p_rendered * (1 - m_composite) | |
# Save output with proper denormalization | |
tryon = tryon.clamp(-1, 1) # Ensure valid range | |
out_img = (tryon.squeeze().permute(1, 2, 0).cpu().numpy() + 1) * 127.5 | |
out_img = out_img.clip(0, 255).astype("uint8") | |
try: | |
out_pil = Image.fromarray(out_img) | |
output_path = os.path.join(output_dir, "tryon.jpg") | |
out_pil.save(output_path) | |
return output_path | |
except Exception as e: | |
raise ValueError(f"Error saving output image: {str(e)}") |