Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,122 +1,96 @@
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
-
import torch
|
4 |
import cv2
|
5 |
from PIL import Image
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
# Load U²-Net model
|
10 |
-
model_path = "cloth_segmentation/networks/u2net.pth"
|
11 |
model = U2NET(3, 1)
|
12 |
-
|
13 |
-
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} # Remove 'module.' prefix
|
14 |
-
model.load_state_dict(state_dict)
|
15 |
model.eval()
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
# Use K-means clustering to detect dominant dress region
|
25 |
-
pixel_values = img_lab.reshape((-1, 3)).astype(np.float32)
|
26 |
-
k = 3 # Three clusters: background, skin, dress
|
27 |
-
_, labels, centers = cv2.kmeans(pixel_values, k, None, (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0), 10, cv2.KMEANS_RANDOM_CENTERS)
|
28 |
-
labels = labels.reshape(image_np.shape[:2])
|
29 |
-
|
30 |
-
# Assume dress is the largest non-background cluster
|
31 |
-
unique_labels, counts = np.unique(labels, return_counts=True)
|
32 |
-
dress_label = unique_labels[np.argmax(counts[1:]) + 1] # Avoid background
|
33 |
-
|
34 |
-
# Create dress mask
|
35 |
-
mask = (labels == dress_label).astype(np.uint8) * 255
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
]
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
45 |
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
refined_mask = cv2.bitwise_and(mask, u2net_mask)
|
54 |
|
55 |
-
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
edges = cv2.Canny(gray, 50, 150)
|
61 |
-
|
62 |
-
# Expand detected edges to mask the design area
|
63 |
-
kernel = np.ones((5, 5), np.uint8)
|
64 |
-
design_mask = cv2.dilate(edges, kernel, iterations=2)
|
65 |
-
|
66 |
-
# Keep only the design within the dress area
|
67 |
-
design_mask = cv2.bitwise_and(design_mask, dress_mask)
|
68 |
-
return design_mask
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
design_mask = detect_design(img_np, dress_mask)
|
92 |
-
|
93 |
-
if dress_mask is None:
|
94 |
-
return img # No dress detected
|
95 |
-
|
96 |
-
# Convert the selected color to BGR
|
97 |
-
color_map = {
|
98 |
-
"Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0), "Yellow": (0, 255, 255),
|
99 |
-
"Purple": (128, 0, 128), "Orange": (0, 165, 255), "Cyan": (255, 255, 0), "Magenta": (255, 0, 255),
|
100 |
-
"White": (255, 255, 255), "Black": (0, 0, 0)
|
101 |
-
}
|
102 |
-
new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8) # Default to Red
|
103 |
-
|
104 |
-
# Recolor the dress naturally
|
105 |
-
img_recolored = recolor_dress(img_np, dress_mask, design_mask, new_color_bgr)
|
106 |
|
107 |
-
|
|
|
|
|
108 |
|
109 |
-
|
110 |
-
|
111 |
-
fn=change_dress_color,
|
112 |
inputs=[
|
113 |
-
gr.Image(
|
114 |
-
gr.
|
115 |
],
|
116 |
-
outputs=gr.Image(
|
117 |
-
title="Dress
|
118 |
-
description="Upload
|
|
|
119 |
)
|
120 |
|
121 |
-
|
122 |
-
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
|
|
3 |
import cv2
|
4 |
from PIL import Image
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torchvision.transforms as T
|
8 |
+
from skimage import color
|
9 |
+
from sklearn.cluster import KMeans
|
10 |
+
from cloth_segmentation.networks.u2net import U2NET
|
11 |
|
|
|
|
|
12 |
model = U2NET(3, 1)
|
13 |
+
model.load_state_dict(torch.load("u2net.pth", map_location=torch.device('cpu')))
|
|
|
|
|
14 |
model.eval()
|
15 |
|
16 |
+
# Preprocessing
|
17 |
+
transform = T.Compose([
|
18 |
+
T.Resize((320, 320)),
|
19 |
+
T.ToTensor(),
|
20 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
21 |
+
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
# Segmentation mask
|
24 |
+
@torch.no_grad()
|
25 |
+
def get_dress_mask(image_pil):
|
26 |
+
img = transform(image_pil).unsqueeze(0)
|
27 |
+
pred = model(img)[0]
|
28 |
+
pred = pred.squeeze().cpu().numpy()
|
29 |
+
mask = (pred > 0.5).astype(np.uint8)
|
30 |
+
mask = cv2.resize(mask, image_pil.size[::-1])
|
31 |
+
return mask
|
32 |
|
33 |
+
# Color parsing (extract target color from prompt)
|
34 |
+
def extract_target_color(prompt):
|
35 |
+
# Basic keyword matching (can be replaced with NLP-based color detection)
|
36 |
+
import re
|
37 |
+
colors = ["red", "blue", "green", "yellow", "pink", "black", "white", "sky blue", "purple"]
|
38 |
+
for c in colors:
|
39 |
+
if re.search(c, prompt.lower()):
|
40 |
+
return c
|
41 |
+
return "red" # default fallback
|
42 |
|
43 |
+
# Recoloring function
|
44 |
+
def recolor_dress(image_pil, prompt):
|
45 |
+
image_np = np.array(image_pil.convert("RGB")) / 255.0
|
46 |
+
lab = color.rgb2lab(image_np)
|
|
|
47 |
|
48 |
+
mask = get_dress_mask(image_pil)
|
49 |
|
50 |
+
# Get mean a, b values in masked region
|
51 |
+
a_mean = lab[:, :, 1][mask == 1].mean()
|
52 |
+
b_mean = lab[:, :, 2][mask == 1].mean()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
+
# Target a, b (from a small predefined palette)
|
55 |
+
target_color_map = {
|
56 |
+
"red": [60, 40],
|
57 |
+
"blue": [20, -60],
|
58 |
+
"green": [-60, 60],
|
59 |
+
"yellow": [10, 70],
|
60 |
+
"pink": [50, 10],
|
61 |
+
"purple": [40, -40],
|
62 |
+
"black": [0, 0],
|
63 |
+
"white": [0, 0],
|
64 |
+
"sky blue": [0, -50],
|
65 |
+
}
|
66 |
+
target = extract_target_color(prompt)
|
67 |
+
target_a, target_b = target_color_map.get(target, [60, 40])
|
68 |
|
69 |
+
# Apply color shift only to dress region
|
70 |
+
lab_new = lab.copy()
|
71 |
+
delta_a = target_a - a_mean
|
72 |
+
delta_b = target_b - b_mean
|
73 |
+
lab_new[:, :, 1][mask == 1] += delta_a
|
74 |
+
lab_new[:, :, 2][mask == 1] += delta_b
|
75 |
|
76 |
+
rgb_new = color.lab2rgb(lab_new)
|
77 |
+
rgb_new = (rgb_new * 255).astype(np.uint8)
|
78 |
+
return Image.fromarray(rgb_new)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
+
# Gradio UI
|
81 |
+
def interface_fn(image, prompt):
|
82 |
+
return recolor_dress(image, prompt)
|
83 |
|
84 |
+
interface = gr.Interface(
|
85 |
+
fn=interface_fn,
|
|
|
86 |
inputs=[
|
87 |
+
gr.Image(label="Upload Image", type="pil"),
|
88 |
+
gr.Textbox(label="Describe Dress Color Change", placeholder="e.g. Make the dress sky blue")
|
89 |
],
|
90 |
+
outputs=gr.Image(label="Recolored Dress Output"),
|
91 |
+
title="Natural Dress Recoloring with Design Preservation (No GPU)",
|
92 |
+
description="Upload a fashion photo and use prompts like “Make the dress blue” or “Change to red silk”. This CPU-based app changes the color of dresses while preserving textures, embroidery, and shadows.",
|
93 |
+
allow_flagging="never"
|
94 |
)
|
95 |
|
96 |
+
interface.launch()
|
|