Spaces:
Sleeping
Sleeping
File size: 3,126 Bytes
b61f3f8 198f320 ddec91c b61f3f8 198f320 b61f3f8 ddec91c a4a6754 b61f3f8 198f320 1edd3bd 198f320 a4a6754 40d3928 198f320 40d3928 198f320 a4a6754 198f320 a4a6754 ddec91c a4a6754 b61f3f8 198f320 a4a6754 b61f3f8 a4a6754 198f320 a4a6754 b61f3f8 198f320 a4a6754 40d3928 1964060 40d3928 a4a6754 198f320 40d3928 198f320 b61f3f8 198f320 a4a6754 b61f3f8 198f320 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import os
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
from networks import GMM, TOM, load_checkpoint, Options
from preprocessing import pad_to_22_channels
def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
os.makedirs(output_dir, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Preprocessing
im_h, im_w = 256, 192
tf = transforms.Compose([
transforms.Resize((im_h, im_w)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load images
dress_img = Image.open(dress_path).convert("RGB")
design_img = Image.open(design_path).convert("RGB")
# Convert to tensors
dress_tensor = tf(dress_img).unsqueeze(0).to(device)
design_tensor = tf(design_img).unsqueeze(0).to(device)
# Create design mask
design_mask = torch.ones_like(design_tensor[:, :1, :, :]).to(device)
# Prepare agnostic input
agnostic = pad_to_22_channels(dress_tensor).to(device)
# GMM Processing
opt = Options()
gmm = GMM(opt).to(device)
load_checkpoint(gmm, gmm_ckpt, strict=False)
gmm.eval()
with torch.no_grad():
grid, _ = gmm(agnostic, design_mask)
warped_design = F.grid_sample(
design_tensor,
grid,
padding_mode='border',
align_corners=True
)
warped_mask = F.grid_sample(
design_mask,
grid,
padding_mode='zeros',
align_corners=True
)
# TOM Processing
tom = TOM(opt).to(device)
load_checkpoint(tom, tom_ckpt, strict=False)
tom.eval()
with torch.no_grad():
# Create simplified feature inputs
gray = dress_tensor.mean(dim=1, keepdim=True)
# Create edge detection kernel
kernel = torch.tensor(
[[[-1, -1, -1],
[-1, 8, -1],
[-1, -1, -1]]], dtype=torch.float32, device=device)
# Calculate edges
edges = torch.abs(F.conv2d(gray, kernel, padding=1))
# Combine inputs
tom_input = torch.cat([
dress_tensor, # 3 channels
warped_design, # 3 channels
warped_mask, # 1 channel
gray, # 1 channel
edges, # 1 channel
torch.zeros_like(dress_tensor)[:, :17] # 17 dummy channels
], dim=1) # Total: 3+3+1+1+1+17 = 26 channels
# Generate try-on result
p_rendered, m_composite = tom(tom_input)
tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
# Convert to PIL image
tryon = tryon.squeeze().detach().cpu()
tryon = (tryon.permute(1, 2, 0).numpy() + 1) * 127.5
tryon = np.clip(tryon, 0, 255).astype("uint8")
out_pil = Image.fromarray(tryon)
# Save output
output_path = os.path.join(output_dir, "tryon.jpg")
out_pil.save(output_path)
return output_path |