gaur3009 commited on
Commit
9582f0e
·
verified ·
1 Parent(s): c8390c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -19
app.py CHANGED
@@ -12,16 +12,12 @@ model = U2NET(3, 1)
12
 
13
  # Load the state dictionary
14
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
15
-
16
- # Remove the 'module.' prefix from the keys
17
- state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
18
-
19
- # Load the modified state dictionary into the model
20
  model.load_state_dict(state_dict)
21
  model.eval()
22
 
23
  def segment_dress(image_np):
24
- """Segment the dress from the image using U²-Net."""
25
  transform_pipeline = transforms.Compose([
26
  transforms.ToTensor(),
27
  transforms.Resize((320, 320))
@@ -32,8 +28,16 @@ def segment_dress(image_np):
32
  with torch.no_grad():
33
  output = model(input_tensor)[0][0].squeeze().cpu().numpy()
34
 
35
- mask = (output > 0.5).astype(np.uint8) * 255 # Thresholding for binary mask
36
- mask = cv2.resize(mask, (image_np.shape[1], image_np.shape[0])) # Resize mask to original
 
 
 
 
 
 
 
 
37
  return mask
38
 
39
  def change_dress_color(image_path, color):
@@ -44,7 +48,7 @@ def change_dress_color(image_path, color):
44
  img = Image.open(image_path).convert("RGB")
45
  img_np = np.array(img)
46
  mask = segment_dress(img_np)
47
-
48
  if mask is None:
49
  return img # No dress detected
50
 
@@ -58,19 +62,17 @@ def change_dress_color(image_path, color):
58
  }
59
  new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8) # Default to Red
60
 
61
- # Convert image to HSV
62
- img_hsv = cv2.cvtColor(img_np, cv2.COLOR_RGB2HSV)
63
-
64
- # Convert new color to HSV
65
- new_color_hsv = cv2.cvtColor(np.uint8([[new_color_bgr]]), cv2.COLOR_BGR2HSV)[0][0]
66
 
67
- # Apply the new color only to the masked dress area
68
- img_hsv[..., 0] = np.where(mask == 255, new_color_hsv[0], img_hsv[..., 0]) # Hue
69
- img_hsv[..., 1] = np.where(mask == 255, new_color_hsv[1], img_hsv[..., 1]) # Saturation
70
 
71
  # Convert back to RGB
72
- img_recolored = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB)
73
-
74
  return Image.fromarray(img_recolored)
75
 
76
  # Gradio Interface
 
12
 
13
  # Load the state dictionary
14
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
15
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} # Remove 'module.' prefix
 
 
 
 
16
  model.load_state_dict(state_dict)
17
  model.eval()
18
 
19
  def segment_dress(image_np):
20
+ """Segment the dress from the image using U²-Net and refine the mask."""
21
  transform_pipeline = transforms.Compose([
22
  transforms.ToTensor(),
23
  transforms.Resize((320, 320))
 
28
  with torch.no_grad():
29
  output = model(input_tensor)[0][0].squeeze().cpu().numpy()
30
 
31
+ mask = (output > 0.5).astype(np.uint8) * 255 # Binary mask
32
+
33
+ # Resize mask to original image size
34
+ mask = cv2.resize(mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
35
+
36
+ # Apply morphological operations for better segmentation
37
+ kernel = np.ones((7, 7), np.uint8)
38
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) # Close small gaps
39
+ mask = cv2.dilate(mask, kernel, iterations=2) # Expand the detected dress area
40
+
41
  return mask
42
 
43
  def change_dress_color(image_path, color):
 
48
  img = Image.open(image_path).convert("RGB")
49
  img_np = np.array(img)
50
  mask = segment_dress(img_np)
51
+
52
  if mask is None:
53
  return img # No dress detected
54
 
 
62
  }
63
  new_color_bgr = np.array(color_map.get(color, (0, 0, 255)), dtype=np.uint8) # Default to Red
64
 
65
+ # Convert image to LAB color space for better blending
66
+ img_lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
67
+ new_color_lab = cv2.cvtColor(np.uint8([[new_color_bgr]]), cv2.COLOR_BGR2LAB)[0][0]
 
 
68
 
69
+ # Apply the new color while preserving texture
70
+ img_lab[..., 1] = np.where(mask == 255, new_color_lab[1], img_lab[..., 1]) # Modify A-channel
71
+ img_lab[..., 2] = np.where(mask == 255, new_color_lab[2], img_lab[..., 2]) # Modify B-channel
72
 
73
  # Convert back to RGB
74
+ img_recolored = cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)
75
+
76
  return Image.fromarray(img_recolored)
77
 
78
  # Gradio Interface