gaur3009 commited on
Commit
ddec91c
·
verified ·
1 Parent(s): 580136d

Update warp_design_on_dress.py

Browse files
Files changed (1) hide show
  1. warp_design_on_dress.py +11 -6
warp_design_on_dress.py CHANGED
@@ -4,7 +4,7 @@ import torch.nn.functional as F
4
  from torchvision import transforms
5
  from PIL import Image
6
  from networks import GMM, UnetGenerator, load_checkpoint, Options
7
-
8
 
9
  def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
10
  os.makedirs(output_dir, exist_ok=True)
@@ -13,7 +13,8 @@ def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output
13
  im_h, im_w = 256, 192
14
  tf = transforms.Compose([
15
  transforms.Resize((im_h, im_w)),
16
- transforms.ToTensor()
 
17
  ])
18
 
19
  dress_img = Image.open(dress_path).convert("RGB")
@@ -28,11 +29,15 @@ def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output
28
 
29
  opt = Options()
30
  gmm = GMM(opt)
31
- load_checkpoint(gmm, gmm_ckpt, strict = False)
32
  gmm.cpu().eval()
33
 
 
 
 
 
34
  with torch.no_grad():
35
- grid, _ = gmm(agnostic, design_mask)
36
  warped_design = F.grid_sample(design_tensor, grid, padding_mode='border')
37
  warped_mask = F.grid_sample(design_mask, grid, padding_mode='zeros')
38
 
@@ -52,8 +57,8 @@ def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output
52
  tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
53
 
54
  # Save output
55
- out_img = tryon.squeeze().permute(1, 2, 0).cpu().numpy()
56
- out_img = (out_img * 255).astype("uint8")
57
  out_pil = Image.fromarray(out_img)
58
 
59
  output_path = os.path.join(output_dir, "tryon.jpg")
 
4
  from torchvision import transforms
5
  from PIL import Image
6
  from networks import GMM, UnetGenerator, load_checkpoint, Options
7
+ from preprocessing import pad_to_22_channels
8
 
9
  def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
10
  os.makedirs(output_dir, exist_ok=True)
 
13
  im_h, im_w = 256, 192
14
  tf = transforms.Compose([
15
  transforms.Resize((im_h, im_w)),
16
+ transforms.ToTensor(),
17
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Added normalization
18
  ])
19
 
20
  dress_img = Image.open(dress_path).convert("RGB")
 
29
 
30
  opt = Options()
31
  gmm = GMM(opt)
32
+ load_checkpoint(gmm, gmm_ckpt, strict=False)
33
  gmm.cpu().eval()
34
 
35
+ # Convert agnostic to 22 channels before passing to GMM
36
+ agnostic_22ch = pad_to_22_channels(agnostic)
37
+ design_mask_22ch = pad_to_22_channels(design_mask)
38
+
39
  with torch.no_grad():
40
+ grid, _ = gmm(agnostic_22ch, design_mask_22ch) # Use padded inputs
41
  warped_design = F.grid_sample(design_tensor, grid, padding_mode='border')
42
  warped_mask = F.grid_sample(design_mask, grid, padding_mode='zeros')
43
 
 
57
  tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
58
 
59
  # Save output
60
+ out_img = (tryon.squeeze().permute(1, 2, 0).cpu().numpy() + 1) * 127.5 # Denormalize
61
+ out_img = out_img.clip(0, 255).astype("uint8")
62
  out_pil = Image.fromarray(out_img)
63
 
64
  output_path = os.path.join(output_dir, "tryon.jpg")