File size: 3,877 Bytes
00b93d3 9a5bad4 00b93d3 9a5bad4 b22be33 7ba5ba4 00b93d3 9a5bad4 d55d4b0 00b93d3 d55d4b0 3327a73 00b93d3 9b06638 00b93d3 9a5bad4 cc5f70a 9a5bad4 cc5f70a 9144506 2201b09 7bda381 00b93d3 b938ae3 9144506 cc5f70a 9144506 7bda381 7d41451 9144506 cc5f70a b938ae3 9a5bad4 cc5f70a 00b93d3 b22be33 9144506 00b93d3 b938ae3 00b93d3 afb33f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import gradio as gr
import torch
import cv2
import numpy as np
from torchvision import transforms
from PIL import Image
from scipy.interpolate import Rbf
# Load MiDaS depth estimation model
midas_model = torch.hub.load("intel-isl/MiDaS", "DPT_Hybrid")
midas_model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
midas_model.to(device)
midas_transform = torch.hub.load("intel-isl/MiDaS", "transforms").default_transform
def estimate_depth(image):
image = image.convert("RGB")
image_np = np.array(image) / 255.0 # Convert PIL image to NumPy and normalize
image_tensor = torch.tensor(image_np, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
depth = midas_model(image_tensor).squeeze().cpu().numpy()
depth = cv2.resize(depth, (image.size[0], image.size[1]))
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255
return depth.astype(np.uint8)
def compute_optical_flow(depth):
depth_blurred = cv2.GaussianBlur(depth, (5, 5), 0)
flow = cv2.calcOpticalFlowFarneback(depth_blurred, depth, None, 0.5, 3, 15, 3, 5, 1.2, 0)
displacement_x = cv2.normalize(flow[..., 0], None, -5, 5, cv2.NORM_MINMAX)
displacement_y = cv2.normalize(flow[..., 1], None, -5, 5, cv2.NORM_MINMAX)
return displacement_x, displacement_y
def apply_tps_interpolation(design, depth):
h, w = depth.shape
grid_x, grid_y = np.meshgrid(np.arange(w), np.arange(h))
edges = cv2.Canny(depth.astype(np.uint8), 50, 150)
points = np.column_stack(np.where(edges > 0))
tps_x = Rbf(points[:, 1], points[:, 0], grid_x[points[:, 0], points[:, 1]], function="thin_plate")
tps_y = Rbf(points[:, 1], points[:, 0], grid_y[points[:, 0], points[:, 1]], function="thin_plate")
map_x = tps_x(grid_x, grid_y).astype(np.float32)
map_y = tps_y(grid_x, grid_y).astype(np.float32)
return cv2.remap(design, map_x, map_y, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
def compute_adaptive_alpha(depth):
grad_x = cv2.Sobel(depth, cv2.CV_32F, 1, 0, ksize=3)
grad_y = cv2.Sobel(depth, cv2.CV_32F, 0, 1, ksize=3)
grad_magnitude = np.sqrt(grad_x**2 + grad_y**2)
alpha = cv2.normalize(grad_magnitude, None, 0, 1, cv2.NORM_MINMAX)
return alpha
def blend_design(cloth_img, design_img):
cloth_img = cloth_img.convert("RGB")
design_img = design_img.convert("RGBA")
cloth_np = np.array(cloth_img)
design_np = np.array(design_img)
h, w, _ = cloth_np.shape
dh, dw, _ = design_np.shape
scale_factor = min(w / dw, h / dh) * 0.4
new_w, new_h = int(dw * scale_factor), int(dh * scale_factor)
design_np = cv2.resize(design_np, (new_w, new_h), interpolation=cv2.INTER_AREA)
alpha_channel = design_np[:, :, 3] / 255.0
design_np = design_np[:, :, :3]
x_offset = (w - new_w) // 2
y_offset = int(h * 0.35)
design_canvas = np.zeros_like(cloth_np)
design_canvas[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = design_np
depth_map = estimate_depth(cloth_img)
warped_design = apply_tps_interpolation(design_canvas, depth_map)
adaptive_alpha = compute_adaptive_alpha(depth_map)
cloth_np = (cloth_np * (1 - adaptive_alpha) + warped_design * adaptive_alpha).astype(np.uint8)
return Image.fromarray(cloth_np)
def main(cloth, design):
global midas_model
if midas_model is None:
midas_model = torch.hub.load("intel-isl/MiDaS", "MiDaS_small").to(device).eval()
return blend_design(cloth, design)
iface = gr.Interface(
fn=main,
inputs=[gr.Image(type="pil"), gr.Image(type="pil")],
outputs=gr.Image(type="pil"),
title="AI Cloth Design Warping",
description="Upload a clothing image and a design to blend it naturally, ensuring it stays centered and follows fabric folds."
)
if __name__ == "__main__":
iface.launch(share=False) |