File size: 3,415 Bytes
00b93d3
 
 
 
 
 
cc5f70a
00b93d3
cc5f70a
 
 
 
00b93d3
 
 
cc5f70a
9144506
cc5f70a
00b93d3
 
2201b09
cc5f70a
00b93d3
9b06638
00b93d3
cc5f70a
 
 
 
 
 
 
 
 
 
 
 
 
9144506
cc5f70a
2201b09
7bda381
00b93d3
 
9144506
cc5f70a
b938ae3
9144506
cc5f70a
9144506
7bda381
9144506
cc5f70a
7d41451
 
109a86d
cc5f70a
9144506
cc5f70a
 
 
9144506
cc5f70a
b938ae3
cc5f70a
109a86d
cc5f70a
 
 
deceec0
cc5f70a
00b93d3
 
9144506
00b93d3
 
 
 
 
 
b938ae3
00b93d3
 
 
cc5f70a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
import torch
import cv2
import numpy as np
from torchvision import transforms
from PIL import Image
from transformers import DPTForDepthEstimation, DPTFeatureExtractor, MidasForDepthEstimation, MidasImageProcessor

# Load depth estimation model (MiDaS v3 for better accuracy)
model_name = "Intel/midas-v3"  # Upgraded model
processor = MidasImageProcessor.from_pretrained(model_name)
depth_model = MidasForDepthEstimation.from_pretrained(model_name)
depth_model.eval()

def estimate_depth(image):
    """Estimate depth map from image using MiDaS v3."""
    image = image.convert("RGB")
    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = depth_model(**inputs)
        depth = outputs.predicted_depth.squeeze().cpu().numpy()
    depth = cv2.resize(depth, (image.width, image.height))
    depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255
    return depth.astype(np.uint8)

def apply_tps_warping(design, depth):
    """Apply Thin Plate Spline (TPS) warping based on depth."""
    h, w = depth.shape
    grid_x, grid_y = np.meshgrid(np.arange(w), np.arange(h))
    displacement_x = cv2.Sobel(depth, cv2.CV_32F, 1, 0, ksize=5)
    displacement_y = cv2.Sobel(depth, cv2.CV_32F, 0, 1, ksize=5)
    displacement_x = cv2.normalize(displacement_x, None, -10, 10, cv2.NORM_MINMAX)
    displacement_y = cv2.normalize(displacement_y, None, -10, 10, cv2.NORM_MINMAX)
    map_x = np.clip(grid_x + displacement_x, 0, w - 1).astype(np.float32)
    map_y = np.clip(grid_y + displacement_y, 0, h - 1).astype(np.float32)
    warped_design = cv2.remap(design, map_x, map_y, interpolation=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REFLECT)
    return warped_design

def blend_design(cloth_img, design_img):
    """Blend design onto clothing naturally with fold adaptation using TPS warping."""
    cloth_img = cloth_img.convert("RGB")
    design_img = design_img.convert("RGBA")
    cloth_np = np.array(cloth_img)
    design_np = np.array(design_img)
    
    # Resize design
    h, w, _ = cloth_np.shape
    dh, dw, _ = design_np.shape
    scale_factor = min(w / dw, h / dh) * 0.4  
    new_w, new_h = int(dw * scale_factor), int(dh * scale_factor)
    design_np = cv2.resize(design_np, (new_w, new_h), interpolation=cv2.INTER_AREA)
    
    # Extract alpha channel
    alpha_channel = design_np[:, :, 3] / 255.0
    design_np = design_np[:, :, :3]
    
    # Create placement area
    x_offset = (w - new_w) // 2
    y_offset = int(h * 0.35)
    design_canvas = np.zeros_like(cloth_np)
    design_canvas[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = design_np
    
    # Estimate depth and apply TPS warping
    depth_map = estimate_depth(cloth_img)
    warped_design = apply_tps_warping(design_canvas, depth_map)
    
    # Blend design onto cloth
    for c in range(3):
        cloth_np[:, :, c] = (cloth_np[:, :, c] * (1 - alpha_channel) + warped_design[:, :, c] * alpha_channel)
    
    return Image.fromarray(cloth_np)

def main(cloth, design):
    return blend_design(cloth, design)

iface = gr.Interface(
    fn=main,
    inputs=[gr.Image(type="pil"), gr.Image(type="pil")],
    outputs=gr.Image(type="pil"),
    title="AI Cloth Design Warping",
    description="Upload a clothing image and a design to blend it naturally, ensuring it stays centered and follows fabric folds."
)

if __name__ == "__main__":
    iface.launch(share=True)