Spaces:
Sleeping
Sleeping
import os | |
import cv2 | |
import random | |
import numpy as np | |
import gradio as gr | |
try: | |
from tensorflow.keras.models import Model | |
from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input | |
except ImportError: | |
try: | |
from keras.models import Model | |
from keras.applications.vgg19 import VGG19, preprocess_input | |
except ImportError: | |
pass | |
import matplotlib.pyplot as plt | |
from scipy.special import kl_div as scipy_kl_div | |
from skimage.metrics import structural_similarity as ssim | |
import warnings | |
# --- Configuration --- | |
# Set the default task. | |
TASK = "facades" | |
PATH = os.path.join("datasets", TASK, "real") | |
images = [] | |
perceptual_model = None | |
# --- Model Loading --- | |
# Attempt to load the VGG19 model for the perceptual loss metric. | |
try: | |
vgg = VGG19(weights='imagenet', include_top=False, input_shape=(224, 224, 3)) | |
vgg.trainable = False | |
perceptual_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block5_conv4').output, name="perceptual_model") | |
except Exception as e: | |
perceptual_model = None | |
# --- Utility Functions --- | |
def safe_normalize_heatmap(heatmap): | |
"""Safely normalizes a heatmap to a 0-255 range for visualization.""" | |
if heatmap is None or heatmap.size == 0: | |
return np.zeros((64, 64), dtype=np.uint8) | |
heatmap = heatmap.astype(np.float32) | |
if not np.all(np.isfinite(heatmap)): | |
min_val_safe = np.nanmin(heatmap[np.isfinite(heatmap)]) if np.any(np.isfinite(heatmap)) else 0 | |
max_val_safe = np.nanmax(heatmap[np.isfinite(heatmap)]) if np.any(np.isfinite(heatmap)) else 0 | |
heatmap = np.nan_to_num(heatmap, nan=0.0, posinf=max_val_safe, neginf=min_val_safe) | |
min_val = np.min(heatmap) | |
max_val = np.max(heatmap) | |
range_val = max_val - min_val | |
normalized_heatmap = np.zeros_like(heatmap, dtype=np.float32) | |
if range_val > 1e-9: | |
normalized_heatmap = ((heatmap - min_val) / range_val) * 255.0 | |
normalized_heatmap = np.clip(normalized_heatmap, 0, 255) | |
return np.uint8(normalized_heatmap) | |
# --- Image Comparison Metrics --- | |
def KL_divergence(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False): | |
"""Calculates Kullback-Leibler Divergence between two images.""" | |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: | |
return None | |
try: | |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
except cv2.error: | |
return None | |
height, width, channels = img_real_rgb.shape | |
img_dict = { | |
"R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, | |
"G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, | |
"B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, | |
"SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} | |
} | |
channel_keys = ["R", "G", "B"] | |
current_block_size = max(1, int(block_size)) | |
if current_block_size > min(height, width): | |
current_block_size = min(height, width) | |
for channel_idx, key in enumerate(channel_keys): | |
channel_sum = 0.0 | |
for i in range(0, height - current_block_size + 1, current_block_size): | |
for j in range(0, width - current_block_size + 1, current_block_size): | |
block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx].flatten() + epsilon | |
block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx].flatten() + epsilon | |
if np.sum(block_gt) > 0 and np.sum(block_pred) > 0: | |
block_gt_norm = block_gt / np.sum(block_gt) | |
block_pred_norm = block_pred / np.sum(block_pred) | |
kl_values = scipy_kl_div(block_gt_norm, block_pred_norm) | |
kl_values = np.nan_to_num(kl_values, nan=0.0, posinf=0.0, neginf=0.0) | |
kl_sum_block = np.sum(kl_values) | |
if np.isfinite(kl_sum_block): | |
channel_sum += kl_sum_block | |
mean_kl_block = kl_sum_block / max(1, current_block_size * current_block_size) | |
img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = mean_kl_block | |
if sum_channels: | |
img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += mean_kl_block | |
img_dict[key]["SUM"] = channel_sum | |
if sum_channels: | |
img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] | |
img_dict["SUM"]["HEATMAP"] /= max(1, channels) | |
return img_dict | |
def L1_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False): | |
"""Calculates L1 (Mean Absolute Error) loss between two images.""" | |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None | |
try: | |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
except cv2.error: return None | |
height, width, channels = img_real_rgb.shape | |
img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} } | |
channel_keys = ["R", "G", "B"] | |
current_block_size = max(1, int(block_size)) | |
if current_block_size > min(height, width): current_block_size = min(height, width) | |
for channel_idx, key in enumerate(channel_keys): | |
channel_sum = 0.0 | |
for i in range(0, height - current_block_size + 1, current_block_size): | |
for j in range(0, width - current_block_size + 1, current_block_size): | |
block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] | |
block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] | |
result_block = np.abs(block_pred - block_gt) | |
sum_result_block = np.sum(result_block) | |
channel_sum += sum_result_block | |
mean_l1_block = sum_result_block / max(1, current_block_size * current_block_size) | |
img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = mean_l1_block | |
if sum_channels: img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += mean_l1_block | |
img_dict[key]["SUM"] = channel_sum | |
if sum_channels: | |
img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] | |
img_dict["SUM"]["HEATMAP"] /= max(1, channels) | |
return img_dict | |
def MSE_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False): | |
"""Calculates MSE (Mean Squared Error) loss between two images.""" | |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None | |
try: | |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
except cv2.error: return None | |
height, width, channels = img_real_rgb.shape | |
img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} } | |
channel_keys = ["R", "G", "B"] | |
current_block_size = max(1, int(block_size)) | |
if current_block_size > min(height, width): current_block_size = min(height, width) | |
for channel_idx, key in enumerate(channel_keys): | |
channel_sum = 0.0 | |
for i in range(0, height - current_block_size + 1, current_block_size): | |
for j in range(0, width - current_block_size + 1, current_block_size): | |
block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] | |
block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] | |
result_block = np.square(block_pred - block_gt) | |
sum_result_block = np.sum(result_block) | |
channel_sum += sum_result_block | |
mean_mse_block = sum_result_block / max(1, current_block_size * current_block_size) | |
img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = mean_mse_block | |
if sum_channels: img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += mean_mse_block | |
img_dict[key]["SUM"] = channel_sum | |
if sum_channels: | |
img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] | |
img_dict["SUM"]["HEATMAP"] /= max(1, channels) | |
return img_dict | |
def SSIM_loss(img_real, img_fake, block_size=7, sum_channels=False): | |
"""Calculates SSIM (Structural Similarity Index) loss between two images.""" | |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None | |
try: | |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB) | |
img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB) | |
except cv2.error: return None | |
height, width, channels = img_real_rgb.shape | |
img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} } | |
channel_keys = ["R", "G", "B"] | |
for channel_idx, key in enumerate(channel_keys): | |
win_size = int(block_size) | |
if win_size % 2 == 0: win_size += 1 | |
win_size = max(3, min(win_size, height, width)) | |
try: | |
_, ssim_map = ssim(img_real_rgb[:, :, channel_idx], img_fake_rgb[:, :, channel_idx], win_size=win_size, data_range=255, full=True, gaussian_weights=True) | |
ssim_loss_map = np.maximum(0.0, 1.0 - ssim_map) | |
img_dict[key]["SUM"] = np.sum(ssim_loss_map) | |
img_dict[key]["HEATMAP"] = ssim_loss_map | |
if sum_channels: img_dict["SUM"]["HEATMAP"] += ssim_loss_map | |
except ValueError: | |
img_dict[key]["SUM"] = 0.0 | |
img_dict[key]["HEATMAP"] = np.zeros((height, width), dtype=np.float32) | |
if sum_channels: | |
img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] | |
img_dict["SUM"]["HEATMAP"] /= max(1, channels) | |
return img_dict | |
def cosine_similarity_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False): | |
"""Calculates Cosine Similarity loss between two images.""" | |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None | |
try: | |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
except cv2.error: return None | |
height, width, channels = img_real_rgb.shape | |
img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} } | |
channel_keys = ["R", "G", "B"] | |
current_block_size = max(1, int(block_size)) | |
if current_block_size > min(height, width): current_block_size = min(height, width) | |
for channel_idx, key in enumerate(channel_keys): | |
channel_sum = 0.0 | |
for i in range(0, height - current_block_size + 1, current_block_size): | |
for j in range(0, width - current_block_size + 1, current_block_size): | |
block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx].flatten() | |
block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx].flatten() | |
dot_product = np.dot(block_pred, block_gt) | |
norm_pred = np.linalg.norm(block_pred) | |
norm_gt = np.linalg.norm(block_gt) | |
cosine_sim = 0.0 | |
if norm_pred * norm_gt > epsilon: | |
cosine_sim = dot_product / (norm_pred * norm_gt) | |
elif norm_pred < epsilon and norm_gt < epsilon: | |
cosine_sim = 1.0 | |
result_block = 1.0 - np.clip(cosine_sim, -1.0, 1.0) | |
channel_sum += result_block | |
img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = result_block | |
if sum_channels: img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += result_block | |
img_dict[key]["SUM"] = channel_sum | |
if sum_channels: | |
img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] | |
img_dict["SUM"]["HEATMAP"] /= max(1, channels) | |
return img_dict | |
def TV_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False): | |
"""Calculates Total Variation (TV) loss between two images.""" | |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None | |
try: | |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
except cv2.error: return None | |
height, width, channels = img_real_rgb.shape | |
img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} } | |
channel_keys = ["R", "G", "B"] | |
current_block_size = max(2, int(block_size)) | |
if current_block_size > min(height, width): current_block_size = min(height, width) | |
for channel_idx, key in enumerate(channel_keys): | |
channel_sum = 0.0 | |
for i in range(0, height - current_block_size + 1, current_block_size): | |
for j in range(0, width - current_block_size + 1, current_block_size): | |
block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] | |
block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] | |
tv_pred = np.sum(np.abs(block_pred[:, 1:] - block_pred[:, :-1])) + np.sum(np.abs(block_pred[1:, :] - block_pred[:-1, :])) | |
tv_gt = np.sum(np.abs(block_gt[:, 1:] - block_gt[:, :-1])) + np.sum(np.abs(block_gt[1:, :] - block_gt[:-1, :])) | |
result_block = np.abs(tv_pred - tv_gt) | |
channel_sum += result_block | |
img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = result_block | |
if sum_channels: img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += result_block | |
img_dict[key]["SUM"] = channel_sum | |
if sum_channels: | |
img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] | |
img_dict["SUM"]["HEATMAP"] /= max(1, channels) | |
return img_dict | |
def perceptual_loss(img_real, img_fake, model, block_size=4): | |
"""Calculates Perceptual loss using a pre-trained VGG19 model.""" | |
if img_real is None or img_fake is None or model is None or img_real.shape != img_fake.shape: | |
return None | |
original_height, original_width, _ = img_real.shape | |
try: | |
target_size = (model.input_shape[1], model.input_shape[2]) | |
cv2_target_size = (target_size[1], target_size[0]) | |
img_real_resized = cv2.resize(img_real, cv2_target_size, interpolation=cv2.INTER_AREA) | |
img_fake_resized = cv2.resize(img_fake, cv2_target_size, interpolation=cv2.INTER_AREA) | |
img_real_processed = preprocess_input(np.expand_dims(cv2.cvtColor(img_real_resized, cv2.COLOR_BGR2RGB), axis=0)) | |
img_fake_processed = preprocess_input(np.expand_dims(cv2.cvtColor(img_fake_resized, cv2.COLOR_BGR2RGB), axis=0)) | |
except Exception: | |
return None | |
try: | |
img_real_vgg = model.predict(img_real_processed) | |
img_fake_vgg = model.predict(img_fake_processed) | |
except Exception: | |
return None | |
feature_mse = np.square(img_real_vgg - img_fake_vgg) | |
total_loss = np.sum(feature_mse) | |
heatmap_features = np.mean(feature_mse[0, :, :, :], axis=-1) | |
heatmap_original_size = cv2.resize(heatmap_features, (original_width, original_height), interpolation=cv2.INTER_LINEAR) | |
return {"SUM": {"SUM": total_loss, "HEATMAP": heatmap_original_size.astype(np.float32)}} | |
# --- Gradio Core Functions --- | |
def gather_images(task): | |
"""Loads a random pair of real and fake images from the selected dataset.""" | |
global TASK, PATH, images | |
new_path = os.path.join("datasets", task, "real") | |
if TASK != task or not images: | |
PATH = new_path | |
TASK = task | |
images = [] | |
if not os.path.isdir(PATH): | |
error_msg = f"Error: Directory for task '{task}' not found: {PATH}" | |
placeholder = np.zeros((256, 256, 3), dtype=np.uint8) | |
return placeholder, placeholder, error_msg | |
try: | |
valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff') | |
images = [os.path.join(PATH, f) for f in os.listdir(PATH) if f.lower().endswith(valid_extensions)] | |
if not images: | |
error_msg = f"Error: No valid image files found in: {PATH}" | |
placeholder = np.zeros((256, 256, 3), dtype=np.uint8) | |
return placeholder, placeholder, error_msg | |
except Exception as e: | |
error_msg = f"Error reading directory {PATH}: {e}" | |
placeholder = np.zeros((256, 256, 3), dtype=np.uint8) | |
return placeholder, placeholder, error_msg | |
if not images: | |
error_msg = f"Error: No images available for task '{task}'." | |
placeholder = np.zeros((256, 256, 3), dtype=np.uint8) | |
return placeholder, placeholder, error_msg | |
try: | |
real_img_path = random.choice(images) | |
img_filename = os.path.basename(real_img_path) | |
fake_img_path = os.path.join("datasets", task, "fake", img_filename) | |
real_img = cv2.imread(real_img_path) | |
fake_img = cv2.imread(fake_img_path) | |
placeholder_shape = (256, 256, 3) | |
if real_img is None: | |
return np.zeros(placeholder_shape, dtype=np.uint8), fake_img if fake_img is not None else np.zeros(placeholder_shape, dtype=np.uint8), f"Error: Failed to load real image: {real_img_path}" | |
if fake_img is None: | |
return real_img, np.zeros(real_img.shape, dtype=np.uint8), f"Error: Failed to load fake image: {fake_img_path}" | |
if real_img.shape != fake_img.shape: | |
target_dims = (real_img.shape[1], real_img.shape[0]) | |
fake_img = cv2.resize(fake_img, target_dims, interpolation=cv2.INTER_AREA) | |
return real_img, fake_img, f"Sample pair for '{task}' loaded successfully." | |
except Exception as e: | |
error_msg = f"An unexpected error occurred during image loading: {e}" | |
placeholder = np.zeros((256, 256, 3), dtype=np.uint8) | |
return placeholder, placeholder, error_msg | |
def run_comparison(real, fake, measurement, block_size_val): | |
"""Runs the selected comparison metric and generates a heatmap.""" | |
placeholder_heatmap = np.zeros((64, 64, 3), dtype=np.uint8) | |
if real is None or fake is None or not isinstance(real, np.ndarray) or not isinstance(fake, np.ndarray): | |
return placeholder_heatmap, "Error: Input image(s) missing or invalid. Please load or upload a pair of images." | |
status_msg_prefix = "" | |
if real.shape != fake.shape: | |
status_msg_prefix = f"Warning: Input images have different shapes ({real.shape} vs {fake.shape}). Resizing fake image to match real. " | |
target_dims = (real.shape[1], real.shape[0]) | |
fake = cv2.resize(fake, target_dims, interpolation=cv2.INTER_AREA) | |
result = None | |
block_size_int = int(block_size_val) | |
try: | |
if measurement == "Kullback-Leibler Divergence": result = KL_divergence(real, fake, block_size=block_size_int, sum_channels=True) | |
elif measurement == "L1-Loss": result = L1_loss(real, fake, block_size=block_size_int, sum_channels=True) | |
elif measurement == "MSE": result = MSE_loss(real, fake, block_size=block_size_int, sum_channels=True) | |
elif measurement == "SSIM": result = SSIM_loss(real, fake, block_size=block_size_int, sum_channels=True) | |
elif measurement == "Cosine Similarity": result = cosine_similarity_loss(real, fake, block_size=block_size_int, sum_channels=True) | |
elif measurement == "TV": result = TV_loss(real, fake, block_size=block_size_int, sum_channels=True) | |
elif measurement == "Perceptual": | |
if perceptual_model is None: | |
return placeholder_heatmap, "Error: Perceptual model not loaded. Cannot calculate Perceptual loss." | |
result = perceptual_loss(real, fake, model=perceptual_model, block_size=block_size_int) | |
else: | |
return placeholder_heatmap, f"Error: Unknown measurement '{measurement}'." | |
except Exception as e: | |
return placeholder_heatmap, f"Error during {measurement} calculation: {e}" | |
if result is None or "SUM" not in result or "HEATMAP" not in result["SUM"]: | |
return placeholder_heatmap, f"{measurement} calculation failed or returned an invalid result structure." | |
heatmap_raw = result["SUM"]["HEATMAP"] | |
if not isinstance(heatmap_raw, np.ndarray) or heatmap_raw.size == 0: | |
return placeholder_heatmap, f"Generated heatmap is invalid or empty for {measurement}." | |
try: | |
heatmap_normalized = safe_normalize_heatmap(heatmap_raw) | |
heatmap_color = cv2.applyColorMap(heatmap_normalized, cv2.COLORMAP_HOT) | |
heatmap_rgb = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB) | |
except Exception as e: | |
return placeholder_heatmap, f"Error during heatmap coloring: {e}" | |
status_msg = status_msg_prefix + f"{measurement} comparison successful." | |
return heatmap_rgb, status_msg | |
def clear_uploads(msg): | |
"""Clears the image displays and updates the status message.""" | |
return None, None, msg | |
def load_and_compare_initial(task): | |
"""Gathers initial images and runs a comparison on them at startup.""" | |
# Step 1: Get the initial images | |
real_img, fake_img, gather_status = gather_images(task) | |
# Step 2: Run the default comparison | |
# We use the default values from the UI definition | |
default_measurement = "Cosine Similarity" | |
default_block_size = 8 | |
heatmap, compare_status = run_comparison(real_img, fake_img, default_measurement, default_block_size) | |
# Step 3: Combine status messages and return all initial values | |
final_status = f"{gather_status}\n{compare_status}" | |
return real_img, fake_img, heatmap, final_status | |
# --- Gradio UI Definition --- | |
theme = gr.themes.Soft(primary_hue="blue", secondary_hue="orange") | |
with gr.Blocks(theme=theme, css=".gradio-container { max-width: 1400px !important; margin: auto; }") as demo: | |
gr.Markdown("# GAN vs Ground Truth Image Comparison") | |
gr.Markdown("Compare images by loading a sample pair from a dataset or by uploading your own. Choose a comparison metric and run the analysis to see the difference heatmap.") | |
status_message = gr.Textbox(label="Status / Errors", lines=2, interactive=False, show_copy_button=True) | |
with gr.Row(equal_height=False): | |
with gr.Column(scale=1, min_width=300): | |
gr.Markdown("### 1. Get Images") | |
with gr.Tabs(): | |
with gr.TabItem("Load from Dataset"): | |
task_dropdown = gr.Dropdown( | |
["facades"], value=TASK, | |
info="Select the dataset task.", | |
label="Dataset Task" | |
) | |
sample_button = gr.Button("π Get New Sample Pair", variant="secondary") | |
with gr.TabItem("Upload Images"): | |
gr.Markdown("Upload your own images to compare.") | |
upload_real_img = gr.Image(type="numpy", label="Upload Real/Reference Image") | |
upload_fake_img = gr.Image(type="numpy", label="Upload Fake/Comparison Image") | |
with gr.Column(scale=2, min_width=600): | |
gr.Markdown("### 2. View Images & Run Comparison") | |
with gr.Row(): | |
real_img_display = gr.Image(type="numpy", label="Real Image (Ground Truth)", height=350, interactive=False) | |
fake_img_display = gr.Image(type="numpy", label="Fake Image (Generated by GAN)", height=350, interactive=False) | |
with gr.Row(): | |
measurement_dropdown = gr.Dropdown( | |
["Kullback-Leibler Divergence", "L1-Loss", "MSE", "SSIM", "Cosine Similarity", "TV", "Perceptual"], | |
value="Cosine Similarity", | |
info="Select the comparison metric.", | |
label="Comparison Metric", | |
scale=2 | |
) | |
block_size_slider = gr.Slider( | |
minimum=2, maximum=64, value=8, step=2, | |
info="Size of the block/window for comparison.", | |
label="Block/Window Size", | |
scale=1 | |
) | |
run_button = gr.Button("π Run Comparison", variant="primary") | |
with gr.Column(scale=1, min_width=300): | |
gr.Markdown("### 3. See Result") | |
heatmap_display = gr.Image(type="numpy", label="Comparison Heatmap (Difference)", height=350, interactive=False) | |
# --- Event Listeners --- | |
# Load initial sample and run comparison when the app starts | |
demo.load( | |
fn=load_and_compare_initial, | |
inputs=[task_dropdown], | |
outputs=[real_img_display, fake_img_display, heatmap_display, status_message] | |
) | |
sample_button.click( | |
fn=gather_images, | |
inputs=[task_dropdown], | |
outputs=[real_img_display, fake_img_display, status_message] | |
) | |
upload_real_img.upload( | |
fn=lambda x: x, | |
inputs=[upload_real_img], | |
outputs=[real_img_display] | |
) | |
upload_fake_img.upload( | |
fn=lambda x: x, | |
inputs=[upload_fake_img], | |
outputs=[fake_img_display] | |
) | |
run_button.click( | |
fn=run_comparison, | |
inputs=[real_img_display, fake_img_display, measurement_dropdown, block_size_slider], | |
outputs=[heatmap_display, status_message] | |
) | |
task_dropdown.change( | |
fn=clear_uploads, | |
inputs=[gr.Textbox(value="Task changed. Please get a new sample.", visible=False)], | |
outputs=[real_img_display, fake_img_display, status_message] | |
) | |
# --- Application Entry Point --- | |
if __name__ == "__main__": | |
print("-------------------------------------------------------------") | |
print("Verifying VGG19 model status...") | |
if perceptual_model is None: | |
print("WARNING: VGG19 model failed to load. 'Perceptual' metric will be unavailable.") | |
else: | |
print("VGG19 model loaded successfully.") | |
print("-------------------------------------------------------------") | |
print(f"Checking initial dataset path: {PATH}") | |
if not os.path.isdir(PATH): | |
print(f"WARNING: Initial dataset path not found: {PATH}") | |
print(f" Please ensure the directory '{os.path.join('datasets', TASK, 'real')}' exists.") | |
else: | |
print("Initial dataset path seems valid.") | |
print("-------------------------------------------------------------") | |
print("Launching Gradio App...") | |
print("Access the app in your browser, usually at: http://127.0.0.1:7860") | |
print("-------------------------------------------------------------") | |
demo.launch(share=False, debug=False) | |