gaur3009 commited on
Commit
6a85c9c
·
verified ·
1 Parent(s): 751c76a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -13
app.py CHANGED
@@ -7,6 +7,7 @@ import torch
7
  from torchvision import transforms
8
  from torchvision.models.segmentation import deeplabv3_resnet101
9
 
 
10
  model = deeplabv3_resnet101(pretrained=True)
11
  model.eval()
12
 
@@ -24,33 +25,47 @@ def segment_clothing(image):
24
  output = model(input_tensor)['out'][0]
25
  output_predictions = output.argmax(0).byte().cpu().numpy()
26
 
27
- mask = cv2.resize(output_predictions, (image.shape[1], image.shape[0]))
 
 
28
  return mask
29
 
30
  def generate_displacement_map(image, mask):
 
31
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
32
- blurred = cv2.GaussianBlur(gray, (15, 15), 0)
33
  displacement_map = cv2.normalize(blurred, None, 0, 255, cv2.NORM_MINMAX)
34
- displacement_map[mask != 15] = 0
 
35
  return displacement_map
36
 
37
  def warp_text(image, text_overlay, displacement_map):
 
38
  text_overlay_array = np.array(text_overlay)
39
- displacement_map = cv2.GaussianBlur(displacement_map, (15, 15), 0)
40
 
 
41
  h, w = displacement_map.shape
42
  x, y = np.meshgrid(np.arange(w), np.arange(h))
43
- x_displacement = x + displacement_map / 50.0
44
- y_displacement = y + displacement_map / 50.0
45
-
46
- warped = cv2.remap(text_overlay_array, x_displacement.astype(np.float32), y_displacement.astype(np.float32), interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)
 
 
 
 
 
 
 
47
  return Image.fromarray(warped)
48
 
49
  def overlay_text(image, text, font_size, color, mask):
 
50
  pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).convert("RGBA")
51
- draw = ImageDraw.Draw(pil_image)
52
 
53
- y_indices, x_indices = np.where(mask == 15)
 
54
  if len(x_indices) == 0 or len(y_indices) == 0:
55
  return None, "No clothing region detected."
56
 
@@ -60,6 +75,11 @@ def overlay_text(image, text, font_size, color, mask):
60
  clothing_width = x_max - x_min
61
  clothing_height = y_max - y_min
62
 
 
 
 
 
 
63
  font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
64
  if not os.path.exists(font_path):
65
  return None, "Font file not found. Please provide a valid font path."
@@ -73,14 +93,15 @@ def overlay_text(image, text, font_size, color, mask):
73
  font = ImageFont.truetype(font_path, font_size)
74
  text_width, text_height = font.getbbox(text)[2:]
75
 
 
76
  text_x = x_min + (clothing_width - text_width) // 2
77
  text_y = y_min + (clothing_height - text_height) // 2
78
 
 
79
  text_overlay = Image.new("RGBA", pil_image.size, (255, 255, 255, 0))
80
  text_draw = ImageDraw.Draw(text_overlay)
81
-
82
  try:
83
- rgba_color = tuple(color) + (255,)
84
  text_draw.text((text_x, text_y), text, font=font, fill=rgba_color)
85
  except Exception as e:
86
  return None, f"Error applying color: {str(e)}"
@@ -89,26 +110,32 @@ def overlay_text(image, text, font_size, color, mask):
89
 
90
  def process_image(image, text, font_size, color):
91
  try:
 
92
  mask = segment_clothing(image)
93
  if mask.sum() == 0:
94
  return "No clothing detected. Try another image."
95
 
 
96
  displacement_map = generate_displacement_map(image, mask)
97
 
 
98
  text_overlay, error = overlay_text(image, text, font_size, color, mask)
99
  if error:
100
  return error
101
 
 
102
  warped_text = warp_text(image, text_overlay, displacement_map)
103
 
 
104
  pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).convert("RGBA")
105
- final_image = Image.alpha_composite(pil_image, warped_text)
106
 
107
  return final_image
108
  except Exception as e:
109
  print(f"Error processing image: {str(e)}")
110
  return f"Error: {str(e)}"
