Blends / app.py
gaur3009's picture
Update app.py
f189d0f verified
raw
history blame
2.77 kB
import gradio as gr
import torch
import cv2
import numpy as np
from torchvision import transforms
from PIL import Image
from transformers import DPTForDepthEstimation, DPTFeatureExtractor
# Load depth estimation model
model_name = "Intel/dpt-large"
feature_extractor = DPTFeatureExtractor.from_pretrained(model_name)
depth_model = DPTForDepthEstimation.from_pretrained(model_name)
depth_model.eval()
def estimate_depth(image):
"""Estimate depth map from image."""
image = image.convert("RGB")
inputs = feature_extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = depth_model(**inputs)
depth = outputs.predicted_depth.squeeze().cpu().numpy()
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255
depth = depth.astype(np.uint8)
return cv2.GaussianBlur(depth, (5, 5), 0) # Smooth depth map to reduce noise
def warp_design(cloth_img, design_img):
"""Warp the design onto the clothing while preserving folds."""
cloth_img = cloth_img.convert("RGB")
design_img = design_img.convert("RGB")
cloth_np = np.array(cloth_img)
design_np = np.array(design_img)
# Ensure both images have the same dimensions
design_np = cv2.resize(design_np, (cloth_np.shape[1], cloth_np.shape[0]))
# Estimate depth for fold detection
depth_map = estimate_depth(cloth_img)
# Generate displacement map based on depth
displacement_x = cv2.Sobel(depth_map, cv2.CV_32F, 1, 0, ksize=3)
displacement_y = cv2.Sobel(depth_map, cv2.CV_32F, 0, 1, ksize=3)
# Normalize displacement values
displacement_x = cv2.normalize(displacement_x, None, -3, 3, cv2.NORM_MINMAX)
displacement_y = cv2.normalize(displacement_y, None, -3, 3, cv2.NORM_MINMAX)
# Warp design using displacement map
h, w, _ = cloth_np.shape
map_x, map_y = np.meshgrid(np.arange(w), np.arange(h))
map_x = np.clip(np.float32(map_x + displacement_x), 0, w - 1)
map_y = np.clip(np.float32(map_y + displacement_y), 0, h - 1)
warped_design = cv2.remap(design_np, map_x, map_y, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
# Convert warped design to match original image dtype
warped_design = warped_design.astype(np.uint8)
# Blend images
blended = cv2.addWeighted(cloth_np, 0.7, warped_design, 0.3, 0)
return Image.fromarray(blended)
def main(cloth, design):
return warp_design(cloth, design)
iface = gr.Interface(
fn=main,
inputs=[gr.Image(type="pil"), gr.Image(type="pil")],
outputs=gr.Image(type="pil"),
title="AI Cloth Design Warping",
description="Upload a clothing image and a design to blend it naturally, considering fabric folds."
)
if __name__ == "__main__":
iface.launch()