Blends / app.py
gaur3009's picture
Update app.py
deceec0 verified
raw
history blame
3.21 kB
import gradio as gr
import torch
import cv2
import numpy as np
from torchvision import transforms
from PIL import Image
from transformers import DPTForDepthEstimation, DPTFeatureExtractor
# Load depth estimation model
model_name = "Intel/dpt-large"
feature_extractor = DPTFeatureExtractor.from_pretrained(model_name)
depth_model = DPTForDepthEstimation.from_pretrained(model_name)
depth_model.eval()
def estimate_depth(image):
"""Estimate depth map from image."""
image = image.convert("RGB").resize((384, 384)) # Resize for model input
inputs = feature_extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = depth_model(**inputs)
depth = outputs.predicted_depth.squeeze().cpu().numpy()
depth = cv2.resize(depth, (image.width, image.height)) # Resize back
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255
return depth.astype(np.uint8)
def warp_design(cloth_img, design_img):
"""Warp the design onto the clothing while preserving folds."""
cloth_img = cloth_img.convert("RGB")
design_img = design_img.convert("RGB")
cloth_np = np.array(cloth_img)
design_np = np.array(design_img)
h, w, _ = cloth_np.shape
dh, dw, _ = design_np.shape
# Resize design to fit within 70% of the clothing area
scale_factor = min(w / dw, h / dh) * 0.7
new_w, new_h = int(dw * scale_factor), int(dh * scale_factor)
design_np = cv2.resize(design_np, (new_w, new_h))
# Create blank canvas with transparent background
design_canvas = np.zeros_like(cloth_np, dtype=np.uint8)
x_offset = (w - new_w) // 2
y_offset = (h - new_h) // 2
design_canvas[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = design_np
# Estimate depth map
depth_map = estimate_depth(cloth_img)
depth_map = cv2.resize(depth_map, (w, h))
# Generate displacement map based on depth
displacement_x = cv2.Sobel(depth_map, cv2.CV_32F, 1, 0, ksize=5)
displacement_y = cv2.Sobel(depth_map, cv2.CV_32F, 0, 1, ksize=5)
displacement_x = cv2.normalize(displacement_x, None, -3, 3, cv2.NORM_MINMAX)
displacement_y = cv2.normalize(displacement_y, None, -3, 3, cv2.NORM_MINMAX)
map_x, map_y = np.meshgrid(np.arange(w), np.arange(h))
map_x = np.clip(np.float32(map_x + displacement_x), 0, w - 1)
map_y = np.clip(np.float32(map_y + displacement_y), 0, h - 1)
warped_design = cv2.remap(design_canvas, map_x, map_y, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
# Blend images without excessive transparency
mask = np.any(warped_design > 0, axis=-1).astype(np.uint8) * 255
blended = cloth_np.copy()
np.copyto(blended, warped_design, where=(mask[..., None] > 0))
return Image.fromarray(blended)
def main(cloth, design):
return warp_design(cloth, design)
iface = gr.Interface(
fn=main,
inputs=[gr.Image(type="pil"), gr.Image(type="pil")],
outputs=gr.Image(type="pil"),
title="AI Cloth Design Warping",
description="Upload a clothing image and a design to blend it naturally, ensuring it stays centered and follows fabric folds."
)
if __name__ == "__main__":
iface.launch(share=True)