Alican commited on
Commit
5548b5c
·
1 Parent(s): 08af810

Add application files

Browse files
Files changed (1) hide show
  1. app.py +49 -4
app.py CHANGED
@@ -19,11 +19,15 @@ from scipy.special import kl_div as scipy_kl_div
19
  from skimage.metrics import structural_similarity as ssim
20
  import warnings
21
 
22
- TASK = "nodules"
 
 
23
  PATH = os.path.join("datasets", TASK, "real")
24
  images = []
25
  perceptual_model = None
26
 
 
 
27
  try:
28
  vgg = VGG19(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
29
  vgg.trainable = False
@@ -31,7 +35,9 @@ try:
31
  except Exception as e:
32
  perceptual_model = None
33
 
 
34
  def safe_normalize_heatmap(heatmap):
 
35
  if heatmap is None or heatmap.size == 0:
36
  return np.zeros((64, 64), dtype=np.uint8)
37
  heatmap = heatmap.astype(np.float32)
@@ -48,7 +54,10 @@ def safe_normalize_heatmap(heatmap):
48
  normalized_heatmap = np.clip(normalized_heatmap, 0, 255)
49
  return np.uint8(normalized_heatmap)
50
 
 
 
51
  def KL_divergence(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
 
52
  if img_real is None or img_fake is None or img_real.shape != img_fake.shape:
53
  return None
54
  try:
@@ -92,6 +101,7 @@ def KL_divergence(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=
92
  return img_dict
93
 
94
  def L1_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
 
95
  if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
96
  try:
97
  img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
@@ -121,6 +131,7 @@ def L1_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False)
121
  return img_dict
122
 
123
  def MSE_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
 
124
  if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
125
  try:
126
  img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
@@ -150,6 +161,7 @@ def MSE_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False
150
  return img_dict
151
 
152
  def SSIM_loss(img_real, img_fake, block_size=7, sum_channels=False):
 
153
  if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
154
  try:
155
  img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB)
@@ -177,6 +189,7 @@ def SSIM_loss(img_real, img_fake, block_size=7, sum_channels=False):
177
  return img_dict
178
 
179
  def cosine_similarity_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
 
180
  if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
181
  try:
182
  img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
@@ -212,6 +225,7 @@ def cosine_similarity_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_
212
  return img_dict
213
 
214
  def TV_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
 
215
  if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
216
  try:
217
  img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
@@ -241,6 +255,7 @@ def TV_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False)
241
  return img_dict
242
 
243
  def perceptual_loss(img_real, img_fake, model, block_size=4):
 
244
  if img_real is None or img_fake is None or model is None or img_real.shape != img_fake.shape:
245
  return None
246
  original_height, original_width, _ = img_real.shape
@@ -264,7 +279,10 @@ def perceptual_loss(img_real, img_fake, model, block_size=4):
264
  heatmap_original_size = cv2.resize(heatmap_features, (original_width, original_height), interpolation=cv2.INTER_LINEAR)
265
  return {"SUM": {"SUM": total_loss, "HEATMAP": heatmap_original_size.astype(np.float32)}}
266
 
 
 
267
  def gather_images(task):
 
268
  global TASK, PATH, images
269
  new_path = os.path.join("datasets", task, "real")
270
  if TASK != task or not images:
@@ -311,6 +329,7 @@ def gather_images(task):
311
  return placeholder, placeholder, error_msg
312
 
313
  def run_comparison(real, fake, measurement, block_size_val):
 
314
  placeholder_heatmap = np.zeros((64, 64, 3), dtype=np.uint8)
315
  if real is None or fake is None or not isinstance(real, np.ndarray) or not isinstance(fake, np.ndarray):
316
  return placeholder_heatmap, "Error: Input image(s) missing or invalid. Please load or upload a pair of images."
@@ -353,8 +372,25 @@ def run_comparison(real, fake, measurement, block_size_val):
353
  return heatmap_rgb, status_msg
354
 
355
  def clear_uploads(msg):
 
356
  return None, None, msg
357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  theme = gr.themes.Soft(primary_hue="blue", secondary_hue="orange")
359
  with gr.Blocks(theme=theme, css=".gradio-container { max-width: 1400px !important; margin: auto; }") as demo:
