Blends / app.py
gaur3009's picture
Update app.py
afb33f0 verified
raw
history blame
3.88 kB
import gradio as gr
import torch
import cv2
import numpy as np
from torchvision import transforms
from PIL import Image
from scipy.interpolate import Rbf
# Load MiDaS depth estimation model
midas_model = torch.hub.load("intel-isl/MiDaS", "DPT_Hybrid")
midas_model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
midas_model.to(device)
midas_transform = torch.hub.load("intel-isl/MiDaS", "transforms").default_transform
def estimate_depth(image):
image = image.convert("RGB")
image_np = np.array(image) / 255.0 # Convert PIL image to NumPy and normalize
image_tensor = torch.tensor(image_np, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
depth = midas_model(image_tensor).squeeze().cpu().numpy()
depth = cv2.resize(depth, (image.size[0], image.size[1]))
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255
return depth.astype(np.uint8)
def compute_optical_flow(depth):
depth_blurred = cv2.GaussianBlur(depth, (5, 5), 0)
flow = cv2.calcOpticalFlowFarneback(depth_blurred, depth, None, 0.5, 3, 15, 3, 5, 1.2, 0)
displacement_x = cv2.normalize(flow[..., 0], None, -5, 5, cv2.NORM_MINMAX)
displacement_y = cv2.normalize(flow[..., 1], None, -5, 5, cv2.NORM_MINMAX)
return displacement_x, displacement_y
def apply_tps_interpolation(design, depth):
h, w = depth.shape
grid_x, grid_y = np.meshgrid(np.arange(w), np.arange(h))
edges = cv2.Canny(depth.astype(np.uint8), 50, 150)
points = np.column_stack(np.where(edges > 0))
tps_x = Rbf(points[:, 1], points[:, 0], grid_x[points[:, 0], points[:, 1]], function="thin_plate")
tps_y = Rbf(points[:, 1], points[:, 0], grid_y[points[:, 0], points[:, 1]], function="thin_plate")
map_x = tps_x(grid_x, grid_y).astype(np.float32)
map_y = tps_y(grid_x, grid_y).astype(np.float32)
return cv2.remap(design, map_x, map_y, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
def compute_adaptive_alpha(depth):
grad_x = cv2.Sobel(depth, cv2.CV_32F, 1, 0, ksize=3)
grad_y = cv2.Sobel(depth, cv2.CV_32F, 0, 1, ksize=3)
grad_magnitude = np.sqrt(grad_x**2 + grad_y**2)
alpha = cv2.normalize(grad_magnitude, None, 0, 1, cv2.NORM_MINMAX)
return alpha
def blend_design(cloth_img, design_img):
cloth_img = cloth_img.convert("RGB")
design_img = design_img.convert("RGBA")
cloth_np = np.array(cloth_img)
design_np = np.array(design_img)
h, w, _ = cloth_np.shape
dh, dw, _ = design_np.shape
scale_factor = min(w / dw, h / dh) * 0.4
new_w, new_h = int(dw * scale_factor), int(dh * scale_factor)
design_np = cv2.resize(design_np, (new_w, new_h), interpolation=cv2.INTER_AREA)
alpha_channel = design_np[:, :, 3] / 255.0
design_np = design_np[:, :, :3]
x_offset = (w - new_w) // 2
y_offset = int(h * 0.35)
design_canvas = np.zeros_like(cloth_np)
design_canvas[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = design_np
depth_map = estimate_depth(cloth_img)
warped_design = apply_tps_interpolation(design_canvas, depth_map)
adaptive_alpha = compute_adaptive_alpha(depth_map)
cloth_np = (cloth_np * (1 - adaptive_alpha) + warped_design * adaptive_alpha).astype(np.uint8)
return Image.fromarray(cloth_np)
def main(cloth, design):
global midas_model
if midas_model is None:
midas_model = torch.hub.load("intel-isl/MiDaS", "MiDaS_small").to(device).eval()
return blend_design(cloth, design)
iface = gr.Interface(
fn=main,
inputs=[gr.Image(type="pil"), gr.Image(type="pil")],
outputs=gr.Image(type="pil"),
title="AI Cloth Design Warping",
description="Upload a clothing image and a design to blend it naturally, ensuring it stays centered and follows fabric folds."
)
if __name__ == "__main__":
iface.launch(share=False)