X-YOLOv10 / app.py
HugoHE's picture
Update app.py
cb187fe verified
import functools
import cv2
import numpy as np
import gradio as gr
import os
from types import MethodType
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
# Import helper functions from the existing feature-extractor script
from yolov10_RoIFX import (
_predict_once,
get_result_with_features_yolov10_simple,
draw_modern_bbox,
draw_feature_heatmap,
)
# ---------------------------
# Constants & Setup
# ---------------------------
# Set up model and example paths
REPO_ID = "HugoHE/X-YOLOv10"
MODELS_DIR = "models"
os.makedirs(MODELS_DIR, exist_ok=True)
# Download models from Hugging Face Hub
def download_models():
for model_file in ["vanilla.pt", "finetune.pt"]:
if not os.path.exists(os.path.join(MODELS_DIR, model_file)):
try:
hf_hub_download(
repo_id=REPO_ID,
filename=f"models/{model_file}",
local_dir=".",
local_dir_use_symlinks=False
)
except Exception as e:
print(f"Error downloading {model_file}: {e}")
# Download example images from Hugging Face Hub
def download_examples():
for img_file in ["1.png", "2.png"]:
if not os.path.exists(img_file):
try:
hf_hub_download(
repo_id=REPO_ID,
filename=img_file,
local_dir=".",
local_dir_use_symlinks=False
)
except Exception as e:
print(f"Error downloading {img_file}: {e}")
# Download required files
download_models()
download_examples()
AVAILABLE_MODELS = {
"Vanilla VOC": "vanilla.pt",
"Finetune VOC": "finetune.pt"
}
# Example images with their descriptions
EXAMPLES = [
["1.png", 0.25],
["2.png", 0.25]
]
# ---------------------------
# Model loading & caching
# ---------------------------
def load_model(model_name: str):
"""Load a YOLOv10 model and cache it so subsequent calls are fast."""
@functools.lru_cache(maxsize=2)
def _loader(name: str):
model_path = os.path.join(MODELS_DIR, AVAILABLE_MODELS[name])
model = YOLO(model_path)
# Monkey-patch the predictor so we can extract feature maps on demand
model.model._predict_once = MethodType(_predict_once, model.model)
# Run a dummy inference to initialise internals
model(np.zeros((640, 640, 3)), verbose=False)
# Automatically determine which layers to use for feature extraction
detect_layer_idx = -1
for i, m in enumerate(model.model.model):
if "Detect" in type(m).__name__:
detect_layer_idx = i
break
if detect_layer_idx != -1:
input_layer_idxs = model.model.model[detect_layer_idx].f
embed_layers = sorted(input_layer_idxs) + [detect_layer_idx]
else:
embed_layers = [16, 19, 22, 23] # fallback
return model, tuple(embed_layers)
return _loader(model_name)
# ---------------------------
# Composite heat-map layout
# ---------------------------
def generate_heatmap_layout(img_rgb: np.ndarray, conf: float = 0.25):
"""Return separate XAI heatmap layouts for vanilla and fine-tuned models."""
# Convert RGB (Gradio default) ➜ BGR (OpenCV default)
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
# Load both models
vanilla_model, vanilla_embed_layers = load_model("Vanilla VOC")
finetune_model, finetune_embed_layers = load_model("Finetune VOC")
# Run inference on both models
vanilla_results = get_result_with_features_yolov10_simple(
vanilla_model, img_bgr, vanilla_embed_layers, conf=conf
)
finetune_results = get_result_with_features_yolov10_simple(
finetune_model, img_bgr, finetune_embed_layers, conf=conf
)
# Check if any detections were made
vanilla_has_detections = (vanilla_results and len(vanilla_results) > 0 and
hasattr(vanilla_results[0], "boxes") and len(vanilla_results[0].boxes) > 0)
finetune_has_detections = (finetune_results and len(finetune_results) > 0 and
hasattr(finetune_results[0], "boxes") and len(finetune_results[0].boxes) > 0)
# Create heatmap visualizations for both models
vanilla_heatmaps = []
finetune_heatmaps = []
if vanilla_has_detections:
vanilla_result = vanilla_results[0]
vanilla_names = [vanilla_model.model.names[int(cls)] for cls in vanilla_result.boxes.cls]
vanilla_heatmaps = create_heatmap_snippets(img_bgr, vanilla_result, vanilla_names, "Vanilla")
if finetune_has_detections:
finetune_result = finetune_results[0]
finetune_names = [finetune_model.model.names[int(cls)] for cls in finetune_result.boxes.cls]
finetune_heatmaps = create_heatmap_snippets(img_bgr, finetune_result, finetune_names, "Fine-tuned")
# Create separate layouts for each model
vanilla_layout = create_model_layout(vanilla_heatmaps, "Vanilla Model", (0, 100, 0))
finetune_layout = create_model_layout(finetune_heatmaps, "Fine-tuned Model", (0, 0, 200))
# Convert BGR to RGB for display
vanilla_output = cv2.cvtColor(vanilla_layout, cv2.COLOR_BGR2RGB) if vanilla_layout is not None else None
finetune_output = cv2.cvtColor(finetune_layout, cv2.COLOR_BGR2RGB) if finetune_layout is not None else None
return vanilla_output, finetune_output
def create_heatmap_snippets(img_bgr, result, names, model_type):
"""Create heatmap snippets for detected objects."""
snippets = []
if hasattr(result, "pooled_feats") and result.pooled_feats:
last_pooled = result.pooled_feats[-1]
for i in range(len(result.boxes)):
box = result.boxes.xyxy[i]
fmap = last_pooled[i]
heatmap_full = draw_feature_heatmap(img_bgr.copy(), box, fmap)
x1, y1, x2, y2 = box.cpu().numpy().astype(int)
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(img_bgr.shape[1], x2), min(img_bgr.shape[0], y2)
if x2 <= x1 or y2 <= y1:
continue
snippet = heatmap_full[y1:y2, x1:x2]
# Add caption with model type and object info
caption = f"{model_type}: {names[i]}"
font = cv2.FONT_HERSHEY_SIMPLEX
(tw, th), _ = cv2.getTextSize(caption, font, 0.6, 1)
canvas = np.full((snippet.shape[0] + th + 15, max(snippet.shape[1], tw + 10), 3), 255, np.uint8)
# center the snippet
cx = (canvas.shape[1] - snippet.shape[1]) // 2
canvas[0 : snippet.shape[0], cx : cx + snippet.shape[1]] = snippet
# put caption
tx = (canvas.shape[1] - tw) // 2
cv2.putText(canvas, caption, (tx, snippet.shape[0] + th + 5), font, 0.6, (0, 0, 0), 1, cv2.LINE_AA)
cv2.rectangle(canvas, (0, 0), (canvas.shape[1] - 1, canvas.shape[0] - 1), (180, 180, 180), 1)
snippets.append(canvas)
return snippets
def create_model_layout(heatmaps, title, color):
"""Create a layout for one model's heatmaps."""
pad = 20
if not heatmaps:
# Create empty section with title
font = cv2.FONT_HERSHEY_SIMPLEX
(tw, th), _ = cv2.getTextSize(title, font, 1.0, 2)
canvas = np.full((th + 40, tw + 20, 3), 255, np.uint8)
cv2.putText(canvas, title, (10, th + 20), font, 1.0, color, 2, cv2.LINE_AA)
return canvas
# Arrange heatmaps in a row
max_h = max(h.shape[0] for h in heatmaps)
total_w = sum(h.shape[1] for h in heatmaps) + (len(heatmaps) - 1) * 10
# Add title space
title_font = cv2.FONT_HERSHEY_SIMPLEX
(tw, th), _ = cv2.getTextSize(title, title_font, 1.0, 2)
section_h = max_h + th + 40
section_w = max(total_w, tw + 20)
# Create canvas with padding
canvas_h = section_h + 2 * pad
canvas_w = section_w + 2 * pad
canvas = np.full((canvas_h, canvas_w, 3), 255, np.uint8)
# Add title
cv2.putText(canvas, title, (pad + 10, pad + th + 20), title_font, 1.0, color, 2, cv2.LINE_AA)
# Arrange heatmaps
cur_x = pad
for h in heatmaps:
y_off = pad + th + 30 + (max_h - h.shape[0]) // 2
canvas[y_off : y_off + h.shape[0], cur_x : cur_x + h.shape[1]] = h
cur_x += h.shape[1] + 10
return canvas
# ---------------------------
# Gradio UI definition
# ---------------------------
def build_demo():
with gr.Blocks(title="YOLOv10 XAI Heatmap Comparison") as demo:
gr.Markdown("# YOLOv10 XAI Heatmap Comparison")
gr.Markdown("Upload an image to compare XAI heatmaps between vanilla and fine-tuned YOLOv10 models.")
with gr.Row():
# Left side - Input controls
with gr.Column(scale=1):
image_input = gr.Image(type="numpy", label="Input Image")
conf_input = gr.Slider(minimum=0.05, maximum=1.0, step=0.05, value=0.25, label="Confidence Threshold")
gr.Markdown("### Example Images")
gr.Examples(
examples=EXAMPLES,
inputs=[image_input, conf_input],
label="Click to load example"
)
# Right side - Output visualizations (separated vertically)
with gr.Column(scale=2):
vanilla_output = gr.Image(type="numpy", label="Vanilla Model Heatmap")
finetune_output = gr.Image(type="numpy", label="Fine-tuned Model Heatmap")
# Connect inputs to the function
def update_heatmap(image, confidence):
if image is None:
return None, None
return generate_heatmap_layout(image, confidence)
# Set up the interface
image_input.change(fn=update_heatmap, inputs=[image_input, conf_input], outputs=[vanilla_output, finetune_output])
conf_input.change(fn=update_heatmap, inputs=[image_input, conf_input], outputs=[vanilla_output, finetune_output])
return demo
def main():
demo = build_demo()
demo.launch()
if __name__ == "__main__":
main()