360
  gr.Markdown("# GAN vs Ground Truth Image Comparison")
@@ -368,7 +404,7 @@ with gr.Blocks(theme=theme, css=".gradio-container { max-width: 1400px !importan
368
  with gr.Tabs():
369
  with gr.TabItem("Load from Dataset"):
370
  task_dropdown = gr.Dropdown(
371
- ["nodules", "facades"], value=TASK,
372
  info="Select the dataset task.",
373
  label="Dataset Task"
374
  )
@@ -387,7 +423,7 @@ with gr.Blocks(theme=theme, css=".gradio-container { max-width: 1400px !importan
387
  with gr.Row():
388
  measurement_dropdown = gr.Dropdown(
389
  ["Kullback-Leibler Divergence", "L1-Loss", "MSE", "SSIM", "Cosine Similarity", "TV", "Perceptual"],
390
- value="Kullback-Leibler Divergence",
391
  info="Select the comparison metric.",
392
  label="Comparison Metric",
393
  scale=2
@@ -404,6 +440,15 @@ with gr.Blocks(theme=theme, css=".gradio-container { max-width: 1400px !importan
404
  gr.Markdown("### 3. See Result")
405
  heatmap_display = gr.Image(type="numpy", label="Comparison Heatmap (Difference)", height=350, interactive=False)
406
 
 
 
 
 
 
 
 
 
 
407
  sample_button.click(
408
  fn=gather_images,
409
  inputs=[task_dropdown],
@@ -434,7 +479,7 @@ with gr.Blocks(theme=theme, css=".gradio-container { max-width: 1400px !importan
434
  outputs=[real_img_display, fake_img_display, status_message]
435
  )
436
 
437
-
438
  if __name__ == "__main__":
439
  print("-------------------------------------------------------------")
440
  print("Verifying VGG19 model status...")
 
19
  from skimage.metrics import structural_similarity as ssim
20
  import warnings
21
 
22
+ # --- Configuration ---
23
+ # Set the default task.
24
+ TASK = "facades"
25
  PATH = os.path.join("datasets", TASK, "real")
26
  images = []
27
  perceptual_model = None
28
 
29
+ # --- Model Loading ---
30
+ # Attempt to load the VGG19 model for the perceptual loss metric.
31
  try:
32
  vgg = VGG19(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
33
  vgg.trainable = False
 
35
  except Exception as e:
36
  perceptual_model = None
37
 
38
+ # --- Utility Functions ---
39
  def safe_normalize_heatmap(heatmap):
40
+ """Safely normalizes a heatmap to a 0-255 range for visualization."""
41
  if heatmap is None or heatmap.size == 0:
42
  return np.zeros((64, 64), dtype=np.uint8)
43
  heatmap = heatmap.astype(np.float32)
 
54
  normalized_heatmap = np.clip(normalized_heatmap, 0, 255)
55
  return np.uint8(normalized_heatmap)
56
 
57
+ # --- Image Comparison Metrics ---
58
+
59
  def KL_divergence(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
60
+ """Calculates Kullback-Leibler Divergence between two images."""
61
  if img_real is None or img_fake is None or img_real.shape != img_fake.shape:
62
  return None
63
  try:
 
101
  return img_dict
102
 
103
  def L1_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
104
+ """Calculates L1 (Mean Absolute Error) loss between two images."""
105
  if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
106
  try:
107
  img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
 
131
  return img_dict
132
 
133
  def MSE_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
134
+ """Calculates MSE (Mean Squared Error) loss between two images."""
135
  if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
136
  try:
137
  img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
 
161
  return img_dict
162
 
163
  def SSIM_loss(img_real, img_fake, block_size=7, sum_channels=False):
164
+ """Calculates SSIM (Structural Similarity Index) loss between two images."""
165
  if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
166
  try:
167
  img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB)
 
189
  return img_dict
190
 
191
  def cosine_similarity_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
192
+ """Calculates Cosine Similarity loss between two images."""
193
  if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
194
  try:
195
  img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
 
225
  return img_dict
226
 
227
  def TV_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
228
+ """Calculates Total Variation (TV) loss between two images."""
229
  if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
230
  try:
231
  img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
 
255
  return img_dict
256
 
257
  def perceptual_loss(img_real, img_fake, model, block_size=4):
258
+ """Calculates Perceptual loss using a pre-trained VGG19 model."""
259
  if img_real is None or img_fake is None or model is None or img_real.shape != img_fake.shape:
260
  return None
261
  original_height, original_width, _ = img_real.shape
 
279
  heatmap_original_size = cv2.resize(heatmap_features, (original_width, original_height), interpolation=cv2.INTER_LINEAR)
280
  return {"SUM": {"SUM": total_loss, "HEATMAP": heatmap_original_size.astype(np.float32)}}
281
 
282
+ # --- Gradio Core Functions ---
283
+
284
  def gather_images(task):
285
+ """Loads a random pair of real and fake images from the selected dataset."""
286
  global TASK, PATH, images
287
  new_path = os.path.join("datasets", task, "real")
288
  if TASK != task or not images:
 
329
  return placeholder, placeholder, error_msg
330
 
331
  def run_comparison(real, fake, measurement, block_size_val):
332
+ """Runs the selected comparison metric and generates a heatmap."""
333
  placeholder_heatmap = np.zeros((64, 64, 3), dtype=np.uint8)
334
  if real is None or fake is None or not isinstance(real, np.ndarray) or not isinstance(fake, np.ndarray):
335
  return placeholder_heatmap, "Error: Input image(s) missing or invalid. Please load or upload a pair of images."
 
372
  return heatmap_rgb, status_msg
373
 
374
  def clear_uploads(msg):
375
+ """Clears the image displays and updates the status message."""
376
  return None, None, msg
377
 
378
+ def load_and_compare_initial(task):
379
+ """Gathers initial images and runs a comparison on them at startup."""
380
+ # Step 1: Get the initial images
381
+ real_img, fake_img, gather_status = gather_images(task)
382
+
383
+ # Step 2: Run the default comparison
384
+ # We use the default values from the UI definition
385
+ default_measurement = "Cosine Similarity"
386
+ default_block_size = 8
387
+ heatmap, compare_status = run_comparison(real_img, fake_img, default_measurement, default_block_size)
388
+
389
+ # Step 3: Combine status messages and return all initial values
390
+ final_status = f"{gather_status}\n{compare_status}"
391
+ return real_img, fake_img, heatmap, final_status
392
+
393
+ # --- Gradio UI Definition ---
394
  theme = gr.themes.Soft(primary_hue="blue", secondary_hue="orange")
395
  with gr.Blocks(theme=theme, css=".gradio-container { max-width: 1400px !important; margin: auto; }") as demo:
396
  gr.Markdown("# GAN vs Ground Truth Image Comparison")
 
404
  with gr.Tabs():
405
  with gr.TabItem("Load from Dataset"):
406
  task_dropdown = gr.Dropdown(
407
+ ["facades"], value=TASK,
408
  info="Select the dataset task.",
409
  label="Dataset Task"
410
  )
 
423
  with gr.Row():
424
  measurement_dropdown = gr.Dropdown(
425
  ["Kullback-Leibler Divergence", "L1-Loss", "MSE", "SSIM", "Cosine Similarity", "TV", "Perceptual"],
426
+ value="Cosine Similarity",
427
  info="Select the comparison metric.",
428
  label="Comparison Metric",
429
  scale=2
 
440
  gr.Markdown("### 3. See Result")
441
  heatmap_display = gr.Image(type="numpy", label="Comparison Heatmap (Difference)", height=350, interactive=False)
442
 
443
+ # --- Event Listeners ---
444
+
445
+ # Load initial sample and run comparison when the app starts
446
+ demo.load(
447
+ fn=load_and_compare_initial,
448
+ inputs=[task_dropdown],
449
+ outputs=[real_img_display, fake_img_display, heatmap_display, status_message]
450
+ )
451
+
452
  sample_button.click(
453
  fn=gather_images,
454
  inputs=[task_dropdown],
 
479
  outputs=[real_img_display, fake_img_display, status_message]
480
  )
481
 
482
+ # --- Application Entry Point ---
483
  if __name__ == "__main__":
484
  print("-------------------------------------------------------------")
485
  print("Verifying VGG19 model status...")