File size: 3,818 Bytes
984b1c3
 
 
 
 
 
9bbf95a
984b1c3
9bbf95a
d3f9ca8
984b1c3
1c2f991
9bbf95a
1c2f991
984b1c3
 
 
9bbf95a
984b1c3
 
 
 
d3f9ca8
984b1c3
 
d3f9ca8
984b1c3
 
9bbf95a
d3f9ca8
 
 
9bbf95a
 
 
 
 
 
bb4e136
9bbf95a
 
 
 
d3f9ca8
9bbf95a
 
bb4e136
9bbf95a
d3f9ca8
9bbf95a
 
d3f9ca8
490bf43
d3f9ca8
9bbf95a
d3f9ca8
 
 
9bbf95a
d3f9ca8
 
95d0b08
490bf43
9bbf95a
984b1c3
 
 
 
 
9cebca9
d3f9ca8
 
 
bb4e136
490bf43
 
 
984b1c3
d3f9ca8
 
 
984b1c3
e49358d
 
490bf43
9bbf95a
95d0b08
1f7ad23
984b1c3
af8c4a2
984b1c3
490bf43
984b1c3
 
d3f9ca8
984b1c3
490bf43
 
 
984b1c3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET

# Load U²-Net Model
model_path = "cloth_segmentation/networks/u2net.pth"
model = U2NET(3, 1)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()

def segment_dress(image_np):
    """Segment the dress using U²-Net and GrabCut."""
    transform_pipeline = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((320, 320))
    ])
    
    image = Image.fromarray(image_np).convert("RGB")
    input_tensor = transform_pipeline(image).unsqueeze(0)

    with torch.no_grad():
        output = model(input_tensor)[0][0].squeeze().cpu().numpy()
    
    u2net_mask = (output > 0.5).astype(np.uint8) * 255
    u2net_mask = cv2.resize(u2net_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)

    # Apply GrabCut to refine the mask
    mask = np.zeros(image_np.shape[:2], np.uint8)
    mask[u2net_mask > 128] = cv2.GC_FGD
    mask[u2net_mask <= 128] = cv2.GC_BGD
    bg_model = np.zeros((1, 65), np.float64)
    fg_model = np.zeros((1, 65), np.float64)
    
    cv2.grabCut(image_np, mask, None, bg_model, fg_model, 5, cv2.GC_INIT_WITH_MASK)
    mask = np.where((mask == 2) | (mask == 0), 0, 255).astype(np.uint8)
    
    return mask

def recolor_dress(image_np, mask, target_color):
    """Recolor the dress while keeping texture, shadows, and designs."""
    
    # Convert to LAB color space
    img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
    
    # Target color in LAB
    target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
    
    # Preserve lightness (L) and change only chromatic channels (A & B)
    blend_factor = 0.8
    img_lab[..., 1] = np.where(mask > 128, img_lab[..., 1] * (1 - blend_factor) + target_color_lab[1] * blend_factor, img_lab[..., 1])
    img_lab[..., 2] = np.where(mask > 128, img_lab[..., 2] * (1 - blend_factor) + target_color_lab[2] * blend_factor, img_lab[..., 2])

    # Convert back to RGB
    img_recolored = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)
    return img_recolored

def change_dress_color(image_path, color):
    """Change the dress color while preserving texture and design details."""
    if image_path is None:
        return None

    img = Image.open(image_path).convert("RGB")
    img_np = np.array(img)

    # Get dress segmentation mask
    mask = segment_dress(img_np)
    
    if mask is None:
        return img  # No dress detected
    
    # Convert the selected color to BGR
    color_map = {
        "Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0), "Yellow": (0, 255, 255),
        "Purple": (128, 0, 128), "Orange": (0, 165, 255), "Cyan": (255, 255, 0), "Magenta": (255, 0, 255),
        "White": (255, 255, 255), "Black": (0, 0, 0)
    }
    new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8)  # Default to Red
    
    # Apply recoloring logic
    img_recolored = recolor_dress(img_np, mask, new_color_bgr)

    return Image.fromarray(img_recolored)

# Gradio Interface
demo = gr.Interface(
    fn=change_dress_color,
    inputs=[
        gr.Image(type="filepath", label="Upload Dress Image"),
        gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple", "Orange", "Cyan", "Magenta", "White", "Black"], label="Choose New Dress Color")
    ],
    outputs=gr.Image(type="pil", label="Color Changed Dress"),
    title="Dress Color Changer",
    description="Upload an image of a dress and select a new color to change its appearance naturally while preserving designs."
)

if __name__ == "__main__":
    demo.launch()