File size: 4,479 Bytes
984b1c3
 
 
 
 
 
d00e30a
984b1c3
d00e30a
d3f9ca8
984b1c3
1c2f991
d00e30a
1c2f991
984b1c3
 
d00e30a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
984b1c3
d00e30a
984b1c3
 
 
 
d3f9ca8
984b1c3
 
d3f9ca8
984b1c3
 
d00e30a
 
 
 
 
 
 
 
9bbf95a
d00e30a
d3f9ca8
d00e30a
 
bb4e136
d3f9ca8
 
d00e30a
 
 
 
 
9bbf95a
d00e30a
 
d3f9ca8
 
 
95d0b08
490bf43
d00e30a
984b1c3
 
 
 
 
9cebca9
d3f9ca8
d00e30a
d3f9ca8
d00e30a
490bf43
d00e30a
 
 
 
490bf43
984b1c3
d3f9ca8
 
 
984b1c3
e49358d
 
490bf43
d00e30a
95d0b08
1f7ad23
984b1c3
af8c4a2
984b1c3
490bf43
984b1c3
 
d3f9ca8
984b1c3
490bf43
 
 
984b1c3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET  # Import U²-Net

# Load U²-Net model
model_path = "cloth_segmentation/networks/u2net.pth"
model = U2NET(3, 1)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}  # Remove 'module.' prefix
model.load_state_dict(state_dict)
model.eval()

def detect_design(image_np):
    """Detects the design on the dress using edge detection and adaptive thresholding."""
    gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
    
    # Use adaptive thresholding to segment the design
    adaptive_thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
                                            cv2.THRESH_BINARY_INV, 11, 2)

    # Detect edges using Canny
    edges = cv2.Canny(gray, 50, 150)

    # Combine both masks
    design_mask = cv2.bitwise_or(adaptive_thresh, edges)

    # Morphological operations to remove noise
    kernel = np.ones((3, 3), np.uint8)
    design_mask = cv2.morphologyEx(design_mask, cv2.MORPH_CLOSE, kernel)

    return design_mask

def segment_dress(image_np):
    """Segment the dress using U²-Net"""
    transform_pipeline = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((320, 320))
    ])
    
    image = Image.fromarray(image_np).convert("RGB")
    input_tensor = transform_pipeline(image).unsqueeze(0)

    with torch.no_grad():
        output = model(input_tensor)[0][0].squeeze().cpu().numpy()

    # Convert output to mask
    dress_mask = (output > 0.5).astype(np.uint8) * 255
    dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)

    # Morphological operations for smoothness
    kernel = np.ones((5, 5), np.uint8)
    dress_mask = cv2.morphologyEx(dress_mask, cv2.MORPH_CLOSE, kernel)
    
    return dress_mask

def recolor_dress(image_np, dress_mask, design_mask, target_color):
    """Change dress color while preserving designs"""
    
    img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
    target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]

    # Ensure the design areas are NOT recolored
    recolor_mask = cv2.bitwise_and(dress_mask, cv2.bitwise_not(design_mask))

    # Apply color change only to the non-design dress areas
    blend_factor = 0.8
    img_lab[..., 1] = np.where(recolor_mask > 128, img_lab[..., 1] * (1 - blend_factor) + target_color_lab[1] * blend_factor, img_lab[..., 1])
    img_lab[..., 2] = np.where(recolor_mask > 128, img_lab[..., 2] * (1 - blend_factor) + target_color_lab[2] * blend_factor, img_lab[..., 2])

    img_recolored = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)
    return img_recolored

def change_dress_color(image_path, color):
    """Change the dress color naturally while keeping designs intact."""
    if image_path is None:
        return None

    img = Image.open(image_path).convert("RGB")
    img_np = np.array(img)

    # Get dress segmentation mask
    dress_mask = segment_dress(img_np)
    
    if dress_mask is None:
        return img  # No dress detected

    # Detect design on the dress
    design_mask = detect_design(img_np)

    # Convert the selected color to BGR
    color_map = {
        "Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0), "Yellow": (0, 255, 255),
        "Purple": (128, 0, 128), "Orange": (0, 165, 255), "Cyan": (255, 255, 0), "Magenta": (255, 0, 255),
        "White": (255, 255, 255), "Black": (0, 0, 0)
    }
    new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8)  # Default to Red
    
    # Apply recoloring logic
    img_recolored = recolor_dress(img_np, dress_mask, design_mask, new_color_bgr)

    return Image.fromarray(img_recolored)

# Gradio Interface
demo = gr.Interface(
    fn=change_dress_color,
    inputs=[
        gr.Image(type="filepath", label="Upload Dress Image"),
        gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple", "Orange", "Cyan", "Magenta", "White", "Black"], label="Choose New Dress Color")
    ],
    outputs=gr.Image(type="pil", label="Color Changed Dress"),
    title="Dress Color Changer",
    description="Upload an image of a dress and select a new color to change its appearance naturally while preserving designs."
)

if __name__ == "__main__":
    demo.launch()