Alican's picture
Add application files
5548b5c
import os
import cv2
import random
import numpy as np
import gradio as gr
try:
from tensorflow.keras.models import Model
from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input
except ImportError:
try:
from keras.models import Model
from keras.applications.vgg19 import VGG19, preprocess_input
except ImportError:
pass
import matplotlib.pyplot as plt
from scipy.special import kl_div as scipy_kl_div
from skimage.metrics import structural_similarity as ssim
import warnings
# --- Configuration ---
# Set the default task.
TASK = "facades"
PATH = os.path.join("datasets", TASK, "real")
images = []
perceptual_model = None
# --- Model Loading ---
# Attempt to load the VGG19 model for the perceptual loss metric.
try:
vgg = VGG19(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
vgg.trainable = False
perceptual_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block5_conv4').output, name="perceptual_model")
except Exception as e:
perceptual_model = None
# --- Utility Functions ---
def safe_normalize_heatmap(heatmap):
"""Safely normalizes a heatmap to a 0-255 range for visualization."""
if heatmap is None or heatmap.size == 0:
return np.zeros((64, 64), dtype=np.uint8)
heatmap = heatmap.astype(np.float32)
if not np.all(np.isfinite(heatmap)):
min_val_safe = np.nanmin(heatmap[np.isfinite(heatmap)]) if np.any(np.isfinite(heatmap)) else 0
max_val_safe = np.nanmax(heatmap[np.isfinite(heatmap)]) if np.any(np.isfinite(heatmap)) else 0
heatmap = np.nan_to_num(heatmap, nan=0.0, posinf=max_val_safe, neginf=min_val_safe)
min_val = np.min(heatmap)
max_val = np.max(heatmap)
range_val = max_val - min_val
normalized_heatmap = np.zeros_like(heatmap, dtype=np.float32)
if range_val > 1e-9:
normalized_heatmap = ((heatmap - min_val) / range_val) * 255.0
normalized_heatmap = np.clip(normalized_heatmap, 0, 255)
return np.uint8(normalized_heatmap)
# --- Image Comparison Metrics ---
def KL_divergence(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
"""Calculates Kullback-Leibler Divergence between two images."""
if img_real is None or img_fake is None or img_real.shape != img_fake.shape:
return None
try:
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
except cv2.error:
return None
height, width, channels = img_real_rgb.shape
img_dict = {
"R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)},
"G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)},
"B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)},
"SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}
}
channel_keys = ["R", "G", "B"]
current_block_size = max(1, int(block_size))
if current_block_size > min(height, width):
current_block_size = min(height, width)
for channel_idx, key in enumerate(channel_keys):
channel_sum = 0.0
for i in range(0, height - current_block_size + 1, current_block_size):
for j in range(0, width - current_block_size + 1, current_block_size):
block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx].flatten() + epsilon
block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx].flatten() + epsilon
if np.sum(block_gt) > 0 and np.sum(block_pred) > 0:
block_gt_norm = block_gt / np.sum(block_gt)
block_pred_norm = block_pred / np.sum(block_pred)
kl_values = scipy_kl_div(block_gt_norm, block_pred_norm)
kl_values = np.nan_to_num(kl_values, nan=0.0, posinf=0.0, neginf=0.0)
kl_sum_block = np.sum(kl_values)
if np.isfinite(kl_sum_block):
channel_sum += kl_sum_block
mean_kl_block = kl_sum_block / max(1, current_block_size * current_block_size)
img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = mean_kl_block
if sum_channels:
img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += mean_kl_block
img_dict[key]["SUM"] = channel_sum
if sum_channels:
img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"]
img_dict["SUM"]["HEATMAP"] /= max(1, channels)
return img_dict
def L1_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
"""Calculates L1 (Mean Absolute Error) loss between two images."""
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
try:
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
except cv2.error: return None
height, width, channels = img_real_rgb.shape
img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} }
channel_keys = ["R", "G", "B"]
current_block_size = max(1, int(block_size))
if current_block_size > min(height, width): current_block_size = min(height, width)
for channel_idx, key in enumerate(channel_keys):
channel_sum = 0.0
for i in range(0, height - current_block_size + 1, current_block_size):
for j in range(0, width - current_block_size + 1, current_block_size):
block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx]
block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx]
result_block = np.abs(block_pred - block_gt)
sum_result_block = np.sum(result_block)
channel_sum += sum_result_block
mean_l1_block = sum_result_block / max(1, current_block_size * current_block_size)
img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = mean_l1_block
if sum_channels: img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += mean_l1_block
img_dict[key]["SUM"] = channel_sum
if sum_channels:
img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"]
img_dict["SUM"]["HEATMAP"] /= max(1, channels)
return img_dict
def MSE_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
"""Calculates MSE (Mean Squared Error) loss between two images."""
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
try:
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
except cv2.error: return None
height, width, channels = img_real_rgb.shape
img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} }
channel_keys = ["R", "G", "B"]
current_block_size = max(1, int(block_size))
if current_block_size > min(height, width): current_block_size = min(height, width)
for channel_idx, key in enumerate(channel_keys):
channel_sum = 0.0
for i in range(0, height - current_block_size + 1, current_block_size):
for j in range(0, width - current_block_size + 1, current_block_size):
block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx]
block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx]
result_block = np.square(block_pred - block_gt)
sum_result_block = np.sum(result_block)
channel_sum += sum_result_block
mean_mse_block = sum_result_block / max(1, current_block_size * current_block_size)
img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = mean_mse_block
if sum_channels: img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += mean_mse_block
img_dict[key]["SUM"] = channel_sum
if sum_channels:
img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"]
img_dict["SUM"]["HEATMAP"] /= max(1, channels)
return img_dict
def SSIM_loss(img_real, img_fake, block_size=7, sum_channels=False):
"""Calculates SSIM (Structural Similarity Index) loss between two images."""
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
try:
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB)
img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB)
except cv2.error: return None
height, width, channels = img_real_rgb.shape
img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} }
channel_keys = ["R", "G", "B"]
for channel_idx, key in enumerate(channel_keys):
win_size = int(block_size)
if win_size % 2 == 0: win_size += 1
win_size = max(3, min(win_size, height, width))
try:
_, ssim_map = ssim(img_real_rgb[:, :, channel_idx], img_fake_rgb[:, :, channel_idx], win_size=win_size, data_range=255, full=True, gaussian_weights=True)
ssim_loss_map = np.maximum(0.0, 1.0 - ssim_map)
img_dict[key]["SUM"] = np.sum(ssim_loss_map)
img_dict[key]["HEATMAP"] = ssim_loss_map
if sum_channels: img_dict["SUM"]["HEATMAP"] += ssim_loss_map
except ValueError:
img_dict[key]["SUM"] = 0.0
img_dict[key]["HEATMAP"] = np.zeros((height, width), dtype=np.float32)
if sum_channels:
img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"]
img_dict["SUM"]["HEATMAP"] /= max(1, channels)
return img_dict
def cosine_similarity_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
"""Calculates Cosine Similarity loss between two images."""
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
try:
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
except cv2.error: return None
height, width, channels = img_real_rgb.shape
img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} }
channel_keys = ["R", "G", "B"]
current_block_size = max(1, int(block_size))
if current_block_size > min(height, width): current_block_size = min(height, width)
for channel_idx, key in enumerate(channel_keys):
channel_sum = 0.0
for i in range(0, height - current_block_size + 1, current_block_size):
for j in range(0, width - current_block_size + 1, current_block_size):
block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx].flatten()
block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx].flatten()
dot_product = np.dot(block_pred, block_gt)
norm_pred = np.linalg.norm(block_pred)
norm_gt = np.linalg.norm(block_gt)
cosine_sim = 0.0
if norm_pred * norm_gt > epsilon:
cosine_sim = dot_product / (norm_pred * norm_gt)
elif norm_pred < epsilon and norm_gt < epsilon:
cosine_sim = 1.0
result_block = 1.0 - np.clip(cosine_sim, -1.0, 1.0)
channel_sum += result_block
img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = result_block
if sum_channels: img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += result_block
img_dict[key]["SUM"] = channel_sum
if sum_channels:
img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"]
img_dict["SUM"]["HEATMAP"] /= max(1, channels)
return img_dict
def TV_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
"""Calculates Total Variation (TV) loss between two images."""
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
try:
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
except cv2.error: return None
height, width, channels = img_real_rgb.shape
img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} }
channel_keys = ["R", "G", "B"]
current_block_size = max(2, int(block_size))
if current_block_size > min(height, width): current_block_size = min(height, width)
for channel_idx, key in enumerate(channel_keys):
channel_sum = 0.0
for i in range(0, height - current_block_size + 1, current_block_size):
for j in range(0, width - current_block_size + 1, current_block_size):
block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx]
block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx]
tv_pred = np.sum(np.abs(block_pred[:, 1:] - block_pred[:, :-1])) + np.sum(np.abs(block_pred[1:, :] - block_pred[:-1, :]))
tv_gt = np.sum(np.abs(block_gt[:, 1:] - block_gt[:, :-1])) + np.sum(np.abs(block_gt[1:, :] - block_gt[:-1, :]))
result_block = np.abs(tv_pred - tv_gt)
channel_sum += result_block
img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = result_block
if sum_channels: img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += result_block
img_dict[key]["SUM"] = channel_sum
if sum_channels:
img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"]
img_dict["SUM"]["HEATMAP"] /= max(1, channels)
return img_dict
def perceptual_loss(img_real, img_fake, model, block_size=4):
"""Calculates Perceptual loss using a pre-trained VGG19 model."""
if img_real is None or img_fake is None or model is None or img_real.shape != img_fake.shape:
return None
original_height, original_width, _ = img_real.shape
try:
target_size = (model.input_shape[1], model.input_shape[2])
cv2_target_size = (target_size[1], target_size[0])
img_real_resized = cv2.resize(img_real, cv2_target_size, interpolation=cv2.INTER_AREA)
img_fake_resized = cv2.resize(img_fake, cv2_target_size, interpolation=cv2.INTER_AREA)
img_real_processed = preprocess_input(np.expand_dims(cv2.cvtColor(img_real_resized, cv2.COLOR_BGR2RGB), axis=0))
img_fake_processed = preprocess_input(np.expand_dims(cv2.cvtColor(img_fake_resized, cv2.COLOR_BGR2RGB), axis=0))
except Exception:
return None
try:
img_real_vgg = model.predict(img_real_processed)
img_fake_vgg = model.predict(img_fake_processed)
except Exception:
return None
feature_mse = np.square(img_real_vgg - img_fake_vgg)
total_loss = np.sum(feature_mse)
heatmap_features = np.mean(feature_mse[0, :, :, :], axis=-1)
heatmap_original_size = cv2.resize(heatmap_features, (original_width, original_height), interpolation=cv2.INTER_LINEAR)
return {"SUM": {"SUM": total_loss, "HEATMAP": heatmap_original_size.astype(np.float32)}}
# --- Gradio Core Functions ---
def gather_images(task):
"""Loads a random pair of real and fake images from the selected dataset."""
global TASK, PATH, images
new_path = os.path.join("datasets", task, "real")
if TASK != task or not images:
PATH = new_path
TASK = task
images = []
if not os.path.isdir(PATH):
error_msg = f"Error: Directory for task '{task}' not found: {PATH}"
placeholder = np.zeros((256, 256, 3), dtype=np.uint8)
return placeholder, placeholder, error_msg
try:
valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff')
images = [os.path.join(PATH, f) for f in os.listdir(PATH) if f.lower().endswith(valid_extensions)]
if not images:
error_msg = f"Error: No valid image files found in: {PATH}"
placeholder = np.zeros((256, 256, 3), dtype=np.uint8)
return placeholder, placeholder, error_msg
except Exception as e:
error_msg = f"Error reading directory {PATH}: {e}"
placeholder = np.zeros((256, 256, 3), dtype=np.uint8)
return placeholder, placeholder, error_msg
if not images:
error_msg = f"Error: No images available for task '{task}'."
placeholder = np.zeros((256, 256, 3), dtype=np.uint8)
return placeholder, placeholder, error_msg
try:
real_img_path = random.choice(images)
img_filename = os.path.basename(real_img_path)
fake_img_path = os.path.join("datasets", task, "fake", img_filename)
real_img = cv2.imread(real_img_path)
fake_img = cv2.imread(fake_img_path)
placeholder_shape = (256, 256, 3)
if real_img is None:
return np.zeros(placeholder_shape, dtype=np.uint8), fake_img if fake_img is not None else np.zeros(placeholder_shape, dtype=np.uint8), f"Error: Failed to load real image: {real_img_path}"
if fake_img is None:
return real_img, np.zeros(real_img.shape, dtype=np.uint8), f"Error: Failed to load fake image: {fake_img_path}"
if real_img.shape != fake_img.shape:
target_dims = (real_img.shape[1], real_img.shape[0])
fake_img = cv2.resize(fake_img, target_dims, interpolation=cv2.INTER_AREA)
return real_img, fake_img, f"Sample pair for '{task}' loaded successfully."
except Exception as e:
error_msg = f"An unexpected error occurred during image loading: {e}"
placeholder = np.zeros((256, 256, 3), dtype=np.uint8)
return placeholder, placeholder, error_msg
def run_comparison(real, fake, measurement, block_size_val):
"""Runs the selected comparison metric and generates a heatmap."""
placeholder_heatmap = np.zeros((64, 64, 3), dtype=np.uint8)
if real is None or fake is None or not isinstance(real, np.ndarray) or not isinstance(fake, np.ndarray):
return placeholder_heatmap, "Error: Input image(s) missing or invalid. Please load or upload a pair of images."
status_msg_prefix = ""
if real.shape != fake.shape:
status_msg_prefix = f"Warning: Input images have different shapes ({real.shape} vs {fake.shape}). Resizing fake image to match real. "
target_dims = (real.shape[1], real.shape[0])
fake = cv2.resize(fake, target_dims, interpolation=cv2.INTER_AREA)
result = None
block_size_int = int(block_size_val)
try:
if measurement == "Kullback-Leibler Divergence": result = KL_divergence(real, fake, block_size=block_size_int, sum_channels=True)
elif measurement == "L1-Loss": result = L1_loss(real, fake, block_size=block_size_int, sum_channels=True)
elif measurement == "MSE": result = MSE_loss(real, fake, block_size=block_size_int, sum_channels=True)
elif measurement == "SSIM": result = SSIM_loss(real, fake, block_size=block_size_int, sum_channels=True)
elif measurement == "Cosine Similarity": result = cosine_similarity_loss(real, fake, block_size=block_size_int, sum_channels=True)
elif measurement == "TV": result = TV_loss(real, fake, block_size=block_size_int, sum_channels=True)
elif measurement == "Perceptual":
if perceptual_model is None:
return placeholder_heatmap, "Error: Perceptual model not loaded. Cannot calculate Perceptual loss."
result = perceptual_loss(real, fake, model=perceptual_model, block_size=block_size_int)
else:
return placeholder_heatmap, f"Error: Unknown measurement '{measurement}'."
except Exception as e:
return placeholder_heatmap, f"Error during {measurement} calculation: {e}"
if result is None or "SUM" not in result or "HEATMAP" not in result["SUM"]:
return placeholder_heatmap, f"{measurement} calculation failed or returned an invalid result structure."
heatmap_raw = result["SUM"]["HEATMAP"]
if not isinstance(heatmap_raw, np.ndarray) or heatmap_raw.size == 0:
return placeholder_heatmap, f"Generated heatmap is invalid or empty for {measurement}."
try:
heatmap_normalized = safe_normalize_heatmap(heatmap_raw)
heatmap_color = cv2.applyColorMap(heatmap_normalized, cv2.COLORMAP_HOT)
heatmap_rgb = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
except Exception as e:
return placeholder_heatmap, f"Error during heatmap coloring: {e}"
status_msg = status_msg_prefix + f"{measurement} comparison successful."
return heatmap_rgb, status_msg
def clear_uploads(msg):
"""Clears the image displays and updates the status message."""
return None, None, msg
def load_and_compare_initial(task):
"""Gathers initial images and runs a comparison on them at startup."""
# Step 1: Get the initial images
real_img, fake_img, gather_status = gather_images(task)
# Step 2: Run the default comparison
# We use the default values from the UI definition
default_measurement = "Cosine Similarity"
default_block_size = 8
heatmap, compare_status = run_comparison(real_img, fake_img, default_measurement, default_block_size)
# Step 3: Combine status messages and return all initial values
final_status = f"{gather_status}\n{compare_status}"
return real_img, fake_img, heatmap, final_status
# --- Gradio UI Definition ---
theme = gr.themes.Soft(primary_hue="blue", secondary_hue="orange")
with gr.Blocks(theme=theme, css=".gradio-container { max-width: 1400px !important; margin: auto; }") as demo:
gr.Markdown("# GAN vs Ground Truth Image Comparison")
gr.Markdown("Compare images by loading a sample pair from a dataset or by uploading your own. Choose a comparison metric and run the analysis to see the difference heatmap.")
status_message = gr.Textbox(label="Status / Errors", lines=2, interactive=False, show_copy_button=True)
with gr.Row(equal_height=False):
with gr.Column(scale=1, min_width=300):
gr.Markdown("### 1. Get Images")
with gr.Tabs():
with gr.TabItem("Load from Dataset"):
task_dropdown = gr.Dropdown(
["facades"], value=TASK,
info="Select the dataset task.",
label="Dataset Task"
)
sample_button = gr.Button("πŸ”„ Get New Sample Pair", variant="secondary")
with gr.TabItem("Upload Images"):
gr.Markdown("Upload your own images to compare.")
upload_real_img = gr.Image(type="numpy", label="Upload Real/Reference Image")
upload_fake_img = gr.Image(type="numpy", label="Upload Fake/Comparison Image")
with gr.Column(scale=2, min_width=600):
gr.Markdown("### 2. View Images & Run Comparison")
with gr.Row():
real_img_display = gr.Image(type="numpy", label="Real Image (Ground Truth)", height=350, interactive=False)
fake_img_display = gr.Image(type="numpy", label="Fake Image (Generated by GAN)", height=350, interactive=False)
with gr.Row():
measurement_dropdown = gr.Dropdown(
["Kullback-Leibler Divergence", "L1-Loss", "MSE", "SSIM", "Cosine Similarity", "TV", "Perceptual"],
value="Cosine Similarity",
info="Select the comparison metric.",
label="Comparison Metric",
scale=2
)
block_size_slider = gr.Slider(
minimum=2, maximum=64, value=8, step=2,
info="Size of the block/window for comparison.",
label="Block/Window Size",
scale=1
)
run_button = gr.Button("πŸ“Š Run Comparison", variant="primary")
with gr.Column(scale=1, min_width=300):
gr.Markdown("### 3. See Result")
heatmap_display = gr.Image(type="numpy", label="Comparison Heatmap (Difference)", height=350, interactive=False)
# --- Event Listeners ---
# Load initial sample and run comparison when the app starts
demo.load(
fn=load_and_compare_initial,
inputs=[task_dropdown],
outputs=[real_img_display, fake_img_display, heatmap_display, status_message]
)
sample_button.click(
fn=gather_images,
inputs=[task_dropdown],
outputs=[real_img_display, fake_img_display, status_message]
)
upload_real_img.upload(
fn=lambda x: x,
inputs=[upload_real_img],
outputs=[real_img_display]
)
upload_fake_img.upload(
fn=lambda x: x,
inputs=[upload_fake_img],
outputs=[fake_img_display]
)
run_button.click(
fn=run_comparison,
inputs=[real_img_display, fake_img_display, measurement_dropdown, block_size_slider],
outputs=[heatmap_display, status_message]
)
task_dropdown.change(
fn=clear_uploads,
inputs=[gr.Textbox(value="Task changed. Please get a new sample.", visible=False)],
outputs=[real_img_display, fake_img_display, status_message]
)
# --- Application Entry Point ---
if __name__ == "__main__":
print("-------------------------------------------------------------")
print("Verifying VGG19 model status...")
if perceptual_model is None:
print("WARNING: VGG19 model failed to load. 'Perceptual' metric will be unavailable.")
else:
print("VGG19 model loaded successfully.")
print("-------------------------------------------------------------")
print(f"Checking initial dataset path: {PATH}")
if not os.path.isdir(PATH):
print(f"WARNING: Initial dataset path not found: {PATH}")
print(f" Please ensure the directory '{os.path.join('datasets', TASK, 'real')}' exists.")
else:
print("Initial dataset path seems valid.")
print("-------------------------------------------------------------")
print("Launching Gradio App...")
print("Access the app in your browser, usually at: http://127.0.0.1:7860")
print("-------------------------------------------------------------")
demo.launch(share=False, debug=False)