Spaces:
Configuration error
Configuration error
import gradio as gr | |
import os | |
import sys | |
import yaml | |
import torch | |
from PIL import Image | |
import numpy as np | |
import torchvision.transforms as transforms # Added this import | |
# Add the project root to sys.path to allow imports from sibling directories | |
# Assuming app.py is in the root of the space, and visual-quality-inspection is a subdirectory | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'visual-quality-inspection'))) | |
# Import your core anomaly detection functions | |
# Make sure these imports work relative to the sys.path adjustments | |
from visual_quality_inspection.anomaly_detection import load_custom_model, prepare_torchscript_model, inference_score, get_PCA_kernel, get_partial_model, get_train_features # Added get_partial_model, get_train_features | |
from visual_quality_inspection.dataset import Mvtec # Your custom dataset class | |
# --- Configuration Loading --- | |
# Define the path to your eval.yaml within the Space | |
CONFIG_FILE_PATH = 'visual-quality-inspection/configs/eval.yaml' | |
# Define the path where your model is located within the Space | |
MODEL_OUTPUT_PATH = 'visual-quality-inspection/models' # This should point to the 'models' directory you create | |
# Load config once at startup | |
with open(CONFIG_FILE_PATH, "r") as f: | |
config = yaml.safe_load(f) | |
# --- Global Model and PCA Kernel Loading (run once when the app starts) --- | |
# This ensures the model is loaded only once, not on every inference call. | |
print("Loading model and preparing PCA kernel...") | |
# Ensure the correct feature_extractor and category_type are set in config | |
# This assumes you've pre-modified eval.yaml or you set them here programmatically | |
# For this example, let's assume eval.yaml is already set to 'simsiam' and 'bottle' or 'all' | |
# If you need to override: | |
# config['model']['feature_extractor'] = 'simsiam' | |
# config['dataset']['category_type'] = 'bottle' # Or 'all' if you want to iterate | |
# Load the pre-trained model | |
model = load_custom_model(MODEL_OUTPUT_PATH, config) | |
if model is None: | |
raise RuntimeError("Failed to load the custom model. Check model path and file integrity.") | |
# Prepare a dummy dataset for feature shape inference and PCA training | |
current_category = config['dataset']['category_type'] | |
if current_category == 'all': | |
print("Config category is 'all'. Using 'bottle' for initial PCA training for demo purposes.") | |
pca_train_category = 'bottle' | |
else: | |
pca_train_category = current_category | |
trainset = Mvtec( | |
root_dir=config['dataset']['root_dir'], | |
object_type=pca_train_category, | |
split='train', | |
im_size=config['dataset']['image_size'] | |
) | |
partial_model, feature_shape = get_partial_model(model, trainset, config['model']) | |
model_ts = prepare_torchscript_model(partial_model, config) | |
train_features, _ = get_train_features(model_ts, trainset, feature_shape, config) | |
pca_kernel = get_PCA_kernel(train_features, config) | |
print("Model and PCA kernel loaded successfully.") | |
# --- Anomaly Detection Function for Gradio --- | |
def predict_anomaly(input_image: Image.Image, current_category_choice: str): | |
\"\"\" | |
Performs anomaly detection on a single input image for a chosen category. | |
\"\"\" | |
# Ensure the model is in evaluation mode | |
model.eval() | |
# Apply the same transformations as defined in Mvtec | |
im_size = config['dataset']['image_size'] | |
transform = transforms.Compose([ | |
transforms.Resize((im_size, im_size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
transformed_image = transform(input_image.convert('RGB')).unsqueeze(0) # Add batch dimension | |
# Dynamically update category in config for inference if 'all' is chosen or new category | |
# Note: This config change is local to this function call and won't affect global `config` | |
# for subsequent calls, which is fine for Gradio's stateless nature per call. | |
original_category_config = config['dataset']['category_type'] # Store original | |
config['dataset']['category_type'] = current_category_choice # Use user's choice for this inference | |
with torch.cpu.amp.autocast(enabled=config['precision']=='bfloat16'): | |
inputs = transformed_image.contiguous(memory_format=torch.channels_last) | |
if config['precision'] == 'bfloat16': | |
inputs = inputs.to(torch.bfloat16) | |
features = partial_model(inputs)[config['model']['layer']] | |
pool_out = torch.nn.functional.avg_pool2d(features, config['model']['pool']) if config['model']['pool'] > 1 else features | |
outputs = pool_out.contiguous().view(pool_out.size(0), -1) | |
oi = outputs | |
oi_or = oi | |
oi_j = pca_kernel.transform(oi) | |
oi_reconstructed = pca_kernel.inverse_transform(oi_j) | |
fre = torch.square(oi_or - oi_reconstructed).reshape(outputs.shape) | |
fre_score = torch.sum(fre, dim=1) | |
score = -fre_score.item() # Get the single scalar score | |
# Revert category_type in config if it was changed (good practice, though not strictly needed for Gradio) | |
config['dataset']['category_type'] = original_category_config | |
# Simple anomaly threshold for display | |
# You might want to get a threshold from your eval.yaml or a pre-computed one | |
# For now, a simple rule: if score is very low (highly negative), it's anomalous. | |
# This threshold is illustrative and should be determined from training/validation. | |
ANOMALY_THRESHOLD = -100.0 # Example threshold, adjust based on your model's score range | |
status = "Anomaly Detected!" if score < ANOMALY_THRESHOLD else "Normal" | |
return f"Status: {status} | Anomaly Score: {score:.4f}", input_image | |
# Get available categories from the data directory | |
DATA_ROOT_DIR = config['dataset']['root_dir'] | |
# Ensure DATA_ROOT_DIR exists before listing | |
if not os.path.isdir(DATA_ROOT_DIR): | |
print(f"Warning: Data root directory '{DATA_ROOT_DIR}' not found. Falling back to default categories.") | |
available_categories = ["bottle", "cable", "capsule", "carpet", "grid", "hazelnut", "leather", "metal_nut", "pill", "screw", "tile", "toothbrush", "transistor", "wood", "zipper"] | |
else: | |
available_categories = [ | |
os.path.basename(d) for d in os.listdir(DATA_ROOT_DIR) | |
if os.path.isdir(os.path.join(DATA_ROOT_DIR, d)) and d not in ['ground_truth'] # Exclude ground_truth if it's a top-level dir | |
] | |
available_categories.sort() | |
if not available_categories: | |
available_categories = ["bottle"] # Final fallback if no categories found | |
# --- Gradio Interface --- | |
iface = gr.Interface( | |
fn=predict_anomaly, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Image for Anomaly Detection"), | |
gr.Dropdown(choices=available_categories, label="Select Category", value=available_categories[0] if available_categories else "bottle") | |
], | |
outputs=[ | |
gr.Textbox(label="Anomaly Detection Result"), | |
gr.Image(type="pil", label="Input Image") | |
], | |
title="Visual Anomaly Detection (SimSiam + PCA)", | |
description="Upload an image and select its category to detect anomalies using a pre-trained SimSiam model with PCA-based anomaly scoring. Note: The anomaly threshold is illustrative and may need tuning." | |
) | |
iface.launch() |