Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -12,16 +12,12 @@ model = U2NET(3, 1)
|
|
12 |
|
13 |
# Load the state dictionary
|
14 |
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
15 |
-
|
16 |
-
# Remove the 'module.' prefix from the keys
|
17 |
-
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
18 |
-
|
19 |
-
# Load the modified state dictionary into the model
|
20 |
model.load_state_dict(state_dict)
|
21 |
model.eval()
|
22 |
|
23 |
def segment_dress(image_np):
|
24 |
-
"""Segment the dress from the image using U²-Net."""
|
25 |
transform_pipeline = transforms.Compose([
|
26 |
transforms.ToTensor(),
|
27 |
transforms.Resize((320, 320))
|
@@ -32,8 +28,16 @@ def segment_dress(image_np):
|
|
32 |
with torch.no_grad():
|
33 |
output = model(input_tensor)[0][0].squeeze().cpu().numpy()
|
34 |
|
35 |
-
mask = (output > 0.5).astype(np.uint8) * 255 #
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
return mask
|
38 |
|
39 |
def change_dress_color(image_path, color):
|
@@ -44,7 +48,7 @@ def change_dress_color(image_path, color):
|
|
44 |
img = Image.open(image_path).convert("RGB")
|
45 |
img_np = np.array(img)
|
46 |
mask = segment_dress(img_np)
|
47 |
-
|
48 |
if mask is None:
|
49 |
return img # No dress detected
|
50 |
|
@@ -58,19 +62,17 @@ def change_dress_color(image_path, color):
|
|
58 |
}
|
59 |
new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8) # Default to Red
|
60 |
|
61 |
-
# Convert image to
|
62 |
-
|
63 |
-
|
64 |
-
# Convert new color to HSV
|
65 |
-
new_color_hsv = cv2.cvtColor(np.uint8([[new_color_bgr]]), cv2.COLOR_BGR2HSV)[0][0]
|
66 |
|
67 |
-
# Apply the new color
|
68 |
-
|
69 |
-
|
70 |
|
71 |
# Convert back to RGB
|
72 |
-
img_recolored = cv2.cvtColor(
|
73 |
-
|
74 |
return Image.fromarray(img_recolored)
|
75 |
|
76 |
# Gradio Interface
|
|
|
12 |
|
13 |
# Load the state dictionary
|
14 |
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
15 |
+
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} # Remove 'module.' prefix
|
|
|
|
|
|
|
|
|
16 |
model.load_state_dict(state_dict)
|
17 |
model.eval()
|
18 |
|
19 |
def segment_dress(image_np):
|
20 |
+
"""Segment the dress from the image using U²-Net and refine the mask."""
|
21 |
transform_pipeline = transforms.Compose([
|
22 |
transforms.ToTensor(),
|
23 |
transforms.Resize((320, 320))
|
|
|
28 |
with torch.no_grad():
|
29 |
output = model(input_tensor)[0][0].squeeze().cpu().numpy()
|
30 |
|
31 |
+
mask = (output > 0.5).astype(np.uint8) * 255 # Binary mask
|
32 |
+
|
33 |
+
# Resize mask to original image size
|
34 |
+
mask = cv2.resize(mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
|
35 |
+
|
36 |
+
# Apply morphological operations for better segmentation
|
37 |
+
kernel = np.ones((7, 7), np.uint8)
|
38 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) # Close small gaps
|
39 |
+
mask = cv2.dilate(mask, kernel, iterations=2) # Expand the detected dress area
|
40 |
+
|
41 |
return mask
|
42 |
|
43 |
def change_dress_color(image_path, color):
|
|
|
48 |
img = Image.open(image_path).convert("RGB")
|
49 |
img_np = np.array(img)
|
50 |
mask = segment_dress(img_np)
|
51 |
+
|
52 |
if mask is None:
|
53 |
return img # No dress detected
|
54 |
|
|
|
62 |
}
|
63 |
new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8) # Default to Red
|
64 |
|
65 |
+
# Convert image to LAB color space for better blending
|
66 |
+
img_lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
|
67 |
+
new_color_lab = cv2.cvtColor(np.uint8([[new_color_bgr]]), cv2.COLOR_BGR2LAB)[0][0]
|
|
|
|
|
68 |
|
69 |
+
# Apply the new color while preserving texture
|
70 |
+
img_lab[..., 1] = np.where(mask == 255, new_color_lab[1], img_lab[..., 1]) # Modify A-channel
|
71 |
+
img_lab[..., 2] = np.where(mask == 255, new_color_lab[2], img_lab[..., 2]) # Modify B-channel
|
72 |
|
73 |
# Convert back to RGB
|
74 |
+
img_recolored = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)
|
75 |
+
|
76 |
return Image.fromarray(img_recolored)
|
77 |
|
78 |
# Gradio Interface
|