Blends / app.py
gaur3009's picture
Update app.py
2201b09 verified
raw
history blame
2.54 kB
import gradio as gr
import torch
import cv2
import numpy as np
from torchvision import transforms
from PIL import Image
from transformers import DPTForDepthEstimation, DPTFeatureExtractor
# Load depth estimation model
model_name = "Intel/dpt-large"
feature_extractor = DPTFeatureExtractor.from_pretrained(model_name)
depth_model = DPTForDepthEstimation.from_pretrained(model_name)
depth_model.eval()
def estimate_depth(image):
"""Estimate depth map from image."""
image = image.convert("RGB")
inputs = feature_extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = depth_model(**inputs)
depth = outputs.predicted_depth.squeeze().cpu().numpy()
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255
return depth.astype(np.uint8)
def warp_design(cloth_img, design_img):
"""Warp the design onto the clothing while preserving folds."""
cloth_img = cloth_img.convert("RGB")
design_img = design_img.convert("RGB")
cloth_np = np.array(cloth_img)
design_np = np.array(design_img)
# Ensure both images have the same dimensions
design_np = cv2.resize(design_np, (cloth_np.shape[1], cloth_np.shape[0]))
# Estimate depth for fold detection
depth_map = estimate_depth(cloth_img)
# Generate displacement map based on depth
displacement_x = cv2.Sobel(depth_map, cv2.CV_32F, 1, 0, ksize=5)
displacement_y = cv2.Sobel(depth_map, cv2.CV_32F, 0, 1, ksize=5)
# Normalize displacement values
displacement_x = cv2.normalize(displacement_x, None, -5, 5, cv2.NORM_MINMAX)
displacement_y = cv2.normalize(displacement_y, None, -5, 5, cv2.NORM_MINMAX)
# Warp design using displacement map
h, w, _ = cloth_np.shape
map_x, map_y = np.meshgrid(np.arange(w), np.arange(h))
map_x = np.float32(map_x + displacement_x)
map_y = np.float32(map_y + displacement_y)
warped_design = cv2.remap(design_np, map_x, map_y, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
# Blend images
blended = cv2.addWeighted(cloth_np, 0.6, warped_design, 0.4, 0)
return Image.fromarray(blended)
def main(cloth, design):
return warp_design(cloth, design)
iface = gr.Interface(
fn=main,
inputs=[gr.Image(type="pil"), gr.Image(type="pil")],
outputs=gr.Image(type="pil"),
title="AI Cloth Design Warping",
description="Upload a clothing image and a design to blend it naturally, considering fabric folds."
)
if __name__ == "__main__":
iface.launch()