Blends / app.py
gaur3009's picture
Update app.py
9a5bad4 verified
raw
history blame
3.59 kB
import gradio as gr
import torch
import cv2
import numpy as np
from torchvision import transforms
from PIL import Image
from scipy.interpolate import Rbf
# Load MiDaS depth estimation model
midas_model = torch.hub.load("intel-isl/MiDaS", "MiDaS_small")
midas_model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
midas_model.to(device)
midas_transform = torch.hub.load("intel-isl/MiDaS", "transforms").default_transform
def estimate_depth(image):
image = image.convert("RGB")
img_tensor = midas_transform(image).to(device)
with torch.no_grad():
depth = midas_model(img_tensor).squeeze().cpu().numpy()
depth = cv2.resize(depth, (image.size[0], image.size[1]))
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255
return depth.astype(np.uint8)
def compute_optical_flow(depth):
depth_blurred = cv2.GaussianBlur(depth, (5, 5), 0)
flow = cv2.calcOpticalFlowFarneback(depth_blurred, depth, None, 0.5, 3, 15, 3, 5, 1.2, 0)
displacement_x = cv2.normalize(flow[..., 0], None, -5, 5, cv2.NORM_MINMAX)
displacement_y = cv2.normalize(flow[..., 1], None, -5, 5, cv2.NORM_MINMAX)
return displacement_x, displacement_y
def apply_tps_interpolation(design, depth):
h, w = depth.shape
grid_x, grid_y = np.meshgrid(np.arange(w), np.arange(h))
edges = cv2.Canny(depth.astype(np.uint8), 50, 150)
points = np.column_stack(np.where(edges > 0))
tps_x = Rbf(points[:, 1], points[:, 0], grid_x[points[:, 0], points[:, 1]], function="thin_plate")
tps_y = Rbf(points[:, 1], points[:, 0], grid_y[points[:, 0], points[:, 1]], function="thin_plate")
map_x = tps_x(grid_x, grid_y).astype(np.float32)
map_y = tps_y(grid_x, grid_y).astype(np.float32)
return cv2.remap(design, map_x, map_y, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
def compute_adaptive_alpha(depth):
grad_x = cv2.Sobel(depth, cv2.CV_32F, 1, 0, ksize=3)
grad_y = cv2.Sobel(depth, cv2.CV_32F, 0, 1, ksize=3)
grad_magnitude = np.sqrt(grad_x**2 + grad_y**2)
alpha = cv2.normalize(grad_magnitude, None, 0, 1, cv2.NORM_MINMAX)
return alpha
def blend_design(cloth_img, design_img):
cloth_img = cloth_img.convert("RGB")
design_img = design_img.convert("RGBA")
cloth_np = np.array(cloth_img)
design_np = np.array(design_img)
h, w, _ = cloth_np.shape
dh, dw, _ = design_np.shape
scale_factor = min(w / dw, h / dh) * 0.4
new_w, new_h = int(dw * scale_factor), int(dh * scale_factor)
design_np = cv2.resize(design_np, (new_w, new_h), interpolation=cv2.INTER_AREA)
alpha_channel = design_np[:, :, 3] / 255.0
design_np = design_np[:, :, :3]
x_offset = (w - new_w) // 2
y_offset = int(h * 0.35)
design_canvas = np.zeros_like(cloth_np)
design_canvas[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = design_np
depth_map = estimate_depth(cloth_img)
warped_design = apply_tps_interpolation(design_canvas, depth_map)
adaptive_alpha = compute_adaptive_alpha(depth_map)
cloth_np = (cloth_np * (1 - adaptive_alpha) + warped_design * adaptive_alpha).astype(np.uint8)
return Image.fromarray(cloth_np)
def main(cloth, design):
return blend_design(cloth, design)
iface = gr.Interface(
fn=main,
inputs=[gr.Image(type="pil"), gr.Image(type="pil")],
outputs=gr.Image(type="pil"),
title="AI Cloth Design Warping",
description="Upload a clothing image and a design to blend it naturally, ensuring it stays centered and follows fabric folds."
)
if __name__ == "__main__":
iface.launch(share=True)