detecting_dress / app.py
gaur3009's picture
Update app.py
3758aa6 verified
raw
history blame
5.37 kB
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET # Import U²-Net
# Load U²-Net model
model_path = "cloth_segmentation/networks/u2net.pth"
model = U2NET(3, 1)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} # Remove 'module.' prefix
model.load_state_dict(state_dict)
model.eval()
def remove_background(image_np):
"""Removes background using U²-Net and replaces it with white."""
transform_pipeline = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((320, 320))
])
image = Image.fromarray(image_np).convert("RGB")
input_tensor = transform_pipeline(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)[0][0].squeeze().cpu().numpy()
mask = (output > 0.5).astype(np.uint8) * 255
mask = cv2.resize(mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
white_bg = np.ones_like(image_np) * 255 # White background
segmented_image = np.where(mask[..., None] > 128, image_np, white_bg)
return segmented_image, mask
def segment_dress(image_np):
"""Segments the dress using K-means and refines with U²-Net."""
# Convert to Lab color space
img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
pixel_values = img_lab.reshape((-1, 3)).astype(np.float32)
# K-means clustering to detect dress region
k = 3 # Three clusters: background, skin, dress
_, labels, centers = cv2.kmeans(pixel_values, k, None, (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0), 10, cv2.KMEANS_RANDOM_CENTERS)
labels = labels.reshape(image_np.shape[:2])
# Assume dress is the largest non-background cluster
unique_labels, counts = np.unique(labels, return_counts=True)
dress_label = unique_labels[np.argmax(counts[1:]) + 1] # Avoid background
# Create dress mask
mask = (labels == dress_label).astype(np.uint8) * 255
# Refine with U²-Net prediction
transform_pipeline = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((320, 320))
])
image = Image.fromarray(image_np).convert("RGB")
input_tensor = transform_pipeline(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)[0][0].squeeze().cpu().numpy()
u2net_mask = (output > 0.5).astype(np.uint8) * 255
u2net_mask = cv2.resize(u2net_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
# Combine masks
refined_mask = cv2.bitwise_and(mask, u2net_mask)
# Morphological operations for smoothness
kernel = np.ones((5, 5), np.uint8)
refined_mask = cv2.morphologyEx(refined_mask, cv2.MORPH_CLOSE, kernel)
refined_mask = cv2.GaussianBlur(refined_mask, (15, 15), 5)
return refined_mask
def recolor_dress(image_np, mask, target_color):
"""Change dress color while keeping the white background intact."""
img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
# Preserve lightness (L) and change only chromatic channels (A & B)
blend_factor = 0.7
img_lab[..., 1] = np.where(mask > 128, img_lab[..., 1] * (1 - blend_factor) + target_color_lab[1] * blend_factor, img_lab[..., 1])
img_lab[..., 2] = np.where(mask > 128, img_lab[..., 2] * (1 - blend_factor) + target_color_lab[2] * blend_factor, img_lab[..., 2])
img_recolored = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)
return img_recolored
def process_image(image_path, color):
"""Remove background, segment dress, and recolor while keeping background white."""
if image_path is None:
return None
img = Image.open(image_path).convert("RGB")
img_np = np.array(img)
# Remove background
img_segmented, _ = remove_background(img_np)
# Get dress segmentation mask
mask = segment_dress(img_np)
if mask is None:
return Image.fromarray(img_segmented) # No dress detected, return only background removal
# Convert selected color to BGR
color_map = {
"Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0), "Yellow": (0, 255, 255),
"Purple": (128, 0, 128), "Orange": (0, 165, 255), "Cyan": (255, 255, 0), "Magenta": (255, 0, 255),
"White": (255, 255, 255), "Black": (0, 0, 0)
}
new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8) # Default to Red
# Apply recoloring
img_recolored = recolor_dress(img_segmented, mask, new_color_bgr)
return Image.fromarray(img_recolored)
# Gradio Interface
demo = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="filepath", label="Upload Dress Image"),
gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple", "Orange", "Cyan", "Magenta", "White", "Black"], label="Choose New Dress Color")
],
outputs=gr.Image(type="pil", label="Final Dress Image"),
title="Dress Color Changer with Background Removal",
description="Upload an image of a dress, remove its background, and recolor it naturally while keeping the background white."
)
if __name__ == "__main__":
demo.launch()