File size: 3,343 Bytes
215921b
 
93827d9
 
215921b
93827d9
41c7336
f62de59
41c7336
 
 
 
93827d9
 
 
41c7336
93827d9
41c7336
93827d9
41c7336
 
 
 
93827d9
41c7336
 
f62de59
41c7336
 
 
93827d9
215921b
80364f7
215921b
 
 
 
 
93827d9
215921b
93827d9
 
f62de59
 
 
93827d9
 
f62de59
93827d9
 
 
 
80364f7
f62de59
0be295c
 
93827d9
80364f7
f62de59
41c7336
93827d9
 
d004f7c
f62de59
0be295c
 
 
f62de59
215921b
0be295c
215921b
f62de59
215921b
 
 
80364f7
 
215921b
 
 
 
 
 
80364f7
 
215921b
 
 
93827d9
215921b
 
f62de59
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
import numpy as np
from PIL import Image
import torch
import cv2
from torchvision import transforms
from transformers import AutoModel, AutoProcessor  # Import Hugging Face model

# Load pre-trained U2-Net from Hugging Face
model_name = "netradrishti/u2net-saliency"
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_name)
model.eval()

def get_dress_mask(image_np):
    """Segment the dress using the Hugging Face U2-Net model."""
    
    # Convert NumPy image to PIL image
    image = Image.fromarray(image_np).convert("RGB")

    # Preprocess image using the Hugging Face processor
    inputs = processor(images=image, return_tensors="pt")

    with torch.no_grad():
        output = model(**inputs).logits  # Get model output

    # Convert the predicted mask into a binary mask
    mask = output[0, 0].squeeze().cpu().numpy()
    mask = (mask > 0.5).astype(np.uint8)  # Convert to binary mask

    return mask

def detect_dress_and_warp(hoodie_img, design_img, warp_strength, scale):
    if hoodie_img is None or design_img is None:
        return None
    
    hoodie = Image.open(hoodie_img).convert("RGBA")
    design = Image.open(design_img).convert("RGBA")
    hoodie_np = np.array(hoodie)
    
    dress_mask = get_dress_mask(hoodie_np)
    if dress_mask is None:
        return hoodie  # No mask, return original
    
    # Find bounding box of the mask
    y, x = np.where(dress_mask > 0)
    if len(x) == 0 or len(y) == 0:
        return hoodie  # Empty mask, return original

    x_min, x_max, y_min, y_max = x.min(), x.max(), y.min(), y.max()
    center_x, center_y = (x_min + x_max) // 2, (y_min + y_max) // 2
    w, h = x_max - x_min, y_max - y_min
    
    # Scale the design to match the bounding box
    new_w, new_h = int(w * scale), int(h * scale)
    design = design.resize((new_w, new_h))
    design_np = np.array(design)
    
    # Warp the design using TPS
    from tps_warp import apply_tps_warp
    warped_design = apply_tps_warp(design_np, dress_mask, warp_strength)
    design = Image.fromarray(warped_design)
    
    # Compute position to paste the design
    paste_x = center_x - new_w // 2
    paste_y = center_y - new_h // 2
    
    # Create a temporary layer and paste the warped design onto it
    temp_layer = Image.new("RGBA", hoodie.size, (0, 0, 0, 0))
    temp_layer.paste(design, (paste_x, paste_y), design)
    
    # Combine the hoodie with the design layer
    final_image = Image.alpha_composite(hoodie, temp_layer)
    return final_image

def generate_mockup(hoodie, design, warp_strength=5.0, scale=1.0):
    return detect_dress_and_warp(hoodie, design, warp_strength, scale)

demo = gr.Interface(
    fn=generate_mockup,
    inputs=[
        gr.Image(type="filepath", label="Upload Hoodie Mockup"),
        gr.Image(type="filepath", label="Upload Your Design"),
        gr.Slider(0.0, 20.0, value=5.0, label="Warp Strength"),
        gr.Slider(0.5, 2.0, value=1.0, label="Design Scale")
    ],
    outputs=gr.Image(type="pil", label="Generated Mockup"),
    title="Rookus Hoodie Mockup Generator",
    description="Upload a hoodie mockup and your design to generate a realistic preview with automatic dress detection and warping using TPS transformation."
)

if __name__ == "__main__":
    demo.launch()