File size: 3,689 Bytes
984b1c3
 
 
 
 
 
ba17c5b
984b1c3
d00e30a
d3f9ca8
984b1c3
1c2f991
ba17c5b
1c2f991
984b1c3
 
 
ba17c5b
984b1c3
 
 
 
d3f9ca8
984b1c3
 
d3f9ca8
984b1c3
 
d00e30a
 
 
 
 
d3f9ca8
ba17c5b
 
bb4e136
ba17c5b
d3f9ca8
d00e30a
ba17c5b
 
 
 
 
 
d00e30a
ba17c5b
 
 
d3f9ca8
ba17c5b
d3f9ca8
ba17c5b
 
 
 
d3f9ca8
95d0b08
490bf43
ba17c5b
984b1c3
 
 
 
 
9cebca9
d3f9ca8
d00e30a
d3f9ca8
d00e30a
490bf43
d00e30a
490bf43
984b1c3
d3f9ca8
 
 
984b1c3
ba17c5b
 
 
 
95d0b08
1f7ad23
984b1c3
af8c4a2
984b1c3
490bf43
984b1c3
 
d3f9ca8
984b1c3
490bf43
ba17c5b
 
984b1c3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET  

# Load U²-Net model
model_path = "cloth_segmentation/networks/u2net.pth"
model = U2NET(3, 1)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()

def segment_dress(image_np):
    """Segment dress using U²-Net"""
    transform_pipeline = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((320, 320))
    ])
    
    image = Image.fromarray(image_np).convert("RGB")
    input_tensor = transform_pipeline(image).unsqueeze(0)

    with torch.no_grad():
        output = model(input_tensor)[0][0].squeeze().cpu().numpy()

    dress_mask = (output > 0.5).astype(np.uint8) * 255
    dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)

    return dress_mask

def recolor_dress(image_np, dress_mask, target_color):
    """Change dress color naturally while keeping textures intact"""
    
    # Convert target color to LAB
    target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]

    # Convert image to LAB for better color control
    img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
    
    # Compute mean LAB values of dress pixels
    dress_pixels = img_lab[dress_mask > 0]
    mean_L, mean_A, mean_B = dress_pixels[:, 0].mean(), dress_pixels[:, 1].mean(), dress_pixels[:, 2].mean()

    # Compute new color adjustment
    img_lab[..., 1] = np.where(dress_mask > 128, img_lab[..., 1] - mean_A + target_color_lab[1], img_lab[..., 1])
    img_lab[..., 2] = np.where(dress_mask > 128, img_lab[..., 2] - mean_B + target_color_lab[2], img_lab[..., 2])

    # Convert back to RGB
    img_recolored = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)

    # Smooth edges for natural blending
    img_recolored = cv2.seamlessClone(img_recolored, image_np, dress_mask, (image_np.shape[1]//2, image_np.shape[0]//2), cv2.NORMAL_CLONE)

    return img_recolored

def change_dress_color(image_path, color):
    """Main function to change dress color naturally"""
    if image_path is None:
        return None

    img = Image.open(image_path).convert("RGB")
    img_np = np.array(img)

    # Get dress segmentation mask
    dress_mask = segment_dress(img_np)
    
    if dress_mask is None:
        return img  # No dress detected

    # Convert the selected color to BGR
    color_map = {
        "Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0), "Yellow": (0, 255, 255),
        "Purple": (128, 0, 128), "Orange": (0, 165, 255), "Cyan": (255, 255, 0), "Magenta": (255, 0, 255),
        "White": (255, 255, 255), "Black": (0, 0, 0)
    }
    new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8)

    # Apply recoloring with blending
    img_recolored = recolor_dress(img_np, dress_mask, new_color_bgr)

    return Image.fromarray(img_recolored)

# Gradio Interface
demo = gr.Interface(
    fn=change_dress_color,
    inputs=[
        gr.Image(type="filepath", label="Upload Dress Image"),
        gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple", "Orange", "Cyan", "Magenta", "White", "Black"], label="Choose New Dress Color")
    ],
    outputs=gr.Image(type="pil", label="Color Changed Dress"),
    title="Realistic Dress Color Changer",
    description="Upload an image of a dress and select a new color. The AI will change the dress color naturally while keeping the fabric texture."
)

if __name__ == "__main__":
    demo.launch()