sivakum4 commited on
Commit
3814f03
·
1 Parent(s): 0d8a13d

Droping Metrics

Browse files
Files changed (1) hide show
  1. app.py +15 -42
app.py CHANGED
@@ -13,11 +13,8 @@ try:
13
  except Exception:
14
  kc = None
15
 
16
- from skimage.metrics import peak_signal_noise_ratio as psnr_metric
17
- from skimage.metrics import structural_similarity as ssim_metric
18
  from model import ViTUNetColorizer
19
 
20
-
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
 
23
  CKPT = "checkpoints/checkpoint_epoch_017_20250810_193435.pt"
@@ -57,49 +54,34 @@ def pad_to_multiple(img_np, m=16):
57
  ph, pw = math.ceil(h/m)*m, math.ceil(w/m)*m
58
  return cv2.copyMakeBorder(img_np,0,ph-h,0,pw-w,cv2.BORDER_CONSTANT,value=(0,0,0)), (h,w)
59
 
60
- def compute_metrics(pred, gt):
61
- p = pred.astype(np.float32)/255.; g = gt.astype(np.float32)/255.
62
- mae = float(np.mean(np.abs(p-g)))
63
- psnr = float(psnr_metric(g, p, data_range=1.0))
64
- try:
65
- ssim = float(ssim_metric(g, p, channel_axis=2, data_range=1.0, win_size=7))
66
- except TypeError:
67
- ssim = float(ssim_metric(g, p, multichannel=True, data_range=1.0, win_size=7))
68
- return round(mae,4), round(psnr,2), round(ssim,4)
69
-
70
  def to_grayscale(image):
71
  if image is None:
72
  return None
73
  return image.convert("L").convert("RGB")
74
 
75
- def infer(image: Image.Image, want_metrics: bool):
76
  if image is None:
77
- return None, None, None, None, None
78
  if model is None:
79
- return None, None, None, None, "<div>Checkpoint not found.</div>"
80
 
81
  pil = image.convert("RGB")
82
  rgb = np.array(pil)
83
-
84
  proc, (oh, ow) = pad_to_multiple(rgb, 16); back = (ow, oh)
85
 
86
  L = to_L(proc)
87
  with torch.no_grad():
88
  ab = model(L)
89
  out = lab_to_rgb(L, ab)
90
-
91
  out = out[:back[1], :back[0]]
92
 
93
- mae = psnr = ssim = None
94
- if want_metrics:
95
- mae, psnr, ssim = compute_metrics(out, np.array(pil))
96
-
97
  gray_pil = pil.convert("L").convert("RGB")
98
  _, bo = cv2.imencode(".jpg", cv2.cvtColor(np.array(gray_pil), cv2.COLOR_RGB2BGR))
99
  _, bc = cv2.imencode(".jpg", cv2.cvtColor(out, cv2.COLOR_RGB2BGR))
100
  so = "data:image/jpeg;base64," + base64.b64encode(bo).decode()
101
  sc = "data:image/jpeg;base64," + base64.b64encode(bc).decode()
102
-
103
  compare_html = f"""
104
  <div style="margin:auto; border-radius:14px; overflow:hidden;">
105
  <img-comparison-slider>
@@ -109,7 +91,7 @@ def infer(image: Image.Image, want_metrics: bool):
109
  </div>
110
  """
111
 
112
- return out, mae, psnr, ssim, compare_html
113
 
114
  def save_for_download(image_array):
115
  """Saves a NumPy array to a temporary file and returns the path."""
@@ -159,7 +141,6 @@ with gr.Blocks(theme=THEME, title="Image Colorizer", head=HEAD) as demo:
159
  sources=["upload", "webcam", "clipboard"]
160
  )
161
  img_in.upload(fn=to_grayscale, inputs=img_in, outputs=img_in)
162
- show_m = gr.Checkbox(label="Show metrics", value=True)
163
  with gr.Row():
164
  run = gr.Button("Colorize")
165
  clr = gr.Button("Clear")
@@ -174,33 +155,25 @@ with gr.Blocks(theme=THEME, title="Image Colorizer", head=HEAD) as demo:
174
 
175
  with gr.Column(scale=7):
176
  out_html = gr.HTML(label="Result", value=PLACEHOLDER_HTML)
177
- with gr.Row():
178
- mae_box = gr.Number(label="MAE", interactive=False, precision=4)
179
- psnr_box = gr.Number(label="PSNR (dB)", interactive=False, precision=2)
180
- ssim_box = gr.Number(label="SSIM", interactive=False, precision=4)
181
-
182
- def _go(image, want_metrics):
183
- out_image, mae, psnr, ssim, cmp_html = infer(image, want_metrics)
184
- if not want_metrics:
185
- mae = psnr = ssim = None
186
-
187
  download_button_update = gr.update(visible=True) if out_image is not None else gr.update(visible=False)
