gaur3009 commited on
Commit
af15f73
·
verified ·
1 Parent(s): b17fae5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -100
app.py CHANGED
@@ -1,122 +1,96 @@
1
  import gradio as gr
2
  import numpy as np
3
- import torch
4
  import cv2
5
  from PIL import Image
6
- from torchvision import transforms
7
- from cloth_segmentation.networks.u2net import U2NET # Import U²-Net
 
 
 
 
8
 
9
- # Load U²-Net model
10
- model_path = "cloth_segmentation/networks/u2net.pth"
11
  model = U2NET(3, 1)
12
- state_dict = torch.load(model_path, map_location=torch.device('cpu'))
13
- state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} # Remove 'module.' prefix
14
- model.load_state_dict(state_dict)
15
  model.eval()
16
 
17
- def segment_dress(image_np):
18
- """Segment the dress using U²-Net & refine with Lab color space."""
19
-
20
- # Convert to Lab space
21
- img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
22
- L, A, B = cv2.split(img_lab)
23
-
24
- # Use K-means clustering to detect dominant dress region
25
- pixel_values = img_lab.reshape((-1, 3)).astype(np.float32)
26
- k = 3 # Three clusters: background, skin, dress
27
- _, labels, centers = cv2.kmeans(pixel_values, k, None, (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0), 10, cv2.KMEANS_RANDOM_CENTERS)
28
- labels = labels.reshape(image_np.shape[:2])
29
-
30
- # Assume dress is the largest non-background cluster
31
- unique_labels, counts = np.unique(labels, return_counts=True)
32
- dress_label = unique_labels[np.argmax(counts[1:]) + 1] # Avoid background
33
-
34
- # Create dress mask
35
- mask = (labels == dress_label).astype(np.uint8) * 255
36
 
37
- # Use U²-Net prediction to refine segmentation
38
- transform_pipeline = transforms.Compose([
39
- transforms.ToTensor(),
40
- transforms.Resize((320, 320))
41
- ])
42
-
43
- image = Image.fromarray(image_np).convert("RGB")
44
- input_tensor = transform_pipeline(image).unsqueeze(0)
 
45
 
46
- with torch.no_grad():
47
- output = model(input_tensor)[0][0].squeeze().cpu().numpy()
 
 
 
 
 
 
 
48
 
49
- u2net_mask = (output > 0.5).astype(np.uint8) * 255
50
- u2net_mask = cv2.resize(u2net_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
51
-
52
- # Combine K-means and U²-Net masks
53
- refined_mask = cv2.bitwise_and(mask, u2net_mask)
54
 
55
- return refined_mask
56
 
57
- def detect_design(image_np, dress_mask):
58
- """Detect the design part of the dress and separate it from fabric."""
59
- gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
60
- edges = cv2.Canny(gray, 50, 150)
61
-
62
- # Expand detected edges to mask the design area
63
- kernel = np.ones((5, 5), np.uint8)
64
- design_mask = cv2.dilate(edges, kernel, iterations=2)
65
-
66
- # Keep only the design within the dress area
67
- design_mask = cv2.bitwise_and(design_mask, dress_mask)
68
- return design_mask
69
 
70
- def recolor_dress(image_np, mask, design_mask, target_color):
71
- """Change dress color while preserving texture and design."""
72
- img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
73
- target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
74
-
75
- # Preserve lightness (L) and change only chromatic channels (A & B)
76
- blend_factor = 0.7
77
- img_lab[..., 1] = np.where((mask > 128) & (design_mask == 0), img_lab[..., 1] * (1 - blend_factor) + target_color_lab[1] * blend_factor, img_lab[..., 1])
78
- img_lab[..., 2] = np.where((mask > 128) & (design_mask == 0), img_lab[..., 2] * (1 - blend_factor) + target_color_lab[2] * blend_factor, img_lab[..., 2])
79
-
80
- img_recolored = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)
81
- return img_recolored
 
 
82
 
83
- def change_dress_color(image_path, color):
84
- """Change the dress color naturally while keeping textures and design."""
85
- if image_path is None:
86
- return None
 
 
87
 
88
- img = Image.open(image_path).convert("RGB")
89
- img_np = np.array(img)
90
- dress_mask = segment_dress(img_np)
91
- design_mask = detect_design(img_np, dress_mask)
92
-
93
- if dress_mask is None:
94
- return img # No dress detected
95
-
96
- # Convert the selected color to BGR
97
- color_map = {
98
- "Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0), "Yellow": (0, 255, 255),
99
- "Purple": (128, 0, 128), "Orange": (0, 165, 255), "Cyan": (255, 255, 0), "Magenta": (255, 0, 255),
100
- "White": (255, 255, 255), "Black": (0, 0, 0)
101
- }
102
- new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8) # Default to Red
103
-
104
- # Recolor the dress naturally
105
- img_recolored = recolor_dress(img_np, dress_mask, design_mask, new_color_bgr)
106
 
