File size: 3,888 Bytes
00b93d3
 
 
 
 
 
9a5bad4
00b93d3
9a5bad4
b22be33
7ba5ba4
 
 
 
00b93d3
 
9a5bad4
d55d4b0
 
 
00b93d3
d55d4b0
 
3327a73
00b93d3
9b06638
00b93d3
9a5bad4
 
 
 
 
 
 
 
cc5f70a
 
9a5bad4
 
 
 
 
 
 
 
 
 
 
 
 
 
cc5f70a
9144506
2201b09
7bda381
00b93d3
 
b938ae3
9144506
cc5f70a
9144506
7bda381
7d41451
 
9144506
cc5f70a
 
 
b938ae3
9a5bad4
 
 
cc5f70a
00b93d3
 
b22be33
 
 
9144506
00b93d3
 
 
 
 
 
b938ae3
00b93d3
 
 
c60b02d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
import torch
import cv2
import numpy as np
from torchvision import transforms
from PIL import Image
from scipy.interpolate import Rbf

# Load MiDaS depth estimation model
midas_model = torch.hub.load("intel-isl/MiDaS", "DPT_Hybrid")
midas_model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
midas_model.to(device)
midas_transform = torch.hub.load("intel-isl/MiDaS", "transforms").default_transform

def estimate_depth(image):
    image = image.convert("RGB")
    image_np = np.array(image) / 255.0  # Convert PIL image to NumPy and normalize
    image_tensor = torch.tensor(image_np, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(device)
    
    with torch.no_grad():
        depth = midas_model(image_tensor).squeeze().cpu().numpy()
    
    depth = cv2.resize(depth, (image.size[0], image.size[1]))
    depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255
    return depth.astype(np.uint8)

def compute_optical_flow(depth):
    depth_blurred = cv2.GaussianBlur(depth, (5, 5), 0)
    flow = cv2.calcOpticalFlowFarneback(depth_blurred, depth, None, 0.5, 3, 15, 3, 5, 1.2, 0)
    displacement_x = cv2.normalize(flow[..., 0], None, -5, 5, cv2.NORM_MINMAX)
    displacement_y = cv2.normalize(flow[..., 1], None, -5, 5, cv2.NORM_MINMAX)
    return displacement_x, displacement_y

def apply_tps_interpolation(design, depth):
    h, w = depth.shape
    grid_x, grid_y = np.meshgrid(np.arange(w), np.arange(h))
    edges = cv2.Canny(depth.astype(np.uint8), 50, 150)
    points = np.column_stack(np.where(edges > 0))
    tps_x = Rbf(points[:, 1], points[:, 0], grid_x[points[:, 0], points[:, 1]], function="thin_plate")
    tps_y = Rbf(points[:, 1], points[:, 0], grid_y[points[:, 0], points[:, 1]], function="thin_plate")
    map_x = tps_x(grid_x, grid_y).astype(np.float32)
    map_y = tps_y(grid_x, grid_y).astype(np.float32)
    return cv2.remap(design, map_x, map_y, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)

def compute_adaptive_alpha(depth):
    grad_x = cv2.Sobel(depth, cv2.CV_32F, 1, 0, ksize=3)
    grad_y = cv2.Sobel(depth, cv2.CV_32F, 0, 1, ksize=3)
    grad_magnitude = np.sqrt(grad_x**2 + grad_y**2)
    alpha = cv2.normalize(grad_magnitude, None, 0, 1, cv2.NORM_MINMAX)
    return alpha

def blend_design(cloth_img, design_img):
    cloth_img = cloth_img.convert("RGB")
    design_img = design_img.convert("RGBA")
    cloth_np = np.array(cloth_img)
    design_np = np.array(design_img)
    h, w, _ = cloth_np.shape
    dh, dw, _ = design_np.shape
    scale_factor = min(w / dw, h / dh) * 0.4  
    new_w, new_h = int(dw * scale_factor), int(dh * scale_factor)
    design_np = cv2.resize(design_np, (new_w, new_h), interpolation=cv2.INTER_AREA)
    alpha_channel = design_np[:, :, 3] / 255.0
    design_np = design_np[:, :, :3]
    x_offset = (w - new_w) // 2
    y_offset = int(h * 0.35)
    design_canvas = np.zeros_like(cloth_np)
    design_canvas[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = design_np
    depth_map = estimate_depth(cloth_img)
    warped_design = apply_tps_interpolation(design_canvas, depth_map)
    adaptive_alpha = compute_adaptive_alpha(depth_map)
    cloth_np = (cloth_np * (1 - adaptive_alpha) + warped_design * adaptive_alpha).astype(np.uint8)
    return Image.fromarray(cloth_np)

def main(cloth, design):
    global midas_model
    if midas_model is None:
        midas_model = torch.hub.load("intel-isl/MiDaS", "MiDaS_small").to(device).eval()
    return blend_design(cloth, design)

iface = gr.Interface(
    fn=main,
    inputs=[gr.Image(type="pil"), gr.Image(type="pil")],
    outputs=gr.Image(type="pil"),
    title="AI Cloth Design Warping",
    description="Upload a clothing image and a design to blend it naturally, ensuring it stays centered and follows fabric folds."
)

if __name__ == "__main__":
    iface.launch(share=True, debug=True)