File size: 3,238 Bytes
984b1c3
 
 
 
 
 
ac99583
984b1c3
 
fa3d84c
984b1c3
1c2f991
 
 
fa3d84c
1c2f991
984b1c3
 
 
9582f0e
984b1c3
 
 
 
 
 
 
 
 
 
9582f0e
 
 
 
 
fa3d84c
 
9582f0e
 
fa3d84c
9582f0e
984b1c3
 
 
65014ef
984b1c3
 
 
 
 
 
9582f0e
984b1c3
 
 
fa3d84c
 
 
 
984b1c3
fa3d84c
 
 
 
 
984b1c3
fa3d84c
 
 
 
af8c4a2
c8390c9
fa3d84c
af8c4a2
fa3d84c
 
 
65014ef
984b1c3
 
af8c4a2
984b1c3
 
 
 
 
 
 
 
65014ef
984b1c3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET  # Import U²-Net

# Load U²-Net model
model_path = "cloth_segmentation/networks/u2net.pth"
model = U2NET(3, 1)

# Load the state dictionary
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}  
model.load_state_dict(state_dict)
model.eval()

def segment_dress(image_np):
    """Segment the dress from the image using U²-Net and refine the mask."""
    transform_pipeline = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((320, 320))
    ])
    image = Image.fromarray(image_np).convert("RGB")
    input_tensor = transform_pipeline(image).unsqueeze(0)
    
    with torch.no_grad():
        output = model(input_tensor)[0][0].squeeze().cpu().numpy()
    
    mask = (output > 0.5).astype(np.uint8) * 255  # Binary mask
    
    # Resize mask to original image size
    mask = cv2.resize(mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
    
    # Refine mask using morphological operations
    kernel = np.ones((5, 5), np.uint8)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)  # Close small gaps
    mask = cv2.dilate(mask, kernel, iterations=2)  # Expand the detected dress area
    mask = cv2.GaussianBlur(mask, (5, 5), 0)  # Smooth edges
    
    return mask

def change_dress_color(image_path, color):
    """Change the dress color naturally while keeping textures."""
    if image_path is None:
        return None

    img = Image.open(image_path).convert("RGB")
    img_np = np.array(img)
    mask = segment_dress(img_np)

    if mask is None:
        return img  # No dress detected
    
    # Convert image to HSV for color modification
    img_hsv = cv2.cvtColor(img_np, cv2.COLOR_RGB2HSV)

    # Define new color in HSV (only modifying the Hue)
    color_map = {
        "Red": 0,      # Hue value for Red
        "Blue": 120,   # Hue value for Blue
        "Green": 60,   # Hue value for Green
        "Yellow": 30,  # Hue value for Yellow
        "Purple": 150  # Hue value for Purple
    }
    new_hue = color_map.get(color, 0)

    # Modify only the Hue channel where the mask is applied
    img_hsv[..., 0] = np.where(mask > 128, new_hue, img_hsv[..., 0])

    # Convert back to RGB
    img_recolored = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB)

    # Apply Poisson blending for natural integration
    center = (img_np.shape[1] // 2, img_np.shape[0] // 2)
    img_recolored = cv2.seamlessClone(img_recolored, img_np, mask, center, cv2.MIXED_CLONE)

    return Image.fromarray(img_recolored)

# Gradio Interface
demo = gr.Interface(
    fn=change_dress_color,
    inputs=[
        gr.Image(type="filepath", label="Upload Dress Image"),
        gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple"], label="Choose New Dress Color")
    ],
    outputs=gr.Image(type="pil", label="Color Changed Dress"),
    title="Dress Color Changer",
    description="Upload an image of a dress and select a new color to change its appearance naturally."
)

if __name__ == "__main__":
    demo.launch()