File size: 5,599 Bytes
984b1c3
 
 
 
 
 
ac99583
984b1c3
 
d3f9ca8
984b1c3
1c2f991
e49358d
1c2f991
984b1c3
 
d3f9ca8
 
 
 
 
 
 
 
 
 
 
 
 
 
984b1c3
d3f9ca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
984b1c3
 
 
 
d3f9ca8
984b1c3
 
d3f9ca8
984b1c3
 
d3f9ca8
 
 
984b1c3
d3f9ca8
 
 
 
 
 
 
bb4e136
d3f9ca8
 
 
 
bb4e136
d3f9ca8
 
f6a6474
d3f9ca8
 
 
 
 
 
 
 
 
 
 
95d0b08
984b1c3
d3f9ca8
984b1c3
 
 
 
 
9cebca9
d3f9ca8
 
 
 
 
 
bb4e136
984b1c3
 
e49358d
984b1c3
d3f9ca8
 
 
984b1c3
e49358d
 
d3f9ca8
 
 
 
 
 
 
95d0b08
1f7ad23
984b1c3
af8c4a2
984b1c3
 
 
 
d3f9ca8
984b1c3
 
 
d3f9ca8
984b1c3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET  # Import U²-Net

# Load U²-Net model
model_path = "cloth_segmentation/networks/u2net.pth"
model = U2NET(3, 1)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}  # Remove 'module.' prefix
model.load_state_dict(state_dict)
model.eval()

def detect_design(image_np):
    """Detects if a design exists on the dress using edge detection & clustering."""
    gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
    edges = cv2.Canny(gray, 50, 150)

    # Dilation to highlight patterns
    kernel = np.ones((3, 3), np.uint8)
    edges = cv2.dilate(edges, kernel, iterations=1)

    # Count edge density
    design_ratio = np.sum(edges > 0) / (image_np.shape[0] * image_np.shape[1])

    return design_ratio > 0.02, edges  # If edge density is high, assume a design is present

def segment_dress(image_np):
    """Segment the dress using U²-Net & refine with Lab color space."""
    
    # Convert to Lab space
    img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
    L, A, B = cv2.split(img_lab)

    # Use K-means clustering to detect dominant dress region
    pixel_values = img_lab.reshape((-1, 3)).astype(np.float32)
    k = 3  # Three clusters: background, skin, dress
    _, labels, centers = cv2.kmeans(pixel_values, k, None, (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0), 10, cv2.KMEANS_RANDOM_CENTERS)
    labels = labels.reshape(image_np.shape[:2])
    
    # Assume dress is the largest non-background cluster
    unique_labels, counts = np.unique(labels, return_counts=True)
    dress_label = unique_labels[np.argmax(counts[1:]) + 1]  # Avoid background
    
    # Create dress mask
    mask = (labels == dress_label).astype(np.uint8) * 255

    # Use U²-Net prediction to refine segmentation
    transform_pipeline = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((320, 320))
    ])
    
    image = Image.fromarray(image_np).convert("RGB")
    input_tensor = transform_pipeline(image).unsqueeze(0)

    with torch.no_grad():
        output = model(input_tensor)[0][0].squeeze().cpu().numpy()

    u2net_mask = (output > 0.5).astype(np.uint8) * 255
    u2net_mask = cv2.resize(u2net_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
    
    # Combine K-means and U²-Net masks
    refined_mask = cv2.bitwise_and(mask, u2net_mask)

    # Morphological operations for smoothness
    kernel = np.ones((5, 5), np.uint8)
    refined_mask = cv2.morphologyEx(refined_mask, cv2.MORPH_CLOSE, kernel)
    refined_mask = cv2.GaussianBlur(refined_mask, (15, 15), 5)
    
    return refined_mask

def recolor_dress(image_np, mask, target_color, edges):
    """Change dress color while preserving texture, shadows, and designs."""
    
    img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
    target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
    
    # Exclude design from recoloring
    design_mask = (edges > 0).astype(np.uint8) * 255
    mask = cv2.bitwise_and(mask, cv2.bitwise_not(design_mask))

    # Preserve lightness (L) and change only chromatic channels (A & B)
    blend_factor = 0.7
    img_lab[..., 1] = np.where(mask > 128, img_lab[..., 1] * (1 - blend_factor) + target_color_lab[1] * blend_factor, img_lab[..., 1])
    img_lab[..., 2] = np.where(mask > 128, img_lab[..., 2] * (1 - blend_factor) + target_color_lab[2] * blend_factor, img_lab[..., 2])

    img_recolored = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)
    return img_recolored

def change_dress_color(image_path, color):
    """Change the dress color naturally while keeping designs intact."""
    if image_path is None:
        return None

    img = Image.open(image_path).convert("RGB")
    img_np = np.array(img)

    # Detect if a design is present
    design_present, edges = detect_design(img_np)

    # Get dress segmentation mask
    mask = segment_dress(img_np)
    
    if mask is None:
        return img  # No dress detected
    
    # Convert the selected color to BGR
    color_map = {
        "Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0), "Yellow": (0, 255, 255),
        "Purple": (128, 0, 128), "Orange": (0, 165, 255), "Cyan": (255, 255, 0), "Magenta": (255, 0, 255),
        "White": (255, 255, 255), "Black": (0, 0, 0)
    }
    new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8)  # Default to Red
    
    # Apply recoloring logic
    if design_present:
        print("Design detected! Coloring only non-design areas.")
        img_recolored = recolor_dress(img_np, mask, new_color_bgr, edges)
    else:
        print("No design detected. Coloring entire dress.")
        img_recolored = recolor_dress(img_np, mask, new_color_bgr, np.zeros_like(mask))  # No design mask

    return Image.fromarray(img_recolored)

# Gradio Interface
demo = gr.Interface(
    fn=change_dress_color,
    inputs=[
        gr.Image(type="filepath", label="Upload Dress Image"),
        gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple", "Orange", "Cyan", "Magenta", "White", "Black"], label="Choose New Dress Color")
    ],
    outputs=gr.Image(type="pil", label="Color Changed Dress"),
    title="Dress Color Changer",
    description="Upload an image of a dress and select a new color to change its appearance naturally while preserving designs."
)

if __name__ == "__main__":
    demo.launch()