import functools import cv2 import numpy as np import gradio as gr import os from types import MethodType from ultralytics import YOLO from huggingface_hub import hf_hub_download # Import helper functions from the existing feature-extractor script from yolov10_RoIFX import ( _predict_once, get_result_with_features_yolov10_simple, draw_modern_bbox, draw_feature_heatmap, ) # --------------------------- # Constants & Setup # --------------------------- # Set up model and example paths REPO_ID = "HugoHE/X-YOLOv10" MODELS_DIR = "models" os.makedirs(MODELS_DIR, exist_ok=True) # Download models from Hugging Face Hub def download_models(): for model_file in ["vanilla.pt", "finetune.pt"]: if not os.path.exists(os.path.join(MODELS_DIR, model_file)): try: hf_hub_download( repo_id=REPO_ID, filename=f"models/{model_file}", local_dir=".", local_dir_use_symlinks=False ) except Exception as e: print(f"Error downloading {model_file}: {e}") # Download example images from Hugging Face Hub def download_examples(): for img_file in ["1.png", "2.png"]: if not os.path.exists(img_file): try: hf_hub_download( repo_id=REPO_ID, filename=img_file, local_dir=".", local_dir_use_symlinks=False ) except Exception as e: print(f"Error downloading {img_file}: {e}") # Download required files download_models() download_examples() AVAILABLE_MODELS = { "Vanilla VOC": "vanilla.pt", "Finetune VOC": "finetune.pt" } # Example images with their descriptions EXAMPLES = [ ["1.png", 0.25], ["2.png", 0.25] ] # --------------------------- # Model loading & caching # --------------------------- def load_model(model_name: str): """Load a YOLOv10 model and cache it so subsequent calls are fast.""" @functools.lru_cache(maxsize=2) def _loader(name: str): model_path = os.path.join(MODELS_DIR, AVAILABLE_MODELS[name]) model = YOLO(model_path) # Monkey-patch the predictor so we can extract feature maps on demand model.model._predict_once = MethodType(_predict_once, model.model) # Run a dummy inference to initialise internals model(np.zeros((640, 640, 3)), verbose=False) # Automatically determine which layers to use for feature extraction detect_layer_idx = -1 for i, m in enumerate(model.model.model): if "Detect" in type(m).__name__: detect_layer_idx = i break if detect_layer_idx != -1: input_layer_idxs = model.model.model[detect_layer_idx].f embed_layers = sorted(input_layer_idxs) + [detect_layer_idx] else: embed_layers = [16, 19, 22, 23] # fallback return model, tuple(embed_layers) return _loader(model_name) # --------------------------- # Composite heat-map layout # --------------------------- def generate_heatmap_layout(img_rgb: np.ndarray, conf: float = 0.25): """Return separate XAI heatmap layouts for vanilla and fine-tuned models.""" # Convert RGB (Gradio default) ➜ BGR (OpenCV default) img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) # Load both models vanilla_model, vanilla_embed_layers = load_model("Vanilla VOC") finetune_model, finetune_embed_layers = load_model("Finetune VOC") # Run inference on both models vanilla_results = get_result_with_features_yolov10_simple( vanilla_model, img_bgr, vanilla_embed_layers, conf=conf ) finetune_results = get_result_with_features_yolov10_simple( finetune_model, img_bgr, finetune_embed_layers, conf=conf ) # Check if any detections were made vanilla_has_detections = (vanilla_results and len(vanilla_results) > 0 and hasattr(vanilla_results[0], "boxes") and len(vanilla_results[0].boxes) > 0) finetune_has_detections = (finetune_results and len(finetune_results) > 0 and hasattr(finetune_results[0], "boxes") and len(finetune_results[0].boxes) > 0) # Create heatmap visualizations for both models vanilla_heatmaps = [] finetune_heatmaps = [] if vanilla_has_detections: vanilla_result = vanilla_results[0] vanilla_names = [vanilla_model.model.names[int(cls)] for cls in vanilla_result.boxes.cls] vanilla_heatmaps = create_heatmap_snippets(img_bgr, vanilla_result, vanilla_names, "Vanilla") if finetune_has_detections: finetune_result = finetune_results[0] finetune_names = [finetune_model.model.names[int(cls)] for cls in finetune_result.boxes.cls] finetune_heatmaps = create_heatmap_snippets(img_bgr, finetune_result, finetune_names, "Fine-tuned") # Create separate layouts for each model vanilla_layout = create_model_layout(vanilla_heatmaps, "Vanilla Model", (0, 100, 0)) finetune_layout = create_model_layout(finetune_heatmaps, "Fine-tuned Model", (0, 0, 200)) # Convert BGR to RGB for display vanilla_output = cv2.cvtColor(vanilla_layout, cv2.COLOR_BGR2RGB) if vanilla_layout is not None else None finetune_output = cv2.cvtColor(finetune_layout, cv2.COLOR_BGR2RGB) if finetune_layout is not None else None return vanilla_output, finetune_output def create_heatmap_snippets(img_bgr, result, names, model_type): """Create heatmap snippets for detected objects.""" snippets = [] if hasattr(result, "pooled_feats") and result.pooled_feats: last_pooled = result.pooled_feats[-1] for i in range(len(result.boxes)): box = result.boxes.xyxy[i] fmap = last_pooled[i] heatmap_full = draw_feature_heatmap(img_bgr.copy(), box, fmap) x1, y1, x2, y2 = box.cpu().numpy().astype(int) x1, y1 = max(0, x1), max(0, y1) x2, y2 = min(img_bgr.shape[1], x2), min(img_bgr.shape[0], y2) if x2 <= x1 or y2 <= y1: continue snippet = heatmap_full[y1:y2, x1:x2] # Add caption with model type and object info caption = f"{model_type}: {names[i]}" font = cv2.FONT_HERSHEY_SIMPLEX (tw, th), _ = cv2.getTextSize(caption, font, 0.6, 1) canvas = np.full((snippet.shape[0] + th + 15, max(snippet.shape[1], tw + 10), 3), 255, np.uint8) # center the snippet cx = (canvas.shape[1] - snippet.shape[1]) // 2 canvas[0 : snippet.shape[0], cx : cx + snippet.shape[1]] = snippet # put caption tx = (canvas.shape[1] - tw) // 2 cv2.putText(canvas, caption, (tx, snippet.shape[0] + th + 5), font, 0.6, (0, 0, 0), 1, cv2.LINE_AA) cv2.rectangle(canvas, (0, 0), (canvas.shape[1] - 1, canvas.shape[0] - 1), (180, 180, 180), 1) snippets.append(canvas) return snippets def create_model_layout(heatmaps, title, color): """Create a layout for one model's heatmaps.""" pad = 20 if not heatmaps: # Create empty section with title font = cv2.FONT_HERSHEY_SIMPLEX (tw, th), _ = cv2.getTextSize(title, font, 1.0, 2) canvas = np.full((th + 40, tw + 20, 3), 255, np.uint8) cv2.putText(canvas, title, (10, th + 20), font, 1.0, color, 2, cv2.LINE_AA) return canvas # Arrange heatmaps in a row max_h = max(h.shape[0] for h in heatmaps) total_w = sum(h.shape[1] for h in heatmaps) + (len(heatmaps) - 1) * 10 # Add title space title_font = cv2.FONT_HERSHEY_SIMPLEX (tw, th), _ = cv2.getTextSize(title, title_font, 1.0, 2) section_h = max_h + th + 40 section_w = max(total_w, tw + 20) # Create canvas with padding canvas_h = section_h + 2 * pad canvas_w = section_w + 2 * pad canvas = np.full((canvas_h, canvas_w, 3), 255, np.uint8) # Add title cv2.putText(canvas, title, (pad + 10, pad + th + 20), title_font, 1.0, color, 2, cv2.LINE_AA) # Arrange heatmaps cur_x = pad for h in heatmaps: y_off = pad + th + 30 + (max_h - h.shape[0]) // 2 canvas[y_off : y_off + h.shape[0], cur_x : cur_x + h.shape[1]] = h cur_x += h.shape[1] + 10 return canvas # --------------------------- # Gradio UI definition # --------------------------- def build_demo(): with gr.Blocks(title="YOLOv10 XAI Heatmap Comparison") as demo: gr.Markdown("# YOLOv10 XAI Heatmap Comparison") gr.Markdown("Upload an image to compare XAI heatmaps between vanilla and fine-tuned YOLOv10 models.") with gr.Row(): # Left side - Input controls with gr.Column(scale=1): image_input = gr.Image(type="numpy", label="Input Image") conf_input = gr.Slider(minimum=0.05, maximum=1.0, step=0.05, value=0.25, label="Confidence Threshold") gr.Markdown("### Example Images") gr.Examples( examples=EXAMPLES, inputs=[image_input, conf_input], label="Click to load example" ) # Right side - Output visualizations (separated vertically) with gr.Column(scale=2): vanilla_output = gr.Image(type="numpy", label="Vanilla Model Heatmap") finetune_output = gr.Image(type="numpy", label="Fine-tuned Model Heatmap") # Connect inputs to the function def update_heatmap(image, confidence): if image is None: return None, None return generate_heatmap_layout(image, confidence) # Set up the interface image_input.change(fn=update_heatmap, inputs=[image_input, conf_input], outputs=[vanilla_output, finetune_output]) conf_input.change(fn=update_heatmap, inputs=[image_input, conf_input], outputs=[vanilla_output, finetune_output]) return demo def main(): demo = build_demo() demo.launch() if __name__ == "__main__": main()