detecting_dress / app.py
gaur3009's picture
Update app.py
ba17c5b verified
raw
history blame
3.69 kB
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET
# Load U²-Net model
model_path = "cloth_segmentation/networks/u2net.pth"
model = U2NET(3, 1)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
def segment_dress(image_np):
"""Segment dress using U²-Net"""
transform_pipeline = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((320, 320))
])
image = Image.fromarray(image_np).convert("RGB")
input_tensor = transform_pipeline(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)[0][0].squeeze().cpu().numpy()
dress_mask = (output > 0.5).astype(np.uint8) * 255
dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
return dress_mask
def recolor_dress(image_np, dress_mask, target_color):
"""Change dress color naturally while keeping textures intact"""
# Convert target color to LAB
target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
# Convert image to LAB for better color control
img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
# Compute mean LAB values of dress pixels
dress_pixels = img_lab[dress_mask > 0]
mean_L, mean_A, mean_B = dress_pixels[:, 0].mean(), dress_pixels[:, 1].mean(), dress_pixels[:, 2].mean()
# Compute new color adjustment
img_lab[..., 1] = np.where(dress_mask > 128, img_lab[..., 1] - mean_A + target_color_lab[1], img_lab[..., 1])
img_lab[..., 2] = np.where(dress_mask > 128, img_lab[..., 2] - mean_B + target_color_lab[2], img_lab[..., 2])
# Convert back to RGB
img_recolored = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)
# Smooth edges for natural blending
img_recolored = cv2.seamlessClone(img_recolored, image_np, dress_mask, (image_np.shape[1]//2, image_np.shape[0]//2), cv2.NORMAL_CLONE)
return img_recolored
def change_dress_color(image_path, color):
"""Main function to change dress color naturally"""
if image_path is None:
return None
img = Image.open(image_path).convert("RGB")
img_np = np.array(img)
# Get dress segmentation mask
dress_mask = segment_dress(img_np)
if dress_mask is None:
return img # No dress detected
# Convert the selected color to BGR
color_map = {
"Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0), "Yellow": (0, 255, 255),
"Purple": (128, 0, 128), "Orange": (0, 165, 255), "Cyan": (255, 255, 0), "Magenta": (255, 0, 255),
"White": (255, 255, 255), "Black": (0, 0, 0)
}
new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8)
# Apply recoloring with blending
img_recolored = recolor_dress(img_np, dress_mask, new_color_bgr)
return Image.fromarray(img_recolored)
# Gradio Interface
demo = gr.Interface(
fn=change_dress_color,
inputs=[
gr.Image(type="filepath", label="Upload Dress Image"),
gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple", "Orange", "Cyan", "Magenta", "White", "Black"], label="Choose New Dress Color")
],
outputs=gr.Image(type="pil", label="Color Changed Dress"),
title="Realistic Dress Color Changer",
description="Upload an image of a dress and select a new color. The AI will change the dress color naturally while keeping the fabric texture."
)
if __name__ == "__main__":
demo.launch()