detecting_dress / app.py
gaur3009's picture
Update app.py
d00e30a verified
raw
history blame
4.48 kB
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET # Import U²-Net
# Load U²-Net model
model_path = "cloth_segmentation/networks/u2net.pth"
model = U2NET(3, 1)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} # Remove 'module.' prefix
model.load_state_dict(state_dict)
model.eval()
def detect_design(image_np):
"""Detects the design on the dress using edge detection and adaptive thresholding."""
gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
# Use adaptive thresholding to segment the design
adaptive_thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV, 11, 2)
# Detect edges using Canny
edges = cv2.Canny(gray, 50, 150)
# Combine both masks
design_mask = cv2.bitwise_or(adaptive_thresh, edges)
# Morphological operations to remove noise
kernel = np.ones((3, 3), np.uint8)
design_mask = cv2.morphologyEx(design_mask, cv2.MORPH_CLOSE, kernel)
return design_mask
def segment_dress(image_np):
"""Segment the dress using U²-Net"""
transform_pipeline = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((320, 320))
])
image = Image.fromarray(image_np).convert("RGB")
input_tensor = transform_pipeline(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)[0][0].squeeze().cpu().numpy()
# Convert output to mask
dress_mask = (output > 0.5).astype(np.uint8) * 255
dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
# Morphological operations for smoothness
kernel = np.ones((5, 5), np.uint8)
dress_mask = cv2.morphologyEx(dress_mask, cv2.MORPH_CLOSE, kernel)
return dress_mask
def recolor_dress(image_np, dress_mask, design_mask, target_color):
"""Change dress color while preserving designs"""
img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
# Ensure the design areas are NOT recolored
recolor_mask = cv2.bitwise_and(dress_mask, cv2.bitwise_not(design_mask))
# Apply color change only to the non-design dress areas
blend_factor = 0.8
img_lab[..., 1] = np.where(recolor_mask > 128, img_lab[..., 1] * (1 - blend_factor) + target_color_lab[1] * blend_factor, img_lab[..., 1])
img_lab[..., 2] = np.where(recolor_mask > 128, img_lab[..., 2] * (1 - blend_factor) + target_color_lab[2] * blend_factor, img_lab[..., 2])
img_recolored = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)
return img_recolored
def change_dress_color(image_path, color):
"""Change the dress color naturally while keeping designs intact."""
if image_path is None:
return None
img = Image.open(image_path).convert("RGB")
img_np = np.array(img)
# Get dress segmentation mask
dress_mask = segment_dress(img_np)
if dress_mask is None:
return img # No dress detected
# Detect design on the dress
design_mask = detect_design(img_np)
# Convert the selected color to BGR
color_map = {
"Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0), "Yellow": (0, 255, 255),
"Purple": (128, 0, 128), "Orange": (0, 165, 255), "Cyan": (255, 255, 0), "Magenta": (255, 0, 255),
"White": (255, 255, 255), "Black": (0, 0, 0)
}
new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8) # Default to Red
# Apply recoloring logic
img_recolored = recolor_dress(img_np, dress_mask, design_mask, new_color_bgr)
return Image.fromarray(img_recolored)
# Gradio Interface
demo = gr.Interface(
fn=change_dress_color,
inputs=[
gr.Image(type="filepath", label="Upload Dress Image"),
gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple", "Orange", "Cyan", "Magenta", "White", "Black"], label="Choose New Dress Color")
],
outputs=gr.Image(type="pil", label="Color Changed Dress"),
title="Dress Color Changer",
description="Upload an image of a dress and select a new color to change its appearance naturally while preserving designs."
)
if __name__ == "__main__":
demo.launch()