Blends / app.py
gaur3009's picture
Update app.py
3327a73 verified
raw
history blame
3.62 kB
import gradio as gr
import torch
import cv2
import numpy as np
from torchvision import transforms
from PIL import Image
# Load MiDaS depth estimation model from torch.hub
midas_model = torch.hub.load("intel-isl/MiDaS", "MiDaS_small")
midas_model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
midas_model.to(device)
midas_transform = torch.hub.load("intel-isl/MiDaS", "transforms").default_transform
def estimate_depth(image):
"""Estimate depth map using MiDaS v3."""
image = image.convert("RGB") # Ensure it's in RGB format
img_tensor = midas_transform(image).unsqueeze(0).to(device)
# Ensure tensor shape is [1, 3, H, W]
if img_tensor.dim() == 5: # If an extra batch dimension is present
img_tensor = img_tensor.squeeze(1)
with torch.no_grad():
depth = midas_model(img_tensor).squeeze().cpu().numpy()
depth = cv2.resize(depth, (image.size[0], image.size[1]))
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255
return depth.astype(np.uint8)
def apply_tps_warping(design, depth):
"""Apply Thin Plate Spline (TPS) warping based on depth."""
h, w = depth.shape
grid_x, grid_y = np.meshgrid(np.arange(w), np.arange(h))
displacement_x = cv2.Sobel(depth, cv2.CV_32F, 1, 0, ksize=5)
displacement_y = cv2.Sobel(depth, cv2.CV_32F, 0, 1, ksize=5)
displacement_x = cv2.normalize(displacement_x, None, -5, 5, cv2.NORM_MINMAX)
displacement_y = cv2.normalize(displacement_y, None, -5, 5, cv2.NORM_MINMAX)
map_x = np.clip(grid_x + displacement_x, 0, w - 1).astype(np.float32)
map_y = np.clip(grid_y + displacement_y, 0, h - 1).astype(np.float32)
warped_design = cv2.remap(design, map_x, map_y, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
return warped_design
def blend_design(cloth_img, design_img):
"""Blend design onto clothing naturally with fold adaptation using TPS warping."""
cloth_img = cloth_img.convert("RGB")
design_img = design_img.convert("RGBA")
cloth_np = np.array(cloth_img)
design_np = np.array(design_img)
# Resize design
h, w, _ = cloth_np.shape
dh, dw, _ = design_np.shape
scale_factor = min(w / dw, h / dh) * 0.4
new_w, new_h = int(dw * scale_factor), int(dh * scale_factor)
design_np = cv2.resize(design_np, (new_w, new_h), interpolation=cv2.INTER_AREA)
# Extract alpha channel
alpha_channel = design_np[:, :, 3] / 255.0
design_np = design_np[:, :, :3]
# Create placement area
x_offset = (w - new_w) // 2
y_offset = int(h * 0.35)
design_canvas = np.zeros_like(cloth_np)
design_canvas[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = design_np
# Estimate depth and apply TPS warping
depth_map = estimate_depth(cloth_img)
warped_design = apply_tps_warping(design_canvas, depth_map)
# Ensure alpha is applied correctly
mask = np.zeros_like(cloth_np, dtype=np.float32)
mask[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = np.expand_dims(alpha_channel, axis=-1)
cloth_np = (cloth_np * (1 - mask) + warped_design * mask).astype(np.uint8)
return Image.fromarray(cloth_np)
def main(cloth, design):
return blend_design(cloth, design)
iface = gr.Interface(
fn=main,
inputs=[gr.Image(type="pil"), gr.Image(type="pil")],
outputs=gr.Image(type="pil"),
title="AI Cloth Design Warping",
description="Upload a clothing image and a design to blend it naturally, ensuring it stays centered and follows fabric folds."
)
if __name__ == "__main__":
iface.launch(share=True)