File size: 4,233 Bytes
984b1c3
 
 
 
 
 
ac99583
984b1c3
 
e49358d
984b1c3
1c2f991
 
 
e49358d
1c2f991
984b1c3
 
 
9582f0e
984b1c3
 
 
 
 
 
 
 
 
 
9582f0e
 
 
 
 
1f7ad23
95d0b08
9582f0e
95d0b08
9582f0e
984b1c3
 
95d0b08
 
 
 
 
 
984b1c3
95d0b08
984b1c3
 
 
 
 
 
1f7ad23
984b1c3
 
 
e49358d
984b1c3
1f7ad23
 
95d0b08
984b1c3
e49358d
 
 
 
 
 
95d0b08
 
 
 
 
1f7ad23
 
 
65014ef
1f7ad23
 
 
95d0b08
 
 
1f7ad23
984b1c3
af8c4a2
984b1c3
 
 
 
95d0b08
984b1c3
 
 
65014ef
984b1c3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET  # Import U²-Net

# Load U²-Net model
model_path = "cloth_segmentation/networks/u2net.pth"  # Ensure this path is correct
model = U2NET(3, 1)

# Load the state dictionary
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}  # Remove 'module.' prefix
model.load_state_dict(state_dict)
model.eval()

def segment_dress(image_np):
    """Segment the dress from the image using U²-Net and refine the mask."""
    transform_pipeline = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((320, 320))
    ])
    image = Image.fromarray(image_np).convert("RGB")
    input_tensor = transform_pipeline(image).unsqueeze(0)
    
    with torch.no_grad():
        output = model(input_tensor)[0][0].squeeze().cpu().numpy()
    
    mask = (output > 0.5).astype(np.uint8) * 255  # Binary mask
    
    # Resize mask to original image size
    mask = cv2.resize(mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
    
    # Apply morphological operations for better segmentation
    kernel = np.ones((5, 5), np.uint8)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)  # Close small gaps
    mask = cv2.GaussianBlur(mask, (21, 21), 10)  # Smooth edges for natural blending
    
    return mask

def get_ambient_light(img_np):
    """Estimate ambient lighting from non-dress areas for realistic blending."""
    lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
    L_channel = lab[:, :, 0]  # Lightness channel
    return np.median(L_channel)  # Median light level in the image

def change_dress_color(image_path, color):
    """Change the dress color naturally while keeping textures and adjusting to lighting."""
    if image_path is None:
        return None

    img = Image.open(image_path).convert("RGB")
    img_np = np.array(img)
    mask = segment_dress(img_np)
    
    if mask is None:
        return img  # No dress detected
    
    # Convert the selected color to BGR
    color_map = {
        "Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0), "Yellow": (0, 255, 255),
        "Purple": (128, 0, 128), "Orange": (0, 165, 255), "Cyan": (255, 255, 0), "Magenta": (255, 0, 255),
        "White": (255, 255, 255), "Black": (0, 0, 0)
    }
    new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8)  # Default to Red
    
    # Convert image to LAB color space for better blending
    img_lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
    new_color_lab = cv2.cvtColor(np.uint8([[new_color_bgr]]), cv2.COLOR_BGR2LAB)[0][0]
    
    # Adjust color to match ambient lighting
    ambient_light = get_ambient_light(img_np)
    img_lab[..., 0] = np.clip(img_lab[..., 0] * (ambient_light / 128), 0, 255)  # Normalize lighting
    
    # Preserve texture by modifying only A & B channels
    blend_factor = 0.6  # Controls intensity of color change
    img_lab[..., 1] = np.where(mask > 128, img_lab[..., 1] * (1 - blend_factor) + new_color_lab[1] * blend_factor, img_lab[..., 1])
    img_lab[..., 2] = np.where(mask > 128, img_lab[..., 2] * (1 - blend_factor) + new_color_lab[2] * blend_factor, img_lab[..., 2])

    # Convert back to RGB
    img_recolored = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)
    
    # Use Poisson blending for seamless integration with the environment
    img_recolored = cv2.seamlessClone(img_recolored, img_np, mask, (img_np.shape[1]//2, img_np.shape[0]//2), cv2.MIXED_CLONE)

    return Image.fromarray(img_recolored)

# Gradio Interface
demo = gr.Interface(
    fn=change_dress_color,
    inputs=[
        gr.Image(type="filepath", label="Upload Dress Image"),
        gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple", "Orange", "Cyan", "Magenta", "White", "Black"], label="Choose New Dress Color")
    ],
    outputs=gr.Image(type="pil", label="Color Changed Dress"),
    title="Dress Color Changer",
    description="Upload an image of a dress and select a new color to change its appearance naturally."
)

if __name__ == "__main__":
    demo.launch()