188
-
189
- return out_image, cmp_html, mae, psnr, ssim, download_button_update
190
 
191
  run.click(
192
  _go,
193
- inputs=[img_in, show_m],
194
- outputs=[result_state, out_html, mae_box, psnr_box, ssim_box, download_btn]
195
  )
196
 
197
  def _clear():
198
- return None, None, PLACEHOLDER_HTML, None, None, None, gr.update(visible=False)
199
 
200
  clr.click(
201
  _clear,
202
  inputs=None,
203
- outputs=[img_in, result_state, out_html, mae_box, psnr_box, ssim_box, download_btn]
204
  )
205
 
206
  download_btn.click(
@@ -213,4 +186,4 @@ if __name__ == "__main__":
213
  try:
214
  demo.launch(show_api=False)
215
  except TypeError:
216
- demo.launch()
 
13
  except Exception:
14
  kc = None
15
 
 
 
16
  from model import ViTUNetColorizer
17
 
 
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
 
20
  CKPT = "checkpoints/checkpoint_epoch_017_20250810_193435.pt"
 
54
  ph, pw = math.ceil(h/m)*m, math.ceil(w/m)*m
55
  return cv2.copyMakeBorder(img_np,0,ph-h,0,pw-w,cv2.BORDER_CONSTANT,value=(0,0,0)), (h,w)
56
 
 
 
 
 
 
 
 
 
 
 
57
  def to_grayscale(image):
58
  if image is None:
59
  return None
60
  return image.convert("L").convert("RGB")
61
 
62
+ def infer(image: Image.Image):
63
  if image is None:
64
+ return None, None
65
  if model is None:
66
+ return None, "<div>Checkpoint not found.</div>"
67
 
68
  pil = image.convert("RGB")
69
  rgb = np.array(pil)
70
+
71
  proc, (oh, ow) = pad_to_multiple(rgb, 16); back = (ow, oh)
72
 
73
  L = to_L(proc)
74
  with torch.no_grad():
75
  ab = model(L)
76
  out = lab_to_rgb(L, ab)
 
77
  out = out[:back[1], :back[0]]
78
 
 
 
 
 
79
  gray_pil = pil.convert("L").convert("RGB")
80
  _, bo = cv2.imencode(".jpg", cv2.cvtColor(np.array(gray_pil), cv2.COLOR_RGB2BGR))
81
  _, bc = cv2.imencode(".jpg", cv2.cvtColor(out, cv2.COLOR_RGB2BGR))
82
  so = "data:image/jpeg;base64," + base64.b64encode(bo).decode()
83
  sc = "data:image/jpeg;base64," + base64.b64encode(bc).decode()
84
+
85
  compare_html = f"""
86
  <div style="margin:auto; border-radius:14px; overflow:hidden;">
87
  <img-comparison-slider>
 
91
  </div>
92
  """
93
 
94
+ return out, compare_html
95
 
96
  def save_for_download(image_array):
97
  """Saves a NumPy array to a temporary file and returns the path."""
 
141
  sources=["upload", "webcam", "clipboard"]
142
  )
143
  img_in.upload(fn=to_grayscale, inputs=img_in, outputs=img_in)
 
144
  with gr.Row():
145
  run = gr.Button("Colorize")
146
  clr = gr.Button("Clear")
 
155
 
156
  with gr.Column(scale=7):
157
  out_html = gr.HTML(label="Result", value=PLACEHOLDER_HTML)
158
+
159
+ def _go(image):
160
+ out_image, cmp_html = infer(image)
 
 
 
 
 
 
 
161
  download_button_update = gr.update(visible=True) if out_image is not None else gr.update(visible=False)
162
+ return out_image, cmp_html, download_button_update
 
163
 
164
  run.click(
165
  _go,
166
+ inputs=[img_in],
167
+ outputs=[result_state, out_html, download_btn]
168
  )
169
 
170
  def _clear():
171
+ return None, None, PLACEHOLDER_HTML, gr.update(visible=False)
172
 
173
  clr.click(
174
  _clear,
175
  inputs=None,
176
+ outputs=[img_in, result_state, out_html, download_btn]
177
  )
178
 
179
  download_btn.click(
 
186
  try:
187
  demo.launch(show_api=False)
188
  except TypeError:
189
+ demo.launch()