File size: 3,956 Bytes
b61f3f8
 
 
 
 
a4a6754
ddec91c
b61f3f8
 
 
 
a4a6754
b61f3f8
 
 
ddec91c
a4a6754
b61f3f8
 
a4a6754
 
 
 
 
 
b61f3f8
a4a6754
4560e58
 
a4a6754
b61f3f8
a4a6754
b61f3f8
1edd3bd
a4a6754
1edd3bd
a4a6754
 
 
 
ddec91c
a4a6754
b61f3f8
a4a6754
 
 
 
 
ddec91c
b61f3f8
a4a6754
 
 
 
 
 
 
 
 
 
 
 
 
 
b61f3f8
a4a6754
 
 
 
b61f3f8
 
a4a6754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b61f3f8
a4a6754
 
b61f3f8
 
a4a6754
 
 
ddec91c
a4a6754
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from networks import GMM, TOM, load_checkpoint, Options  # Updated imports
from preprocessing import pad_to_22_channels

def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
    os.makedirs(output_dir, exist_ok=True)

    # Preprocessing with enhanced normalization
    im_h, im_w = 256, 192
    tf = transforms.Compose([
        transforms.Resize((im_h, im_w)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Load and prepare images with error handling
    try:
        dress_img = Image.open(dress_path).convert("RGB")
        design_img = Image.open(design_path).convert("RGB")
    except Exception as e:
        raise ValueError(f"Error loading images: {str(e)}")

    # Convert to tensors
    dress_tensor = tf(dress_img).unsqueeze(0).cpu()
    design_tensor = tf(design_img).unsqueeze(0).cpu()
    design_mask = torch.ones_like(design_tensor[:, :1, :, :])

    # Prepare agnostic (dress image)
    agnostic = dress_tensor.clone()
    
    # Initialize models with proper device handling
    opt = Options()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # GMM Processing
    gmm = GMM(opt).to(device)
    load_checkpoint(gmm, gmm_ckpt, strict=False)
    gmm.eval()

    # Convert to required channels and move to device
    agnostic_22ch = pad_to_22_channels(agnostic).contiguous().to(device)
    design_mask_22ch = pad_to_22_channels(design_mask).contiguous().to(device)
    design_tensor = design_tensor.to(device)
    design_mask = design_mask.to(device)

    with torch.no_grad():
        # Process through GMM with align_corners
        grid, _ = gmm(agnostic_22ch, design_mask_22ch)
        warped_design = F.grid_sample(
            design_tensor, 
            grid, 
            padding_mode='border',
            align_corners=True
        )
        warped_mask = F.grid_sample(
            design_mask,
            grid,
            padding_mode='zeros',
            align_corners=True
        )

    # TOM Processing
    tom = TOM(opt).to(device)  # Using the new TOM class
    load_checkpoint(tom, tom_ckpt, strict=False)
    tom.eval()

    with torch.no_grad():
        # Prepare proper 26-channel input
        # Generate additional features (replace with actual feature extraction if available)
        gray = agnostic.mean(dim=1, keepdim=True)
        edges_x = torch.abs(F.conv2d(gray, 
                           torch.tensor([[[[1,0,-1],[2,0,-2],[1,0,-1]]]], device=device).float()))
        edges_y = torch.abs(F.conv2d(gray,
                           torch.tensor([[[[1,2,1],[0,0,0],[-1,-2,-1]]]], device=device).float()))
        
        # Combine all features (3+3+1+19=26)
        tom_input = torch.cat([
            agnostic,          # 3 channels
            warped_design,     # 3 channels
            warped_mask,       # 1 channel
            gray,              # 1 channel
            edges_x,           # 1 channel
            edges_y,           # 1 channel
            torch.zeros_like(agnostic)[:, :16]  # 16 dummy channels (replace with real features)
        ], dim=1)

        # Process through TOM
        p_rendered, m_composite = tom(tom_input)
        tryon = warped_design * m_composite + p_rendered * (1 - m_composite)

        # Save output with proper denormalization
        tryon = tryon.clamp(-1, 1)  # Ensure valid range
        out_img = (tryon.squeeze().permute(1, 2, 0).cpu().numpy() + 1) * 127.5
        out_img = out_img.clip(0, 255).astype("uint8")
        
        try:
            out_pil = Image.fromarray(out_img)
            output_path = os.path.join(output_dir, "tryon.jpg")
            out_pil.save(output_path)
            return output_path
        except Exception as e:
            raise ValueError(f"Error saving output image: {str(e)}")