File size: 3,126 Bytes
b61f3f8
 
 
 
 
198f320
 
ddec91c
b61f3f8
 
 
198f320
 
 
b61f3f8
 
 
ddec91c
a4a6754
b61f3f8
 
198f320
 
 
1edd3bd
198f320
 
 
a4a6754
40d3928
198f320
 
40d3928
198f320
 
a4a6754
198f320
a4a6754
ddec91c
a4a6754
b61f3f8
 
198f320
a4a6754
 
 
 
 
 
 
 
 
 
 
 
b61f3f8
a4a6754
198f320
a4a6754
 
b61f3f8
 
198f320
 
a4a6754
40d3928
1964060
40d3928
 
 
 
 
 
 
 
a4a6754
198f320
40d3928
 
 
 
198f320
 
b61f3f8
198f320
a4a6754
b61f3f8
 
198f320
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
from networks import GMM, TOM, load_checkpoint, Options
from preprocessing import pad_to_22_channels

def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Preprocessing
    im_h, im_w = 256, 192
    tf = transforms.Compose([
        transforms.Resize((im_h, im_w)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Load images
    dress_img = Image.open(dress_path).convert("RGB")
    design_img = Image.open(design_path).convert("RGB")
    
    # Convert to tensors
    dress_tensor = tf(dress_img).unsqueeze(0).to(device)
    design_tensor = tf(design_img).unsqueeze(0).to(device)
    
    # Create design mask
    design_mask = torch.ones_like(design_tensor[:, :1, :, :]).to(device)

    # Prepare agnostic input
    agnostic = pad_to_22_channels(dress_tensor).to(device)

    # GMM Processing
    opt = Options()
    gmm = GMM(opt).to(device)
    load_checkpoint(gmm, gmm_ckpt, strict=False)
    gmm.eval()

    with torch.no_grad():
        grid, _ = gmm(agnostic, design_mask)
        warped_design = F.grid_sample(
            design_tensor, 
            grid, 
            padding_mode='border',
            align_corners=True
        )
        warped_mask = F.grid_sample(
            design_mask,
            grid,
            padding_mode='zeros',
            align_corners=True
        )

    # TOM Processing
    tom = TOM(opt).to(device)
    load_checkpoint(tom, tom_ckpt, strict=False)
    tom.eval()

    with torch.no_grad():
        # Create simplified feature inputs
        gray = dress_tensor.mean(dim=1, keepdim=True)
        
        # Create edge detection kernel
        kernel = torch.tensor(
            [[[-1, -1, -1],
              [-1,  8, -1],
              [-1, -1, -1]]], dtype=torch.float32, device=device)
        
        # Calculate edges
        edges = torch.abs(F.conv2d(gray, kernel, padding=1))
        
        # Combine inputs
        tom_input = torch.cat([
            dress_tensor,          # 3 channels
            warped_design,         # 3 channels
            warped_mask,           # 1 channel
            gray,                  # 1 channel
            edges,                 # 1 channel
            torch.zeros_like(dress_tensor)[:, :17]  # 17 dummy channels
        ], dim=1)  # Total: 3+3+1+1+1+17 = 26 channels

        # Generate try-on result
        p_rendered, m_composite = tom(tom_input)
        tryon = warped_design * m_composite + p_rendered * (1 - m_composite)

    # Convert to PIL image
    tryon = tryon.squeeze().detach().cpu()
    tryon = (tryon.permute(1, 2, 0).numpy() + 1) * 127.5
    tryon = np.clip(tryon, 0, 255).astype("uint8")
    out_pil = Image.fromarray(tryon)
    
    # Save output
    output_path = os.path.join(output_dir, "tryon.jpg")
    out_pil.save(output_path)
    return output_path