ParamDev's picture
Update app.py
56f90b5 verified
import gradio as gr
import os
import sys
import yaml
import torch
from PIL import Image
import numpy as np
import torchvision.transforms as transforms # Added this import
# Add the project root to sys.path to allow imports from sibling directories
# Assuming app.py is in the root of the space, and visual-quality-inspection is a subdirectory
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'visual-quality-inspection')))
# Import your core anomaly detection functions
# Make sure these imports work relative to the sys.path adjustments
from visual_quality_inspection.anomaly_detection import load_custom_model, prepare_torchscript_model, inference_score, get_PCA_kernel, get_partial_model, get_train_features # Added get_partial_model, get_train_features
from visual_quality_inspection.dataset import Mvtec # Your custom dataset class
# --- Configuration Loading ---
# Define the path to your eval.yaml within the Space
CONFIG_FILE_PATH = 'visual-quality-inspection/configs/eval.yaml'
# Define the path where your model is located within the Space
MODEL_OUTPUT_PATH = 'visual-quality-inspection/models' # This should point to the 'models' directory you create
# Load config once at startup
with open(CONFIG_FILE_PATH, "r") as f:
config = yaml.safe_load(f)
# --- Global Model and PCA Kernel Loading (run once when the app starts) ---
# This ensures the model is loaded only once, not on every inference call.
print("Loading model and preparing PCA kernel...")
# Ensure the correct feature_extractor and category_type are set in config
# This assumes you've pre-modified eval.yaml or you set them here programmatically
# For this example, let's assume eval.yaml is already set to 'simsiam' and 'bottle' or 'all'
# If you need to override:
# config['model']['feature_extractor'] = 'simsiam'
# config['dataset']['category_type'] = 'bottle' # Or 'all' if you want to iterate
# Load the pre-trained model
model = load_custom_model(MODEL_OUTPUT_PATH, config)
if model is None:
raise RuntimeError("Failed to load the custom model. Check model path and file integrity.")
# Prepare a dummy dataset for feature shape inference and PCA training
current_category = config['dataset']['category_type']
if current_category == 'all':
print("Config category is 'all'. Using 'bottle' for initial PCA training for demo purposes.")
pca_train_category = 'bottle'
else:
pca_train_category = current_category
trainset = Mvtec(
root_dir=config['dataset']['root_dir'],
object_type=pca_train_category,
split='train',
im_size=config['dataset']['image_size']
)
partial_model, feature_shape = get_partial_model(model, trainset, config['model'])
model_ts = prepare_torchscript_model(partial_model, config)
train_features, _ = get_train_features(model_ts, trainset, feature_shape, config)
pca_kernel = get_PCA_kernel(train_features, config)
print("Model and PCA kernel loaded successfully.")
# --- Anomaly Detection Function for Gradio ---
def predict_anomaly(input_image: Image.Image, current_category_choice: str):
\"\"\"
Performs anomaly detection on a single input image for a chosen category.
\"\"\"
# Ensure the model is in evaluation mode
model.eval()
# Apply the same transformations as defined in Mvtec
im_size = config['dataset']['image_size']
transform = transforms.Compose([
transforms.Resize((im_size, im_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transformed_image = transform(input_image.convert('RGB')).unsqueeze(0) # Add batch dimension
# Dynamically update category in config for inference if 'all' is chosen or new category
# Note: This config change is local to this function call and won't affect global `config`
# for subsequent calls, which is fine for Gradio's stateless nature per call.
original_category_config = config['dataset']['category_type'] # Store original
config['dataset']['category_type'] = current_category_choice # Use user's choice for this inference
with torch.cpu.amp.autocast(enabled=config['precision']=='bfloat16'):
inputs = transformed_image.contiguous(memory_format=torch.channels_last)
if config['precision'] == 'bfloat16':
inputs = inputs.to(torch.bfloat16)
features = partial_model(inputs)[config['model']['layer']]
pool_out = torch.nn.functional.avg_pool2d(features, config['model']['pool']) if config['model']['pool'] > 1 else features
outputs = pool_out.contiguous().view(pool_out.size(0), -1)
oi = outputs
oi_or = oi
oi_j = pca_kernel.transform(oi)
oi_reconstructed = pca_kernel.inverse_transform(oi_j)
fre = torch.square(oi_or - oi_reconstructed).reshape(outputs.shape)
fre_score = torch.sum(fre, dim=1)
score = -fre_score.item() # Get the single scalar score
# Revert category_type in config if it was changed (good practice, though not strictly needed for Gradio)
config['dataset']['category_type'] = original_category_config
# Simple anomaly threshold for display
# You might want to get a threshold from your eval.yaml or a pre-computed one
# For now, a simple rule: if score is very low (highly negative), it's anomalous.
# This threshold is illustrative and should be determined from training/validation.
ANOMALY_THRESHOLD = -100.0 # Example threshold, adjust based on your model's score range
status = "Anomaly Detected!" if score < ANOMALY_THRESHOLD else "Normal"
return f"Status: {status} | Anomaly Score: {score:.4f}", input_image
# Get available categories from the data directory
DATA_ROOT_DIR = config['dataset']['root_dir']
# Ensure DATA_ROOT_DIR exists before listing
if not os.path.isdir(DATA_ROOT_DIR):
print(f"Warning: Data root directory '{DATA_ROOT_DIR}' not found. Falling back to default categories.")
available_categories = ["bottle", "cable", "capsule", "carpet", "grid", "hazelnut", "leather", "metal_nut", "pill", "screw", "tile", "toothbrush", "transistor", "wood", "zipper"]
else:
available_categories = [
os.path.basename(d) for d in os.listdir(DATA_ROOT_DIR)
if os.path.isdir(os.path.join(DATA_ROOT_DIR, d)) and d not in ['ground_truth'] # Exclude ground_truth if it's a top-level dir
]
available_categories.sort()
if not available_categories:
available_categories = ["bottle"] # Final fallback if no categories found
# --- Gradio Interface ---
iface = gr.Interface(
fn=predict_anomaly,
inputs=[
gr.Image(type="pil", label="Upload Image for Anomaly Detection"),
gr.Dropdown(choices=available_categories, label="Select Category", value=available_categories[0] if available_categories else "bottle")
],
outputs=[
gr.Textbox(label="Anomaly Detection Result"),
gr.Image(type="pil", label="Input Image")
],
title="Visual Anomaly Detection (SimSiam + PCA)",
description="Upload an image and select its category to detect anomalies using a pre-trained SimSiam model with PCA-based anomaly scoring. Note: The anomaly threshold is illustrative and may need tuning."
)
iface.launch()