import os import cv2 import random import numpy as np import gradio as gr try: from tensorflow.keras.models import Model from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input except ImportError: try: from keras.models import Model from keras.applications.vgg19 import VGG19, preprocess_input except ImportError: pass import matplotlib.pyplot as plt from scipy.special import kl_div as scipy_kl_div from skimage.metrics import structural_similarity as ssim import warnings # --- Configuration --- # Set the default task. TASK = "facades" PATH = os.path.join("datasets", TASK, "real") images = [] perceptual_model = None # --- Model Loading --- # Attempt to load the VGG19 model for the perceptual loss metric. try: vgg = VGG19(weights='imagenet', include_top=False, input_shape=(224, 224, 3)) vgg.trainable = False perceptual_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block5_conv4').output, name="perceptual_model") except Exception as e: perceptual_model = None # --- Utility Functions --- def safe_normalize_heatmap(heatmap): """Safely normalizes a heatmap to a 0-255 range for visualization.""" if heatmap is None or heatmap.size == 0: return np.zeros((64, 64), dtype=np.uint8) heatmap = heatmap.astype(np.float32) if not np.all(np.isfinite(heatmap)): min_val_safe = np.nanmin(heatmap[np.isfinite(heatmap)]) if np.any(np.isfinite(heatmap)) else 0 max_val_safe = np.nanmax(heatmap[np.isfinite(heatmap)]) if np.any(np.isfinite(heatmap)) else 0 heatmap = np.nan_to_num(heatmap, nan=0.0, posinf=max_val_safe, neginf=min_val_safe) min_val = np.min(heatmap) max_val = np.max(heatmap) range_val = max_val - min_val normalized_heatmap = np.zeros_like(heatmap, dtype=np.float32) if range_val > 1e-9: normalized_heatmap = ((heatmap - min_val) / range_val) * 255.0 normalized_heatmap = np.clip(normalized_heatmap, 0, 255) return np.uint8(normalized_heatmap) # --- Image Comparison Metrics --- def KL_divergence(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False): """Calculates Kullback-Leibler Divergence between two images.""" if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None try: img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 except cv2.error: return None height, width, channels = img_real_rgb.shape img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} } channel_keys = ["R", "G", "B"] current_block_size = max(1, int(block_size)) if current_block_size > min(height, width): current_block_size = min(height, width) for channel_idx, key in enumerate(channel_keys): channel_sum = 0.0 for i in range(0, height - current_block_size + 1, current_block_size): for j in range(0, width - current_block_size + 1, current_block_size): block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx].flatten() + epsilon block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx].flatten() + epsilon if np.sum(block_gt) > 0 and np.sum(block_pred) > 0: block_gt_norm = block_gt / np.sum(block_gt) block_pred_norm = block_pred / np.sum(block_pred) kl_values = scipy_kl_div(block_gt_norm, block_pred_norm) kl_values = np.nan_to_num(kl_values, nan=0.0, posinf=0.0, neginf=0.0) kl_sum_block = np.sum(kl_values) if np.isfinite(kl_sum_block): channel_sum += kl_sum_block mean_kl_block = kl_sum_block / max(1, current_block_size * current_block_size) img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = mean_kl_block if sum_channels: img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += mean_kl_block img_dict[key]["SUM"] = channel_sum if sum_channels: img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] img_dict["SUM"]["HEATMAP"] /= max(1, channels) return img_dict def L1_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False): """Calculates L1 (Mean Absolute Error) loss between two images.""" if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None try: img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 except cv2.error: return None height, width, channels = img_real_rgb.shape img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} } channel_keys = ["R", "G", "B"] current_block_size = max(1, int(block_size)) if current_block_size > min(height, width): current_block_size = min(height, width) for channel_idx, key in enumerate(channel_keys): channel_sum = 0.0 for i in range(0, height - current_block_size + 1, current_block_size): for j in range(0, width - current_block_size + 1, current_block_size): block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] result_block = np.abs(block_pred - block_gt) sum_result_block = np.sum(result_block) channel_sum += sum_result_block mean_l1_block = sum_result_block / max(1, current_block_size * current_block_size) img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = mean_l1_block if sum_channels: img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += mean_l1_block img_dict[key]["SUM"] = channel_sum if sum_channels: img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] img_dict["SUM"]["HEATMAP"] /= max(1, channels) return img_dict def MSE_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False): """Calculates MSE (Mean Squared Error) loss between two images.""" if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None try: img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 except cv2.error: return None height, width, channels = img_real_rgb.shape img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} } channel_keys = ["R", "G", "B"] current_block_size = max(1, int(block_size)) if current_block_size > min(height, width): current_block_size = min(height, width) for channel_idx, key in enumerate(channel_keys): channel_sum = 0.0 for i in range(0, height - current_block_size + 1, current_block_size): for j in range(0, width - current_block_size + 1, current_block_size): block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] result_block = np.square(block_pred - block_gt) sum_result_block = np.sum(result_block) channel_sum += sum_result_block mean_mse_block = sum_result_block / max(1, current_block_size * current_block_size) img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = mean_mse_block if sum_channels: img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += mean_mse_block img_dict[key]["SUM"] = channel_sum if sum_channels: img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] img_dict["SUM"]["HEATMAP"] /= max(1, channels) return img_dict def SSIM_loss(img_real, img_fake, block_size=7, sum_channels=False): """Calculates SSIM (Structural Similarity Index) loss between two images.""" if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None try: img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB) img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB) except cv2.error: return None height, width, channels = img_real_rgb.shape img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} } channel_keys = ["R", "G", "B"] for channel_idx, key in enumerate(channel_keys): win_size = int(block_size) if win_size % 2 == 0: win_size += 1 win_size = max(3, min(win_size, height, width)) try: _, ssim_map = ssim(img_real_rgb[:, :, channel_idx], img_fake_rgb[:, :, channel_idx], win_size=win_size, data_range=255, full=True, gaussian_weights=True) ssim_loss_map = np.maximum(0.0, 1.0 - ssim_map) img_dict[key]["SUM"] = np.sum(ssim_loss_map) img_dict[key]["HEATMAP"] = ssim_loss_map if sum_channels: img_dict["SUM"]["HEATMAP"] += ssim_loss_map except ValueError: img_dict[key]["SUM"] = 0.0 img_dict[key]["HEATMAP"] = np.zeros((height, width), dtype=np.float32) if sum_channels: img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] img_dict["SUM"]["HEATMAP"] /= max(1, channels) return img_dict def cosine_similarity_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False): """Calculates Cosine Similarity loss between two images.""" if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None try: img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 except cv2.error: return None height, width, channels = img_real_rgb.shape img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} } channel_keys = ["R", "G", "B"] current_block_size = max(1, int(block_size)) if current_block_size > min(height, width): current_block_size = min(height, width) for channel_idx, key in enumerate(channel_keys): channel_sum = 0.0 for i in range(0, height - current_block_size + 1, current_block_size): for j in range(0, width - current_block_size + 1, current_block_size): block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx].flatten() block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx].flatten() dot_product = np.dot(block_pred, block_gt) norm_pred = np.linalg.norm(block_pred) norm_gt = np.linalg.norm(block_gt) cosine_sim = 0.0 if norm_pred * norm_gt > epsilon: cosine_sim = dot_product / (norm_pred * norm_gt) elif norm_pred < epsilon and norm_gt < epsilon: cosine_sim = 1.0 result_block = 1.0 - np.clip(cosine_sim, -1.0, 1.0) channel_sum += result_block img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = result_block if sum_channels: img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += result_block img_dict[key]["SUM"] = channel_sum if sum_channels: img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] img_dict["SUM"]["HEATMAP"] /= max(1, channels) return img_dict def TV_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False): """Calculates Total Variation (TV) loss between two images.""" if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None try: img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 except cv2.error: return None height, width, channels = img_real_rgb.shape img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} } channel_keys = ["R", "G", "B"] current_block_size = max(2, int(block_size)) if current_block_size > min(height, width): current_block_size = min(height, width) for channel_idx, key in enumerate(channel_keys): channel_sum = 0.0 for i in range(0, height - current_block_size + 1, current_block_size): for j in range(0, width - current_block_size + 1, current_block_size): block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] tv_pred = np.sum(np.abs(block_pred[:, 1:] - block_pred[:, :-1])) + np.sum(np.abs(block_pred[1:, :] - block_pred[:-1, :])) tv_gt = np.sum(np.abs(block_gt[:, 1:] - block_gt[:, :-1])) + np.sum(np.abs(block_gt[1:, :] - block_gt[:-1, :])) result_block = np.abs(tv_pred - tv_gt) channel_sum += result_block img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = result_block if sum_channels: img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += result_block img_dict[key]["SUM"] = channel_sum if sum_channels: img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] img_dict["SUM"]["HEATMAP"] /= max(1, channels) return img_dict def perceptual_loss(img_real, img_fake, model, block_size=4): """Calculates Perceptual loss using a pre-trained VGG19 model.""" if img_real is None or img_fake is None or model is None or img_real.shape != img_fake.shape: return None original_height, original_width, _ = img_real.shape try: target_size = (model.input_shape[1], model.input_shape[2]) cv2_target_size = (target_size[1], target_size[0]) img_real_resized = cv2.resize(img_real, cv2_target_size, interpolation=cv2.INTER_AREA) img_fake_resized = cv2.resize(img_fake, cv2_target_size, interpolation=cv2.INTER_AREA) img_real_processed = preprocess_input(np.expand_dims(cv2.cvtColor(img_real_resized, cv2.COLOR_BGR2RGB), axis=0)) img_fake_processed = preprocess_input(np.expand_dims(cv2.cvtColor(img_fake_resized, cv2.COLOR_BGR2RGB), axis=0)) except Exception: return None try: img_real_vgg = model.predict(img_real_processed) img_fake_vgg = model.predict(img_fake_processed) except Exception: return None feature_mse = np.square(img_real_vgg - img_fake_vgg) total_loss = np.sum(feature_mse) heatmap_features = np.mean(feature_mse[0, :, :, :], axis=-1) heatmap_original_size = cv2.resize(heatmap_features, (original_width, original_height), interpolation=cv2.INTER_LINEAR) return {"SUM": {"SUM": total_loss, "HEATMAP": heatmap_original_size.astype(np.float32)}} # --- Gradio Core Functions --- def gather_images(task): """Loads a random pair of real and fake images from the selected dataset.""" global TASK, PATH, images new_path = os.path.join("datasets", task, "real") if TASK != task or not images: PATH = new_path TASK = task images = [] if not os.path.isdir(PATH): error_msg = f"Error: Directory for task '{task}' not found: {PATH}" placeholder = np.zeros((256, 256, 3), dtype=np.uint8) return placeholder, placeholder, error_msg try: valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff') images = [os.path.join(PATH, f) for f in os.listdir(PATH) if f.lower().endswith(valid_extensions)] if not images: error_msg = f"Error: No valid image files found in: {PATH}" placeholder = np.zeros((256, 256, 3), dtype=np.uint8) return placeholder, placeholder, error_msg except Exception as e: error_msg = f"Error reading directory {PATH}: {e}" placeholder = np.zeros((256, 256, 3), dtype=np.uint8) return placeholder, placeholder, error_msg if not images: error_msg = f"Error: No images available for task '{task}'." placeholder = np.zeros((256, 256, 3), dtype=np.uint8) return placeholder, placeholder, error_msg try: real_img_path = random.choice(images) img_filename = os.path.basename(real_img_path) fake_img_path = os.path.join("datasets", task, "fake", img_filename) real_img = cv2.imread(real_img_path) fake_img = cv2.imread(fake_img_path) placeholder_shape = (256, 256, 3) if real_img is None: return np.zeros(placeholder_shape, dtype=np.uint8), fake_img if fake_img is not None else np.zeros(placeholder_shape, dtype=np.uint8), f"Error: Failed to load real image: {real_img_path}" if fake_img is None: return real_img, np.zeros(real_img.shape, dtype=np.uint8), f"Error: Failed to load fake image: {fake_img_path}" if real_img.shape != fake_img.shape: target_dims = (real_img.shape[1], real_img.shape[0]) fake_img = cv2.resize(fake_img, target_dims, interpolation=cv2.INTER_AREA) return real_img, fake_img, f"Sample pair for '{task}' loaded successfully." except Exception as e: error_msg = f"An unexpected error occurred during image loading: {e}" placeholder = np.zeros((256, 256, 3), dtype=np.uint8) return placeholder, placeholder, error_msg def run_comparison(real, fake, measurement, block_size_val): """Runs the selected comparison metric and generates a heatmap.""" placeholder_heatmap = np.zeros((64, 64, 3), dtype=np.uint8) if real is None or fake is None or not isinstance(real, np.ndarray) or not isinstance(fake, np.ndarray): return placeholder_heatmap, "Error: Input image(s) missing or invalid. Please load or upload a pair of images." status_msg_prefix = "" if real.shape != fake.shape: status_msg_prefix = f"Warning: Input images have different shapes ({real.shape} vs {fake.shape}). Resizing fake image to match real. " target_dims = (real.shape[1], real.shape[0]) fake = cv2.resize(fake, target_dims, interpolation=cv2.INTER_AREA) result = None block_size_int = int(block_size_val) try: if measurement == "Kullback-Leibler Divergence": result = KL_divergence(real, fake, block_size=block_size_int, sum_channels=True) elif measurement == "L1-Loss": result = L1_loss(real, fake, block_size=block_size_int, sum_channels=True) elif measurement == "MSE": result = MSE_loss(real, fake, block_size=block_size_int, sum_channels=True) elif measurement == "SSIM": result = SSIM_loss(real, fake, block_size=block_size_int, sum_channels=True) elif measurement == "Cosine Similarity": result = cosine_similarity_loss(real, fake, block_size=block_size_int, sum_channels=True) elif measurement == "TV": result = TV_loss(real, fake, block_size=block_size_int, sum_channels=True) elif measurement == "Perceptual": if perceptual_model is None: return placeholder_heatmap, "Error: Perceptual model not loaded. Cannot calculate Perceptual loss." result = perceptual_loss(real, fake, model=perceptual_model, block_size=block_size_int) else: return placeholder_heatmap, f"Error: Unknown measurement '{measurement}'." except Exception as e: return placeholder_heatmap, f"Error during {measurement} calculation: {e}" if result is None or "SUM" not in result or "HEATMAP" not in result["SUM"]: return placeholder_heatmap, f"{measurement} calculation failed or returned an invalid result structure." heatmap_raw = result["SUM"]["HEATMAP"] if not isinstance(heatmap_raw, np.ndarray) or heatmap_raw.size == 0: return placeholder_heatmap, f"Generated heatmap is invalid or empty for {measurement}." try: heatmap_normalized = safe_normalize_heatmap(heatmap_raw) heatmap_color = cv2.applyColorMap(heatmap_normalized, cv2.COLORMAP_HOT) heatmap_rgb = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB) except Exception as e: return placeholder_heatmap, f"Error during heatmap coloring: {e}" status_msg = status_msg_prefix + f"{measurement} comparison successful." return heatmap_rgb, status_msg def clear_uploads(msg): """Clears the image displays and updates the status message.""" return None, None, msg def load_and_compare_initial(task): """Gathers initial images and runs a comparison on them at startup.""" # Step 1: Get the initial images real_img, fake_img, gather_status = gather_images(task) # Step 2: Run the default comparison # We use the default values from the UI definition default_measurement = "Cosine Similarity" default_block_size = 8 heatmap, compare_status = run_comparison(real_img, fake_img, default_measurement, default_block_size) # Step 3: Combine status messages and return all initial values final_status = f"{gather_status}\n{compare_status}" return real_img, fake_img, heatmap, final_status # --- Gradio UI Definition --- theme = gr.themes.Soft(primary_hue="blue", secondary_hue="orange") with gr.Blocks(theme=theme, css=".gradio-container { max-width: 1400px !important; margin: auto; }") as demo: gr.Markdown("# GAN vs Ground Truth Image Comparison") gr.Markdown("Compare images by loading a sample pair from a dataset or by uploading your own. Choose a comparison metric and run the analysis to see the difference heatmap.") status_message = gr.Textbox(label="Status / Errors", lines=2, interactive=False, show_copy_button=True) with gr.Row(equal_height=False): with gr.Column(scale=1, min_width=300): gr.Markdown("### 1. Get Images") with gr.Tabs(): with gr.TabItem("Load from Dataset"): task_dropdown = gr.Dropdown( ["facades"], value=TASK, info="Select the dataset task.", label="Dataset Task" ) sample_button = gr.Button("🔄 Get New Sample Pair", variant="secondary") with gr.TabItem("Upload Images"): gr.Markdown("Upload your own images to compare.") upload_real_img = gr.Image(type="numpy", label="Upload Real/Reference Image") upload_fake_img = gr.Image(type="numpy", label="Upload Fake/Comparison Image") with gr.Column(scale=2, min_width=600): gr.Markdown("### 2. View Images & Run Comparison") with gr.Row(): real_img_display = gr.Image(type="numpy", label="Real Image (Ground Truth)", height=350, interactive=False) fake_img_display = gr.Image(type="numpy", label="Fake Image (Generated by GAN)", height=350, interactive=False) with gr.Row(): measurement_dropdown = gr.Dropdown( ["Kullback-Leibler Divergence", "L1-Loss", "MSE", "SSIM", "Cosine Similarity", "TV", "Perceptual"], value="Cosine Similarity", info="Select the comparison metric.", label="Comparison Metric", scale=2 ) block_size_slider = gr.Slider( minimum=2, maximum=64, value=8, step=2, info="Size of the block/window for comparison.", label="Block/Window Size", scale=1 ) run_button = gr.Button("📊 Run Comparison", variant="primary") with gr.Column(scale=1, min_width=300): gr.Markdown("### 3. See Result") heatmap_display = gr.Image(type="numpy", label="Comparison Heatmap (Difference)", height=350, interactive=False) # --- Event Listeners --- # Load initial sample and run comparison when the app starts demo.load( fn=load_and_compare_initial, inputs=[task_dropdown], outputs=[real_img_display, fake_img_display, heatmap_display, status_message] ) sample_button.click( fn=gather_images, inputs=[task_dropdown], outputs=[real_img_display, fake_img_display, status_message] ) upload_real_img.upload( fn=lambda x: x, inputs=[upload_real_img], outputs=[real_img_display] ) upload_fake_img.upload( fn=lambda x: x, inputs=[upload_fake_img], outputs=[fake_img_display] ) run_button.click( fn=run_comparison, inputs=[real_img_display, fake_img_display, measurement_dropdown, block_size_slider], outputs=[heatmap_display, status_message] ) task_dropdown.change( fn=clear_uploads, inputs=[gr.Textbox(value="Task changed. Please get a new sample.", visible=False)], outputs=[real_img_display, fake_img_display, status_message] ) # --- Application Entry Point --- if __name__ == "__main__": print("-------------------------------------------------------------") print("Verifying VGG19 model status...") if perceptual_model is None: print("WARNING: VGG19 model failed to load. 'Perceptual' metric will be unavailable.") else: print("VGG19 model loaded successfully.") print("-------------------------------------------------------------") print(f"Checking initial dataset path: {PATH}") if not os.path.isdir(PATH): print(f"WARNING: Initial dataset path not found: {PATH}") print(f" Please ensure the directory '{os.path.join('datasets', TASK, 'real')}' exists.") else: print("Initial dataset path seems valid.") print("-------------------------------------------------------------") print("Launching Gradio App...") print("Access the app in your browser, usually at: http://127.0.0.1:7860") print("-------------------------------------------------------------") demo.launch(share=False, debug=False)