File size: 3,228 Bytes
984b1c3
 
 
 
 
 
ac99583
984b1c3
 
a5aec3c
984b1c3
1c2f991
 
 
9582f0e
1c2f991
984b1c3
 
 
9582f0e
984b1c3
 
 
 
 
 
 
 
 
 
9582f0e
 
 
 
 
 
 
 
 
 
984b1c3
 
 
 
 
 
 
 
 
 
9582f0e
984b1c3
 
 
c8390c9
984b1c3
c8390c9
 
 
 
 
984b1c3
c8390c9
984b1c3
9582f0e
 
 
c8390c9
9582f0e
 
 
c8390c9
 
9582f0e
 
984b1c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET  # Import U²-Net

# Load U²-Net model
model_path = "cloth_segmentation/networks/u2net.pth"  # Ensure this path is correct
model = U2NET(3, 1)

# Load the state dictionary
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}  # Remove 'module.' prefix
model.load_state_dict(state_dict)
model.eval()

def segment_dress(image_np):
    """Segment the dress from the image using U²-Net and refine the mask."""
    transform_pipeline = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((320, 320))
    ])
    image = Image.fromarray(image_np).convert("RGB")
    input_tensor = transform_pipeline(image).unsqueeze(0)
    
    with torch.no_grad():
        output = model(input_tensor)[0][0].squeeze().cpu().numpy()
    
    mask = (output > 0.5).astype(np.uint8) * 255  # Binary mask
    
    # Resize mask to original image size
    mask = cv2.resize(mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
    
    # Apply morphological operations for better segmentation
    kernel = np.ones((7, 7), np.uint8)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)  # Close small gaps
    mask = cv2.dilate(mask, kernel, iterations=2)  # Expand the detected dress area
    
    return mask

def change_dress_color(image_path, color):
    """Change the dress color based on the detected dress mask."""
    if image_path is None:
        return None

    img = Image.open(image_path).convert("RGB")
    img_np = np.array(img)
    mask = segment_dress(img_np)

    if mask is None:
        return img  # No dress detected
    
    # Convert the selected color to BGR
    color_map = {
        "Red": (0, 0, 255),
        "Blue": (255, 0, 0),
        "Green": (0, 255, 0),
        "Yellow": (0, 255, 255),
        "Purple": (128, 0, 128)
    }
    new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8)  # Default to Red
    
    # Convert image to LAB color space for better blending
    img_lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
    new_color_lab = cv2.cvtColor(np.uint8([[new_color_bgr]]), cv2.COLOR_BGR2LAB)[0][0]
    
    # Apply the new color while preserving texture
    img_lab[..., 1] = np.where(mask == 255, new_color_lab[1], img_lab[..., 1])  # Modify A-channel
    img_lab[..., 2] = np.where(mask == 255, new_color_lab[2], img_lab[..., 2])  # Modify B-channel

    # Convert back to RGB
    img_recolored = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)

    return Image.fromarray(img_recolored)

# Gradio Interface
demo = gr.Interface(
    fn=change_dress_color,
    inputs=[
        gr.Image(type="filepath", label="Upload Dress Image"),
        gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple"], label="Choose New Dress Color")
    ],
    outputs=gr.Image(type="pil", label="Color Changed Dress"),
    title="Dress Color Changer",
    description="Upload an image of a dress and select a new color to change its appearance."
)

if __name__ == "__main__":
    demo.launch()