Spaces:
Sleeping
Sleeping
File size: 3,956 Bytes
b61f3f8 a4a6754 ddec91c b61f3f8 a4a6754 b61f3f8 ddec91c a4a6754 b61f3f8 a4a6754 b61f3f8 a4a6754 4560e58 a4a6754 b61f3f8 a4a6754 b61f3f8 1edd3bd a4a6754 1edd3bd a4a6754 ddec91c a4a6754 b61f3f8 a4a6754 ddec91c b61f3f8 a4a6754 b61f3f8 a4a6754 b61f3f8 a4a6754 b61f3f8 a4a6754 b61f3f8 a4a6754 ddec91c a4a6754 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import os
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from networks import GMM, TOM, load_checkpoint, Options # Updated imports
from preprocessing import pad_to_22_channels
def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
os.makedirs(output_dir, exist_ok=True)
# Preprocessing with enhanced normalization
im_h, im_w = 256, 192
tf = transforms.Compose([
transforms.Resize((im_h, im_w)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load and prepare images with error handling
try:
dress_img = Image.open(dress_path).convert("RGB")
design_img = Image.open(design_path).convert("RGB")
except Exception as e:
raise ValueError(f"Error loading images: {str(e)}")
# Convert to tensors
dress_tensor = tf(dress_img).unsqueeze(0).cpu()
design_tensor = tf(design_img).unsqueeze(0).cpu()
design_mask = torch.ones_like(design_tensor[:, :1, :, :])
# Prepare agnostic (dress image)
agnostic = dress_tensor.clone()
# Initialize models with proper device handling
opt = Options()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# GMM Processing
gmm = GMM(opt).to(device)
load_checkpoint(gmm, gmm_ckpt, strict=False)
gmm.eval()
# Convert to required channels and move to device
agnostic_22ch = pad_to_22_channels(agnostic).contiguous().to(device)
design_mask_22ch = pad_to_22_channels(design_mask).contiguous().to(device)
design_tensor = design_tensor.to(device)
design_mask = design_mask.to(device)
with torch.no_grad():
# Process through GMM with align_corners
grid, _ = gmm(agnostic_22ch, design_mask_22ch)
warped_design = F.grid_sample(
design_tensor,
grid,
padding_mode='border',
align_corners=True
)
warped_mask = F.grid_sample(
design_mask,
grid,
padding_mode='zeros',
align_corners=True
)
# TOM Processing
tom = TOM(opt).to(device) # Using the new TOM class
load_checkpoint(tom, tom_ckpt, strict=False)
tom.eval()
with torch.no_grad():
# Prepare proper 26-channel input
# Generate additional features (replace with actual feature extraction if available)
gray = agnostic.mean(dim=1, keepdim=True)
edges_x = torch.abs(F.conv2d(gray,
torch.tensor([[[[1,0,-1],[2,0,-2],[1,0,-1]]]], device=device).float()))
edges_y = torch.abs(F.conv2d(gray,
torch.tensor([[[[1,2,1],[0,0,0],[-1,-2,-1]]]], device=device).float()))
# Combine all features (3+3+1+19=26)
tom_input = torch.cat([
agnostic, # 3 channels
warped_design, # 3 channels
warped_mask, # 1 channel
gray, # 1 channel
edges_x, # 1 channel
edges_y, # 1 channel
torch.zeros_like(agnostic)[:, :16] # 16 dummy channels (replace with real features)
], dim=1)
# Process through TOM
p_rendered, m_composite = tom(tom_input)
tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
# Save output with proper denormalization
tryon = tryon.clamp(-1, 1) # Ensure valid range
out_img = (tryon.squeeze().permute(1, 2, 0).cpu().numpy() + 1) * 127.5
out_img = out_img.clip(0, 255).astype("uint8")
try:
out_pil = Image.fromarray(out_img)
output_path = os.path.join(output_dir, "tryon.jpg")
out_pil.save(output_path)
return output_path
except Exception as e:
raise ValueError(f"Error saving output image: {str(e)}") |