Spaces:
Sleeping
Sleeping
import os | |
import cv2 | |
import random | |
import numpy as np | |
import gradio as gr | |
try: | |
from tensorflow.keras.models import Model | |
from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input | |
except ImportError: | |
try: | |
from keras.models import Model | |
from keras.applications.vgg19 import VGG19, preprocess_input | |
except ImportError: | |
# Silently fail if Keras/TensorFlow is not installed, the UI will handle the error. | |
pass | |
import matplotlib.pyplot as plt | |
from scipy.special import kl_div as scipy_kl_div | |
from skimage.metrics import structural_similarity as ssim | |
import warnings | |
# --- Global Variables --- | |
TASK = "nodules" | |
PATH = os.path.join("datasets", TASK, "real") | |
images = [] | |
perceptual_model = None | |
try: | |
# Initialize the VGG19 model for the perceptual loss metric. | |
vgg = VGG19(weights='imagenet', include_top=False, input_shape=(224, 224, 3)) | |
vgg.trainable = False | |
perceptual_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block5_conv4').output, name="perceptual_model") | |
except Exception as e: | |
# This will be handled gracefully in the UI if the model fails to load. | |
perceptual_model = None | |
# --- Utility Functions --- | |
def safe_normalize_heatmap(heatmap): | |
"""Safely normalizes a heatmap to a 0-255 range for visualization, handling non-finite values.""" | |
if heatmap is None or heatmap.size == 0: | |
return np.zeros((64, 64), dtype=np.uint8) | |
heatmap = heatmap.astype(np.float32) | |
# Replace non-finite values (NaN, inf) with numerical ones for safe processing. | |
if not np.all(np.isfinite(heatmap)): | |
min_val_safe = np.nanmin(heatmap[np.isfinite(heatmap)]) if np.any(np.isfinite(heatmap)) else 0 | |
max_val_safe = np.nanmax(heatmap[np.isfinite(heatmap)]) if np.any(np.isfinite(heatmap)) else 0 | |
heatmap = np.nan_to_num(heatmap, nan=0.0, posinf=max_val_safe, neginf=min_val_safe) | |
min_val = np.min(heatmap) | |
max_val = np.max(heatmap) | |
range_val = max_val - min_val | |
# Normalize the heatmap to the 0-255 range. | |
normalized_heatmap = np.zeros_like(heatmap, dtype=np.float32) | |
if range_val > 1e-9: | |
normalized_heatmap = ((heatmap - min_val) / range_val) * 255.0 | |
normalized_heatmap = np.clip(normalized_heatmap, 0, 255) | |
return np.uint8(normalized_heatmap) | |
# --- Comparison Metric Functions --- | |
def KL_divergence(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False): | |
"""Calculates the Kullback-Leibler Divergence between two images on a block-by-block basis.""" | |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: | |
return None | |
try: | |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
except cv2.error: | |
return None | |
height, width, channels = img_real_rgb.shape | |
img_dict = { | |
"R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, | |
"G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, | |
"B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, | |
"SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} | |
} | |
channel_keys = ["R", "G", "B"] | |
current_block_size = max(1, int(block_size)) | |
if current_block_size > min(height, width): | |
current_block_size = min(height, width) | |
for channel_idx, key in enumerate(channel_keys): | |
channel_sum = 0.0 | |
for i in range(0, height - current_block_size + 1, current_block_size): | |
for j in range(0, width - current_block_size + 1, current_block_size): | |
block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx].flatten() + epsilon | |
block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx].flatten() + epsilon | |
# Normalize distributions within the block | |
if np.sum(block_gt) > 0 and np.sum(block_pred) > 0: | |
block_gt_norm = block_gt / np.sum(block_gt) | |
block_pred_norm = block_pred / np.sum(block_pred) | |
kl_values = scipy_kl_div(block_gt_norm, block_pred_norm) | |
kl_values = np.nan_to_num(kl_values, nan=0.0, posinf=0.0, neginf=0.0) | |
kl_sum_block = np.sum(kl_values) | |
if np.isfinite(kl_sum_block): | |
channel_sum += kl_sum_block | |
mean_kl_block = kl_sum_block / max(1, current_block_size * current_block_size) | |
img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = mean_kl_block | |
if sum_channels: | |
img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += mean_kl_block | |
img_dict[key]["SUM"] = channel_sum | |
if sum_channels: | |
img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] | |
img_dict["SUM"]["HEATMAP"] /= max(1, channels) | |
return img_dict | |
def L1_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False): | |
"""Calculates the L1 (Mean Absolute Error) loss between two images.""" | |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None | |
try: | |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
except cv2.error: return None | |
height, width, channels = img_real_rgb.shape | |
img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} } | |
channel_keys = ["R", "G", "B"] | |
current_block_size = max(1, int(block_size)) | |
if current_block_size > min(height, width): current_block_size = min(height, width) | |
for channel_idx, key in enumerate(channel_keys): | |
channel_sum = 0.0 | |
for i in range(0, height - current_block_size + 1, current_block_size): | |
for j in range(0, width - current_block_size + 1, current_block_size): | |
block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] | |
block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] | |
result_block = np.abs(block_pred - block_gt) | |
sum_result_block = np.sum(result_block) | |
channel_sum += sum_result_block | |
mean_l1_block = sum_result_block / max(1, current_block_size * current_block_size) | |
img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = mean_l1_block | |
if sum_channels: img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += mean_l1_block | |
img_dict[key]["SUM"] = channel_sum | |
if sum_channels: | |
img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] | |
img_dict["SUM"]["HEATMAP"] /= max(1, channels) | |
return img_dict | |
def MSE_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False): | |
"""Calculates the L2 (Mean Squared Error) loss between two images.""" | |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None | |
try: | |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
except cv2.error: return None | |
height, width, channels = img_real_rgb.shape | |
img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} } | |
channel_keys = ["R", "G", "B"] | |
current_block_size = max(1, int(block_size)) | |
if current_block_size > min(height, width): current_block_size = min(height, width) | |
for channel_idx, key in enumerate(channel_keys): | |
channel_sum = 0.0 | |
for i in range(0, height - current_block_size + 1, current_block_size): | |
for j in range(0, width - current_block_size + 1, current_block_size): | |
block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] | |
block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] | |
result_block = np.square(block_pred - block_gt) | |
sum_result_block = np.sum(result_block) | |
channel_sum += sum_result_block | |
mean_mse_block = sum_result_block / max(1, current_block_size * current_block_size) | |
img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = mean_mse_block | |
if sum_channels: img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += mean_mse_block | |
img_dict[key]["SUM"] = channel_sum | |
if sum_channels: | |
img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] | |
img_dict["SUM"]["HEATMAP"] /= max(1, channels) | |
return img_dict | |
def SSIM_loss(img_real, img_fake, block_size=7, sum_channels=False): | |
"""Calculates the Structural Similarity Index Measure (SSIM) loss between two images.""" | |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None | |
try: | |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB) | |
img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB) | |
except cv2.error: return None | |
height, width, channels = img_real_rgb.shape | |
img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} } | |
channel_keys = ["R", "G", "B"] | |
for channel_idx, key in enumerate(channel_keys): | |
win_size = int(block_size) | |
if win_size % 2 == 0: win_size += 1 | |
win_size = max(3, min(win_size, height, width)) | |
try: | |
_, ssim_map = ssim(img_real_rgb[:, :, channel_idx], img_fake_rgb[:, :, channel_idx], win_size=win_size, data_range=255, full=True, gaussian_weights=True) | |
ssim_loss_map = np.maximum(0.0, 1.0 - ssim_map) | |
img_dict[key]["SUM"] = np.sum(ssim_loss_map) | |
img_dict[key]["HEATMAP"] = ssim_loss_map | |
if sum_channels: img_dict["SUM"]["HEATMAP"] += ssim_loss_map | |
except ValueError: | |
img_dict[key]["SUM"] = 0.0 | |
img_dict[key]["HEATMAP"] = np.zeros((height, width), dtype=np.float32) | |
if sum_channels: | |
img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] | |
img_dict["SUM"]["HEATMAP"] /= max(1, channels) | |
return img_dict | |
def cosine_similarity_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False): | |
"""Calculates the Cosine Similarity loss between two images on a block-by-block basis.""" | |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None | |
try: | |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
except cv2.error: return None | |
height, width, channels = img_real_rgb.shape | |
img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} } | |
channel_keys = ["R", "G", "B"] | |
current_block_size = max(1, int(block_size)) | |
if current_block_size > min(height, width): current_block_size = min(height, width) | |
for channel_idx, key in enumerate(channel_keys): | |
channel_sum = 0.0 | |
for i in range(0, height - current_block_size + 1, current_block_size): | |
for j in range(0, width - current_block_size + 1, current_block_size): | |
block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx].flatten() | |
block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx].flatten() | |
dot_product = np.dot(block_pred, block_gt) | |
norm_pred = np.linalg.norm(block_pred) | |
norm_gt = np.linalg.norm(block_gt) | |
cosine_sim = 0.0 | |
if norm_pred * norm_gt > epsilon: | |
cosine_sim = dot_product / (norm_pred * norm_gt) | |
elif norm_pred < epsilon and norm_gt < epsilon: | |
cosine_sim = 1.0 # If both vectors are near-zero, they are identical. | |
result_block = 1.0 - np.clip(cosine_sim, -1.0, 1.0) | |
channel_sum += result_block | |
img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = result_block | |
if sum_channels: img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += result_block | |
img_dict[key]["SUM"] = channel_sum | |
if sum_channels: | |
img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] | |
img_dict["SUM"]["HEATMAP"] /= max(1, channels) | |
return img_dict | |
def TV_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False): | |
"""Calculates the Total Variation (TV) loss between two images.""" | |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None | |
try: | |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
except cv2.error: return None | |
height, width, channels = img_real_rgb.shape | |
img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} } | |
channel_keys = ["R", "G", "B"] | |
current_block_size = max(2, int(block_size)) # TV needs at least 2x2 blocks | |
if current_block_size > min(height, width): current_block_size = min(height, width) | |
for channel_idx, key in enumerate(channel_keys): | |
channel_sum = 0.0 | |
for i in range(0, height - current_block_size + 1, current_block_size): | |
for j in range(0, width - current_block_size + 1, current_block_size): | |
block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] | |
block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] | |
tv_pred = np.sum(np.abs(block_pred[:, 1:] - block_pred[:, :-1])) + np.sum(np.abs(block_pred[1:, :] - block_pred[:-1, :])) | |
tv_gt = np.sum(np.abs(block_gt[:, 1:] - block_gt[:, :-1])) + np.sum(np.abs(block_gt[1:, :] - block_gt[:-1, :])) | |
result_block = np.abs(tv_pred - tv_gt) | |
channel_sum += result_block | |
img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = result_block | |
if sum_channels: img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += result_block | |
img_dict[key]["SUM"] = channel_sum | |
if sum_channels: | |
img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] | |
img_dict["SUM"]["HEATMAP"] /= max(1, channels) | |
return img_dict | |
def perceptual_loss(img_real, img_fake, model, block_size=4): | |
"""Calculates the perceptual loss using a pre-trained VGG19 model.""" | |
if img_real is None or img_fake is None or model is None or img_real.shape != img_fake.shape: | |
return None | |
original_height, original_width, _ = img_real.shape | |
try: | |
# Determine the target input size from the model | |
target_size = (model.input_shape[1], model.input_shape[2]) | |
cv2_target_size = (target_size[1], target_size[0]) | |
# Resize, convert to RGB, and preprocess images for the model | |
img_real_resized = cv2.resize(img_real, cv2_target_size, interpolation=cv2.INTER_AREA) | |
img_fake_resized = cv2.resize(img_fake, cv2_target_size, interpolation=cv2.INTER_AREA) | |
img_real_processed = preprocess_input(np.expand_dims(cv2.cvtColor(img_real_resized, cv2.COLOR_BGR2RGB), axis=0)) | |
img_fake_processed = preprocess_input(np.expand_dims(cv2.cvtColor(img_fake_resized, cv2.COLOR_BGR2RGB), axis=0)) | |
except Exception: | |
return None | |
try: | |
# Get feature maps from the model | |
img_real_vgg = model.predict(img_real_processed) | |
img_fake_vgg = model.predict(img_fake_processed) | |
except Exception: | |
return None | |
# Calculate MSE between feature maps | |
feature_mse = np.square(img_real_vgg - img_fake_vgg) | |
total_loss = np.sum(feature_mse) | |
heatmap_features = np.mean(feature_mse[0, :, :, :], axis=-1) | |
# Resize heatmap back to original image dimensions | |
heatmap_original_size = cv2.resize(heatmap_features, (original_width, original_height), interpolation=cv2.INTER_LINEAR) | |
return {"SUM": {"SUM": total_loss, "HEATMAP": heatmap_original_size.astype(np.float32)}} | |
# --- Gradio Core Logic --- | |
def gather_images(task): | |
"""Loads a random pair of real and fake images from the selected task directory.""" | |
global TASK, PATH, images | |
new_path = os.path.join("datasets", task, "real") | |
if TASK != task or not images: | |
PATH = new_path | |
TASK = task | |
images = [] | |
if not os.path.isdir(PATH): | |
error_msg = f"Error: Directory for task '{task}' not found: {PATH}" | |
placeholder = np.zeros((256, 256, 3), dtype=np.uint8) | |
return placeholder, placeholder, error_msg | |
try: | |
valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff') | |
images = [os.path.join(PATH, f) for f in os.listdir(PATH) if f.lower().endswith(valid_extensions)] | |
if not images: | |
error_msg = f"Error: No valid image files found in: {PATH}" | |
placeholder = np.zeros((256, 256, 3), dtype=np.uint8) | |
return placeholder, placeholder, error_msg | |
except Exception as e: | |
error_msg = f"Error reading directory {PATH}: {e}" | |
placeholder = np.zeros((256, 256, 3), dtype=np.uint8) | |
return placeholder, placeholder, error_msg | |
if not images: | |
error_msg = f"Error: No images available for task '{task}'." | |
placeholder = np.zeros((256, 256, 3), dtype=np.uint8) | |
return placeholder, placeholder, error_msg | |
try: | |
real_img_path = random.choice(images) | |
img_filename = os.path.basename(real_img_path) | |
fake_img_path = os.path.join("datasets", task, "fake", img_filename) | |
real_img = cv2.imread(real_img_path) | |
fake_img = cv2.imread(fake_img_path) | |
placeholder_shape = (256, 256, 3) | |
if real_img is None: | |
return np.zeros(placeholder_shape, dtype=np.uint8), fake_img if fake_img is not None else np.zeros(placeholder_shape, dtype=np.uint8), f"Error: Failed to load real image: {real_img_path}" | |
if fake_img is None: | |
return real_img, np.zeros(real_img.shape, dtype=np.uint8), f"Error: Failed to load fake image: {fake_img_path}" | |
# Ensure images have the same dimensions for comparison | |
if real_img.shape != fake_img.shape: | |
target_dims = (real_img.shape[1], real_img.shape[0]) | |
fake_img = cv2.resize(fake_img, target_dims, interpolation=cv2.INTER_AREA) | |
return real_img, fake_img, f"Sample pair for '{task}' loaded successfully." | |
except Exception as e: | |
error_msg = f"An unexpected error occurred during image loading: {e}" | |
placeholder = np.zeros((256, 256, 3), dtype=np.uint8) | |
return placeholder, placeholder, error_msg | |
def run_comparison(real, fake, measurement, block_size_val): | |
"""Runs the selected comparison and returns the heatmap and a status message.""" | |
placeholder_heatmap = np.zeros((64, 64, 3), dtype=np.uint8) | |
if real is None or fake is None or not isinstance(real, np.ndarray) or not isinstance(fake, np.ndarray): | |
return placeholder_heatmap, "Error: Input image(s) missing or invalid. Please get a new sample pair." | |
if real.shape != fake.shape: | |
return placeholder_heatmap, f"Error: Input images have different shapes ({real.shape} vs {fake.shape})." | |
result = None | |
block_size_int = int(block_size_val) | |
try: | |
if measurement == "Kullback-Leibler Divergence": result = KL_divergence(real, fake, block_size=block_size_int, sum_channels=True) | |
elif measurement == "L1-Loss": result = L1_loss(real, fake, block_size=block_size_int, sum_channels=True) | |
elif measurement == "MSE": result = MSE_loss(real, fake, block_size=block_size_int, sum_channels=True) | |
elif measurement == "SSIM": result = SSIM_loss(real, fake, block_size=block_size_int, sum_channels=True) | |
elif measurement == "Cosine Similarity": result = cosine_similarity_loss(real, fake, block_size=block_size_int, sum_channels=True) | |
elif measurement == "TV": result = TV_loss(real, fake, block_size=block_size_int, sum_channels=True) | |
elif measurement == "Perceptual": | |
if perceptual_model is None: | |
return placeholder_heatmap, "Error: Perceptual model not loaded. Cannot calculate Perceptual loss." | |
result = perceptual_loss(real, fake, model=perceptual_model, block_size=block_size_int) | |
else: | |
return placeholder_heatmap, f"Error: Unknown measurement '{measurement}'." | |
except Exception as e: | |
return placeholder_heatmap, f"Error during {measurement} calculation: {e}" | |
if result is None or "SUM" not in result or "HEATMAP" not in result["SUM"]: | |
return placeholder_heatmap, f"{measurement} calculation failed or returned an invalid result structure." | |
heatmap_raw = result["SUM"]["HEATMAP"] | |
if not isinstance(heatmap_raw, np.ndarray) or heatmap_raw.size == 0: | |
return placeholder_heatmap, f"Generated heatmap is invalid or empty for {measurement}." | |
try: | |
heatmap_normalized = safe_normalize_heatmap(heatmap_raw) | |
heatmap_color = cv2.applyColorMap(heatmap_normalized, cv2.COLORMAP_HOT) | |
heatmap_rgb = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB) | |
except Exception as e: | |
return placeholder_heatmap, f"Error during heatmap coloring: {e}" | |
status_msg = f"{measurement} comparison successful." | |
return heatmap_rgb, status_msg | |
# --- Gradio UI Definition --- | |
theme = gr.themes.Soft(primary_hue="blue", secondary_hue="orange") | |
with gr.Blocks(theme=theme, css=".gradio-container { max-width: 1200px !important; margin: auto; }") as demo: | |
gr.Markdown("# GAN vs Ground Truth Image Comparison") | |
gr.Markdown("Select a dataset task, load a random sample pair (Real vs Fake), choose a comparison metric and parameters, then run the analysis to see the difference heatmap.") | |
with gr.Row(): | |
status_message = gr.Textbox(label="Status / Errors", lines=2, interactive=False, show_copy_button=True, scale=1) | |
with gr.Row(equal_height=False): | |
with gr.Column(scale=2, min_width=300): | |
real_img_display = gr.Image(type="numpy", label="Real Image (Ground Truth)", height=350, interactive=False) | |
task_dropdown = gr.Dropdown( | |
["nodules", "facades"], value=TASK, | |
info="Select the dataset task (must match directory name)", | |
label="Dataset Task" | |
) | |
sample_button = gr.Button("π Get New Sample Pair", variant="secondary") | |
with gr.Column(scale=2, min_width=300): | |
fake_img_display = gr.Image(type="numpy", label="Fake Image (Generated by GAN)", height=350, interactive=False) | |
measurement_dropdown = gr.Dropdown( | |
["Kullback-Leibler Divergence", "L1-Loss", "MSE", "SSIM", "Cosine Similarity", "TV", "Perceptual"], | |
value="Kullback-Leibler Divergence", | |
info="Select the comparison metric", | |
label="Comparison Metric" | |
) | |
block_size_slider = gr.Slider( | |
minimum=2, maximum=64, value=8, step=2, | |
info="Size of the block/window for comparison (e.g., 8x8). Affects granularity. Note: SSIM uses this as 'win_size', Perceptual ignores it.", | |
label="Block/Window Size" | |
) | |
run_button = gr.Button("π Run Comparison", variant="primary") | |
with gr.Column(scale=2, min_width=300): | |
heatmap_display = gr.Image(type="numpy", label="Comparison Heatmap (Difference)", height=350, interactive=False) | |
# --- Event Handlers --- | |
# When the "Get New Sample Pair" button is clicked | |
sample_button.click( | |
fn=gather_images, | |
inputs=[task_dropdown], | |
outputs=[real_img_display, fake_img_display, status_message] | |
) | |
# When the "Run Comparison" button is clicked | |
run_button.click( | |
fn=run_comparison, | |
inputs=[real_img_display, fake_img_display, measurement_dropdown, block_size_slider], | |
outputs=[heatmap_display, status_message] # The status message box now receives the result string | |
) | |
if __name__ == "__main__": | |
print("-------------------------------------------------------------") | |
print("Verifying VGG19 model status...") | |
if perceptual_model is None: | |
print("WARNING: VGG19 model failed to load. 'Perceptual' metric will be unavailable.") | |
else: | |
print("VGG19 model loaded successfully.") | |
print("-------------------------------------------------------------") | |
print(f"Checking initial dataset path: {PATH}") | |
if not os.path.isdir(PATH): | |
print(f"WARNING: Initial dataset path not found: {PATH}") | |
print(f" Please ensure the directory '{os.path.join('datasets', TASK, 'real')}' exists.") | |
else: | |
print("Initial dataset path seems valid.") | |
print("-------------------------------------------------------------") | |
print("Launching Gradio App...") | |
print("Access the app in your browser, usually at: http://127.0.0.1:7860") | |
print("-------------------------------------------------------------") | |
demo.launch(share=False, debug=False) | |