111
 
 
112
  gr.Interface(
113
  fn=process_image,
114
  inputs=[
 
7
  from torchvision import transforms
8
  from torchvision.models.segmentation import deeplabv3_resnet101
9
 
10
+ # Load Pretrained DeepLabV3 Model
11
  model = deeplabv3_resnet101(pretrained=True)
12
  model.eval()
13
 
 
25
  output = model(input_tensor)['out'][0]
26
  output_predictions = output.argmax(0).byte().cpu().numpy()
27
 
28
+ # Scale back to original size
29
+ mask = cv2.resize(output_predictions, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
30
+ print(f"Mask shape: {mask.shape}, unique values: {np.unique(mask)}") # Debugging
31
  return mask
32
 
33
  def generate_displacement_map(image, mask):
34
+ """Generate a displacement map from the clothing region."""
35
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
36
+ blurred = cv2.GaussianBlur(gray, (9, 9), 0) # Reduced kernel size for clarity
37
  displacement_map = cv2.normalize(blurred, None, 0, 255, cv2.NORM_MINMAX)
38
+ displacement_map[mask != 15] = 0 # Apply mask (class 15 corresponds to 'person')
39
+ print(f"Displacement map stats - Min: {np.min(displacement_map)}, Max: {np.max(displacement_map)}") # Debugging
40
  return displacement_map
41
 
42
  def warp_text(image, text_overlay, displacement_map):
43
+ """Warp the text overlay based on the displacement map."""
44
  text_overlay_array = np.array(text_overlay)
45
+ displacement_map = cv2.GaussianBlur(displacement_map, (9, 9), 0) # Reduced blur for better details
46
 
47
+ # Create an x, y distortion map
48
  h, w = displacement_map.shape
49
  x, y = np.meshgrid(np.arange(w), np.arange(h))
50
+ x_displacement = x + displacement_map / 100.0 # Adjusted scaling factor for subtle warping
51
+ y_displacement = y + displacement_map / 100.0
52
+
53
+ # Warp text overlay using remap
54
+ warped = cv2.remap(
55
+ text_overlay_array,
56
+ x_displacement.astype(np.float32),
57
+ y_displacement.astype(np.float32),
58
+ interpolation=cv2.INTER_LINEAR,
59
+ borderMode=cv2.BORDER_CONSTANT
60
+ )
61
  return Image.fromarray(warped)
62
 
63
  def overlay_text(image, text, font_size, color, mask):
64
+ """Overlay text onto the detected clothing region."""
65
  pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).convert("RGBA")
 
66
 
67
+ # Find the bounding box of the mask (clothing area)
68
+ y_indices, x_indices = np.where(mask == 15) # Class 15 corresponds to 'person' in DeepLabV3
69
  if len(x_indices) == 0 or len(y_indices) == 0:
70
  return None, "No clothing region detected."
71
 
 
75
  clothing_width = x_max - x_min
76
  clothing_height = y_max - y_min
77
 
78
+ # Ensure the color is correctly formatted
79
+ color = color.lstrip('#')
80
+ color_tuple = tuple(int(color[i:i+2], 16) for i in (0, 2, 4))
81
+
82
+ # Load font and adjust size dynamically
83
  font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
84
  if not os.path.exists(font_path):
85
  return None, "Font file not found. Please provide a valid font path."
 
93
  font = ImageFont.truetype(font_path, font_size)
94
  text_width, text_height = font.getbbox(text)[2:]
95
 
96
+ # Calculate position to center the text
97
  text_x = x_min + (clothing_width - text_width) // 2
98
  text_y = y_min + (clothing_height - text_height) // 2
99
 
100
+ # Draw the text on a transparent overlay
101
  text_overlay = Image.new("RGBA", pil_image.size, (255, 255, 255, 0))
102
  text_draw = ImageDraw.Draw(text_overlay)
 
103
  try:
104
+ rgba_color = color_tuple + (255,) # Add alpha channel
105
  text_draw.text((text_x, text_y), text, font=font, fill=rgba_color)
106
  except Exception as e:
107
  return None, f"Error applying color: {str(e)}"
 
110
 
111
  def process_image(image, text, font_size, color):
112
  try:
113
+ # Segment the clothing using DeepLabV3
114
  mask = segment_clothing(image)
115
  if mask.sum() == 0:
116
  return "No clothing detected. Try another image."
117
 
118
+ # Generate displacement map
119
  displacement_map = generate_displacement_map(image, mask)
120
 
121
+ # Overlay the text
122
  text_overlay, error = overlay_text(image, text, font_size, color, mask)
123
  if error:
124
  return error
125
 
126
+ # Warp text using displacement map
127
  warped_text = warp_text(image, text_overlay, displacement_map)
128
 
129
+ # Blend the warped text back onto the original image
130
  pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).convert("RGBA")
131
+ final_image = Image.alpha_composite(pil_image, warped_text).convert("RGB")
132
 
133
  return final_image
134
  except Exception as e:
135
  print(f"Error processing image: {str(e)}")
136
  return f"Error: {str(e)}"
137
 
138
+ # Gradio Interface
139
  gr.Interface(
140
  fn=process_image,
141
  inputs=[