Blends / app.py
gaur3009's picture
Update app.py
109a86d verified
raw
history blame
3.29 kB
import gradio as gr
import torch
import cv2
import numpy as np
from torchvision import transforms
from PIL import Image
from transformers import DPTForDepthEstimation, DPTFeatureExtractor
# Load depth estimation model
model_name = "Intel/dpt-large"
feature_extractor = DPTFeatureExtractor.from_pretrained(model_name)
depth_model = DPTForDepthEstimation.from_pretrained(model_name)
depth_model.eval()
def estimate_depth(image):
"""Estimate depth map from image."""
image = image.convert("RGB")
image = image.resize((384, 384)) # Resize for model input
inputs = feature_extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = depth_model(**inputs)
depth = outputs.predicted_depth.squeeze().cpu().numpy()
depth = cv2.resize(depth, (image.width, image.height)) # Resize back to original
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255
return depth.astype(np.uint8)
def warp_design(cloth_img, design_img):
"""Warp the design onto the clothing while preserving folds."""
cloth_img = cloth_img.convert("RGB")
design_img = design_img.convert("RGB")
cloth_np = np.array(cloth_img)
design_np = np.array(design_img)
# Estimate depth for fold detection
depth_map = estimate_depth(cloth_img)
depth_map = cv2.resize(depth_map, (cloth_np.shape[1], cloth_np.shape[0])) # Ensure matching dimensions
# Resize design to fit centrally on clothing
design_h, design_w = cloth_np.shape[:2]
center_x, center_y = design_w // 2, design_h // 2
resized_design = cv2.resize(design_np, (design_w // 2, design_h // 2))
# Place design at the center of the clothing image
y_offset = center_y - resized_design.shape[0] // 2
x_offset = center_x - resized_design.shape[1] // 2
blended_design = cloth_np.copy()
blended_design[y_offset:y_offset + resized_design.shape[0], x_offset:x_offset + resized_design.shape[1]] = resized_design
# Generate displacement map based on depth
displacement_x = cv2.Sobel(depth_map, cv2.CV_32F, 1, 0, ksize=5)
displacement_y = cv2.Sobel(depth_map, cv2.CV_32F, 0, 1, ksize=5)
# Normalize displacement values
displacement_x = cv2.normalize(displacement_x, None, -5, 5, cv2.NORM_MINMAX)
displacement_y = cv2.normalize(displacement_y, None, -5, 5, cv2.NORM_MINMAX)
# Warp design using displacement map
h, w, _ = cloth_np.shape
map_x, map_y = np.meshgrid(np.arange(w), np.arange(h))
map_x = np.clip(np.float32(map_x + displacement_x), 0, w - 1)
map_y = np.clip(np.float32(map_y + displacement_y), 0, h - 1)
warped_design = cv2.remap(blended_design, map_x, map_y, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
# Blend images
blended = cv2.addWeighted(cloth_np, 0.6, warped_design, 0.4, 0)
return Image.fromarray(blended)
def main(cloth, design):
return warp_design(cloth, design)
iface = gr.Interface(
fn=main,
inputs=[gr.Image(type="pil"), gr.Image(type="pil")],
outputs=gr.Image(type="pil"),
title="AI Cloth Design Warping",
description="Upload a clothing image and a design to blend it naturally at the center, considering fabric folds."
)
if __name__ == "__main__":
iface.launch(share=True)