Spaces:
Sleeping
Sleeping
Update warp_design_on_dress.py
Browse files- warp_design_on_dress.py +11 -6
warp_design_on_dress.py
CHANGED
@@ -4,7 +4,7 @@ import torch.nn.functional as F
|
|
4 |
from torchvision import transforms
|
5 |
from PIL import Image
|
6 |
from networks import GMM, UnetGenerator, load_checkpoint, Options
|
7 |
-
|
8 |
|
9 |
def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
|
10 |
os.makedirs(output_dir, exist_ok=True)
|
@@ -13,7 +13,8 @@ def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output
|
|
13 |
im_h, im_w = 256, 192
|
14 |
tf = transforms.Compose([
|
15 |
transforms.Resize((im_h, im_w)),
|
16 |
-
transforms.ToTensor()
|
|
|
17 |
])
|
18 |
|
19 |
dress_img = Image.open(dress_path).convert("RGB")
|
@@ -28,11 +29,15 @@ def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output
|
|
28 |
|
29 |
opt = Options()
|
30 |
gmm = GMM(opt)
|
31 |
-
load_checkpoint(gmm, gmm_ckpt, strict
|
32 |
gmm.cpu().eval()
|
33 |
|
|
|
|
|
|
|
|
|
34 |
with torch.no_grad():
|
35 |
-
grid, _ = gmm(
|
36 |
warped_design = F.grid_sample(design_tensor, grid, padding_mode='border')
|
37 |
warped_mask = F.grid_sample(design_mask, grid, padding_mode='zeros')
|
38 |
|
@@ -52,8 +57,8 @@ def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output
|
|
52 |
tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
|
53 |
|
54 |
# Save output
|
55 |
-
out_img = tryon.squeeze().permute(1, 2, 0).cpu().numpy()
|
56 |
-
out_img = (
|
57 |
out_pil = Image.fromarray(out_img)
|
58 |
|
59 |
output_path = os.path.join(output_dir, "tryon.jpg")
|
|
|
4 |
from torchvision import transforms
|
5 |
from PIL import Image
|
6 |
from networks import GMM, UnetGenerator, load_checkpoint, Options
|
7 |
+
from preprocessing import pad_to_22_channels
|
8 |
|
9 |
def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
|
10 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
13 |
im_h, im_w = 256, 192
|
14 |
tf = transforms.Compose([
|
15 |
transforms.Resize((im_h, im_w)),
|
16 |
+
transforms.ToTensor(),
|
17 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Added normalization
|
18 |
])
|
19 |
|
20 |
dress_img = Image.open(dress_path).convert("RGB")
|
|
|
29 |
|
30 |
opt = Options()
|
31 |
gmm = GMM(opt)
|
32 |
+
load_checkpoint(gmm, gmm_ckpt, strict=False)
|
33 |
gmm.cpu().eval()
|
34 |
|
35 |
+
# Convert agnostic to 22 channels before passing to GMM
|
36 |
+
agnostic_22ch = pad_to_22_channels(agnostic)
|
37 |
+
design_mask_22ch = pad_to_22_channels(design_mask)
|
38 |
+
|
39 |
with torch.no_grad():
|
40 |
+
grid, _ = gmm(agnostic_22ch, design_mask_22ch) # Use padded inputs
|
41 |
warped_design = F.grid_sample(design_tensor, grid, padding_mode='border')
|
42 |
warped_mask = F.grid_sample(design_mask, grid, padding_mode='zeros')
|
43 |
|
|
|
57 |
tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
|
58 |
|
59 |
# Save output
|
60 |
+
out_img = (tryon.squeeze().permute(1, 2, 0).cpu().numpy() + 1) * 127.5 # Denormalize
|
61 |
+
out_img = out_img.clip(0, 255).astype("uint8")
|
62 |
out_pil = Image.fromarray(out_img)
|
63 |
|
64 |
output_path = os.path.join(output_dir, "tryon.jpg")
|