Spaces:
Running
Running
File size: 5,599 Bytes
984b1c3 ac99583 984b1c3 d3f9ca8 984b1c3 1c2f991 e49358d 1c2f991 984b1c3 d3f9ca8 984b1c3 d3f9ca8 984b1c3 d3f9ca8 984b1c3 d3f9ca8 984b1c3 d3f9ca8 984b1c3 d3f9ca8 bb4e136 d3f9ca8 bb4e136 d3f9ca8 f6a6474 d3f9ca8 95d0b08 984b1c3 d3f9ca8 984b1c3 9cebca9 d3f9ca8 bb4e136 984b1c3 e49358d 984b1c3 d3f9ca8 984b1c3 e49358d d3f9ca8 95d0b08 1f7ad23 984b1c3 af8c4a2 984b1c3 d3f9ca8 984b1c3 d3f9ca8 984b1c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET # Import U²-Net
# Load U²-Net model
model_path = "cloth_segmentation/networks/u2net.pth"
model = U2NET(3, 1)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} # Remove 'module.' prefix
model.load_state_dict(state_dict)
model.eval()
def detect_design(image_np):
"""Detects if a design exists on the dress using edge detection & clustering."""
gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
edges = cv2.Canny(gray, 50, 150)
# Dilation to highlight patterns
kernel = np.ones((3, 3), np.uint8)
edges = cv2.dilate(edges, kernel, iterations=1)
# Count edge density
design_ratio = np.sum(edges > 0) / (image_np.shape[0] * image_np.shape[1])
return design_ratio > 0.02, edges # If edge density is high, assume a design is present
def segment_dress(image_np):
"""Segment the dress using U²-Net & refine with Lab color space."""
# Convert to Lab space
img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
L, A, B = cv2.split(img_lab)
# Use K-means clustering to detect dominant dress region
pixel_values = img_lab.reshape((-1, 3)).astype(np.float32)
k = 3 # Three clusters: background, skin, dress
_, labels, centers = cv2.kmeans(pixel_values, k, None, (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0), 10, cv2.KMEANS_RANDOM_CENTERS)
labels = labels.reshape(image_np.shape[:2])
# Assume dress is the largest non-background cluster
unique_labels, counts = np.unique(labels, return_counts=True)
dress_label = unique_labels[np.argmax(counts[1:]) + 1] # Avoid background
# Create dress mask
mask = (labels == dress_label).astype(np.uint8) * 255
# Use U²-Net prediction to refine segmentation
transform_pipeline = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((320, 320))
])
image = Image.fromarray(image_np).convert("RGB")
input_tensor = transform_pipeline(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)[0][0].squeeze().cpu().numpy()
u2net_mask = (output > 0.5).astype(np.uint8) * 255
u2net_mask = cv2.resize(u2net_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
# Combine K-means and U²-Net masks
refined_mask = cv2.bitwise_and(mask, u2net_mask)
# Morphological operations for smoothness
kernel = np.ones((5, 5), np.uint8)
refined_mask = cv2.morphologyEx(refined_mask, cv2.MORPH_CLOSE, kernel)
refined_mask = cv2.GaussianBlur(refined_mask, (15, 15), 5)
return refined_mask
def recolor_dress(image_np, mask, target_color, edges):
"""Change dress color while preserving texture, shadows, and designs."""
img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
# Exclude design from recoloring
design_mask = (edges > 0).astype(np.uint8) * 255
mask = cv2.bitwise_and(mask, cv2.bitwise_not(design_mask))
# Preserve lightness (L) and change only chromatic channels (A & B)
blend_factor = 0.7
img_lab[..., 1] = np.where(mask > 128, img_lab[..., 1] * (1 - blend_factor) + target_color_lab[1] * blend_factor, img_lab[..., 1])
img_lab[..., 2] = np.where(mask > 128, img_lab[..., 2] * (1 - blend_factor) + target_color_lab[2] * blend_factor, img_lab[..., 2])
img_recolored = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)
return img_recolored
def change_dress_color(image_path, color):
"""Change the dress color naturally while keeping designs intact."""
if image_path is None:
return None
img = Image.open(image_path).convert("RGB")
img_np = np.array(img)
# Detect if a design is present
design_present, edges = detect_design(img_np)
# Get dress segmentation mask
mask = segment_dress(img_np)
if mask is None:
return img # No dress detected
# Convert the selected color to BGR
color_map = {
"Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0), "Yellow": (0, 255, 255),
"Purple": (128, 0, 128), "Orange": (0, 165, 255), "Cyan": (255, 255, 0), "Magenta": (255, 0, 255),
"White": (255, 255, 255), "Black": (0, 0, 0)
}
new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8) # Default to Red
# Apply recoloring logic
if design_present:
print("Design detected! Coloring only non-design areas.")
img_recolored = recolor_dress(img_np, mask, new_color_bgr, edges)
else:
print("No design detected. Coloring entire dress.")
img_recolored = recolor_dress(img_np, mask, new_color_bgr, np.zeros_like(mask)) # No design mask
return Image.fromarray(img_recolored)
# Gradio Interface
demo = gr.Interface(
fn=change_dress_color,
inputs=[
gr.Image(type="filepath", label="Upload Dress Image"),
gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple", "Orange", "Cyan", "Magenta", "White", "Black"], label="Choose New Dress Color")
],
outputs=gr.Image(type="pil", label="Color Changed Dress"),
title="Dress Color Changer",
description="Upload an image of a dress and select a new color to change its appearance naturally while preserving designs."
)
if __name__ == "__main__":
demo.launch() |