107
- return Image.fromarray(img_recolored)
 
 
108
 
109
- # Gradio Interface
110
- demo = gr.Interface(
111
- fn=change_dress_color,
112
  inputs=[
113
- gr.Image(type="filepath", label="Upload Dress Image"),
114
- gr.Radio(["Red", "Blue", "Green", "Yellow", "Purple", "Orange", "Cyan", "Magenta", "White", "Black"], label="Choose New Dress Color")
115
  ],
116
- outputs=gr.Image(type="pil", label="Color Changed Dress"),
117
- title="Dress Color Changer",
118
- description="Upload an image of a dress and select a new color to change its appearance naturally while preserving the design."
 
119
  )
120
 
121
- if __name__ == "__main__":
122
- demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
 
3
  import cv2
4
  from PIL import Image
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchvision.transforms as T
8
+ from skimage import color
9
+ from sklearn.cluster import KMeans
10
+ from cloth_segmentation.networks.u2net import U2NET
11
 
 
 
12
  model = U2NET(3, 1)
13
+ model.load_state_dict(torch.load("u2net.pth", map_location=torch.device('cpu')))
 
 
14
  model.eval()
15
 
16
+ # Preprocessing
17
+ transform = T.Compose([
18
+ T.Resize((320, 320)),
19
+ T.ToTensor(),
20
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
21
+ ])
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # Segmentation mask
24
+ @torch.no_grad()
25
+ def get_dress_mask(image_pil):
26
+ img = transform(image_pil).unsqueeze(0)
27
+ pred = model(img)[0]
28
+ pred = pred.squeeze().cpu().numpy()
29
+ mask = (pred > 0.5).astype(np.uint8)
30
+ mask = cv2.resize(mask, image_pil.size[::-1])
31
+ return mask
32
 
33
+ # Color parsing (extract target color from prompt)
34
+ def extract_target_color(prompt):
35
+ # Basic keyword matching (can be replaced with NLP-based color detection)
36
+ import re
37
+ colors = ["red", "blue", "green", "yellow", "pink", "black", "white", "sky blue", "purple"]
38
+ for c in colors:
39
+ if re.search(c, prompt.lower()):
40
+ return c
41
+ return "red" # default fallback
42
 
43
+ # Recoloring function
44
+ def recolor_dress(image_pil, prompt):
45
+ image_np = np.array(image_pil.convert("RGB")) / 255.0
46
+ lab = color.rgb2lab(image_np)
 
47
 
48
+ mask = get_dress_mask(image_pil)
49
 
50
+ # Get mean a, b values in masked region
51
+ a_mean = lab[:, :, 1][mask == 1].mean()
52
+ b_mean = lab[:, :, 2][mask == 1].mean()
 
 
 
 
 
 
 
 
 
53
 
54
+ # Target a, b (from a small predefined palette)
55
+ target_color_map = {
56
+ "red": [60, 40],
57
+ "blue": [20, -60],
58
+ "green": [-60, 60],
59
+ "yellow": [10, 70],
60
+ "pink": [50, 10],
61
+ "purple": [40, -40],
62
+ "black": [0, 0],
63
+ "white": [0, 0],
64
+ "sky blue": [0, -50],
65
+ }
66
+ target = extract_target_color(prompt)
67
+ target_a, target_b = target_color_map.get(target, [60, 40])
68
 
69
+ # Apply color shift only to dress region
70
+ lab_new = lab.copy()
71
+ delta_a = target_a - a_mean
72
+ delta_b = target_b - b_mean
73
+ lab_new[:, :, 1][mask == 1] += delta_a
74
+ lab_new[:, :, 2][mask == 1] += delta_b
75
 
76
+ rgb_new = color.lab2rgb(lab_new)
77
+ rgb_new = (rgb_new * 255).astype(np.uint8)
78
+ return Image.fromarray(rgb_new)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ # Gradio UI
81
+ def interface_fn(image, prompt):
82
+ return recolor_dress(image, prompt)
83
 
84
+ interface = gr.Interface(
85
+ fn=interface_fn,
 
86
  inputs=[
87
+ gr.Image(label="Upload Image", type="pil"),
88
+ gr.Textbox(label="Describe Dress Color Change", placeholder="e.g. Make the dress sky blue")
89
  ],
90
+ outputs=gr.Image(label="Recolored Dress Output"),
91
+ title="Natural Dress Recoloring with Design Preservation (No GPU)",
92
+ description="Upload a fashion photo and use prompts like “Make the dress blue” or “Change to red silk”. This CPU-based app changes the color of dresses while preserving textures, embroidery, and shadows.",
93
+ allow_flagging="never"
94
  )
95
 
96
+ interface.launch()