ParamDev's picture
Update app.py
56f90b5 verified
raw
history blame
7.22 kB
import gradio as gr
import os
import sys
import yaml
import torch
from PIL import Image
import numpy as np
import torchvision.transforms as transforms # Added this import
# Add the project root to sys.path to allow imports from sibling directories
# Assuming app.py is in the root of the space, and visual-quality-inspection is a subdirectory
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'visual-quality-inspection')))
# Import your core anomaly detection functions
# Make sure these imports work relative to the sys.path adjustments
from visual_quality_inspection.anomaly_detection import load_custom_model, prepare_torchscript_model, inference_score, get_PCA_kernel, get_partial_model, get_train_features # Added get_partial_model, get_train_features
from visual_quality_inspection.dataset import Mvtec # Your custom dataset class
# --- Configuration Loading ---
# Define the path to your eval.yaml within the Space
CONFIG_FILE_PATH = 'visual-quality-inspection/configs/eval.yaml'
# Define the path where your model is located within the Space
MODEL_OUTPUT_PATH = 'visual-quality-inspection/models' # This should point to the 'models' directory you create
# Load config once at startup
with open(CONFIG_FILE_PATH, "r") as f:
config = yaml.safe_load(f)
# --- Global Model and PCA Kernel Loading (run once when the app starts) ---
# This ensures the model is loaded only once, not on every inference call.
print("Loading model and preparing PCA kernel...")
# Ensure the correct feature_extractor and category_type are set in config
# This assumes you've pre-modified eval.yaml or you set them here programmatically
# For this example, let's assume eval.yaml is already set to 'simsiam' and 'bottle' or 'all'
# If you need to override:
# config['model']['feature_extractor'] = 'simsiam'
# config['dataset']['category_type'] = 'bottle' # Or 'all' if you want to iterate
# Load the pre-trained model
model = load_custom_model(MODEL_OUTPUT_PATH, config)
if model is None:
raise RuntimeError("Failed to load the custom model. Check model path and file integrity.")
# Prepare a dummy dataset for feature shape inference and PCA training
current_category = config['dataset']['category_type']
if current_category == 'all':
print("Config category is 'all'. Using 'bottle' for initial PCA training for demo purposes.")
pca_train_category = 'bottle'
else:
pca_train_category = current_category
trainset = Mvtec(
root_dir=config['dataset']['root_dir'],
object_type=pca_train_category,
split='train',
im_size=config['dataset']['image_size']
)
partial_model, feature_shape = get_partial_model(model, trainset, config['model'])
model_ts = prepare_torchscript_model(partial_model, config)
train_features, _ = get_train_features(model_ts, trainset, feature_shape, config)
pca_kernel = get_PCA_kernel(train_features, config)
print("Model and PCA kernel loaded successfully.")
# --- Anomaly Detection Function for Gradio ---
def predict_anomaly(input_image: Image.Image, current_category_choice: str):
\"\"\"
Performs anomaly detection on a single input image for a chosen category.
\"\"\"
# Ensure the model is in evaluation mode
model.eval()
# Apply the same transformations as defined in Mvtec
im_size = config['dataset']['image_size']
transform = transforms.Compose([
transforms.Resize((im_size, im_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transformed_image = transform(input_image.convert('RGB')).unsqueeze(0) # Add batch dimension
# Dynamically update category in config for inference if 'all' is chosen or new category
# Note: This config change is local to this function call and won't affect global `config`
# for subsequent calls, which is fine for Gradio's stateless nature per call.
original_category_config = config['dataset']['category_type'] # Store original
config['dataset']['category_type'] = current_category_choice # Use user's choice for this inference
with torch.cpu.amp.autocast(enabled=config['precision']=='bfloat16'):
inputs = transformed_image.contiguous(memory_format=torch.channels_last)
if config['precision'] == 'bfloat16':
inputs = inputs.to(torch.bfloat16)
features = partial_model(inputs)[config['model']['layer']]
pool_out = torch.nn.functional.avg_pool2d(features, config['model']['pool']) if config['model']['pool'] > 1 else features
outputs = pool_out.contiguous().view(pool_out.size(0), -1)
oi = outputs
oi_or = oi
oi_j = pca_kernel.transform(oi)
oi_reconstructed = pca_kernel.inverse_transform(oi_j)
fre = torch.square(oi_or - oi_reconstructed).reshape(outputs.shape)
fre_score = torch.sum(fre, dim=1)
score = -fre_score.item() # Get the single scalar score
# Revert category_type in config if it was changed (good practice, though not strictly needed for Gradio)
config['dataset']['category_type'] = original_category_config
# Simple anomaly threshold for display
# You might want to get a threshold from your eval.yaml or a pre-computed one
# For now, a simple rule: if score is very low (highly negative), it's anomalous.
# This threshold is illustrative and should be determined from training/validation.
ANOMALY_THRESHOLD = -100.0 # Example threshold, adjust based on your model's score range
status = "Anomaly Detected!" if score < ANOMALY_THRESHOLD else "Normal"
return f"Status: {status} | Anomaly Score: {score:.4f}", input_image
# Get available categories from the data directory
DATA_ROOT_DIR = config['dataset']['root_dir']
# Ensure DATA_ROOT_DIR exists before listing
if not os.path.isdir(DATA_ROOT_DIR):
print(f"Warning: Data root directory '{DATA_ROOT_DIR}' not found. Falling back to default categories.")
available_categories = ["bottle", "cable", "capsule", "carpet", "grid", "hazelnut", "leather", "metal_nut", "pill", "screw", "tile", "toothbrush", "transistor", "wood", "zipper"]
else:
available_categories = [
os.path.basename(d) for d in os.listdir(DATA_ROOT_DIR)
if os.path.isdir(os.path.join(DATA_ROOT_DIR, d)) and d not in ['ground_truth'] # Exclude ground_truth if it's a top-level dir
]
available_categories.sort()
if not available_categories:
available_categories = ["bottle"] # Final fallback if no categories found
# --- Gradio Interface ---
iface = gr.Interface(
fn=predict_anomaly,
inputs=[
gr.Image(type="pil", label="Upload Image for Anomaly Detection"),
gr.Dropdown(choices=available_categories, label="Select Category", value=available_categories[0] if available_categories else "bottle")
],
outputs=[
gr.Textbox(label="Anomaly Detection Result"),
gr.Image(type="pil", label="Input Image")
],
title="Visual Anomaly Detection (SimSiam + PCA)",
description="Upload an image and select its category to detect anomalies using a pre-trained SimSiam model with PCA-based anomaly scoring. Note: The anomaly threshold is illustrative and may need tuning."
)
iface.launch()