Spaces:
Configuration error
Configuration error
File size: 7,221 Bytes
46f679b 56f90b5 46f679b 56f90b5 46f679b 56f90b5 46f679b 56f90b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import gradio as gr
import os
import sys
import yaml
import torch
from PIL import Image
import numpy as np
import torchvision.transforms as transforms # Added this import
# Add the project root to sys.path to allow imports from sibling directories
# Assuming app.py is in the root of the space, and visual-quality-inspection is a subdirectory
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'visual-quality-inspection')))
# Import your core anomaly detection functions
# Make sure these imports work relative to the sys.path adjustments
from visual_quality_inspection.anomaly_detection import load_custom_model, prepare_torchscript_model, inference_score, get_PCA_kernel, get_partial_model, get_train_features # Added get_partial_model, get_train_features
from visual_quality_inspection.dataset import Mvtec # Your custom dataset class
# --- Configuration Loading ---
# Define the path to your eval.yaml within the Space
CONFIG_FILE_PATH = 'visual-quality-inspection/configs/eval.yaml'
# Define the path where your model is located within the Space
MODEL_OUTPUT_PATH = 'visual-quality-inspection/models' # This should point to the 'models' directory you create
# Load config once at startup
with open(CONFIG_FILE_PATH, "r") as f:
config = yaml.safe_load(f)
# --- Global Model and PCA Kernel Loading (run once when the app starts) ---
# This ensures the model is loaded only once, not on every inference call.
print("Loading model and preparing PCA kernel...")
# Ensure the correct feature_extractor and category_type are set in config
# This assumes you've pre-modified eval.yaml or you set them here programmatically
# For this example, let's assume eval.yaml is already set to 'simsiam' and 'bottle' or 'all'
# If you need to override:
# config['model']['feature_extractor'] = 'simsiam'
# config['dataset']['category_type'] = 'bottle' # Or 'all' if you want to iterate
# Load the pre-trained model
model = load_custom_model(MODEL_OUTPUT_PATH, config)
if model is None:
raise RuntimeError("Failed to load the custom model. Check model path and file integrity.")
# Prepare a dummy dataset for feature shape inference and PCA training
current_category = config['dataset']['category_type']
if current_category == 'all':
print("Config category is 'all'. Using 'bottle' for initial PCA training for demo purposes.")
pca_train_category = 'bottle'
else:
pca_train_category = current_category
trainset = Mvtec(
root_dir=config['dataset']['root_dir'],
object_type=pca_train_category,
split='train',
im_size=config['dataset']['image_size']
)
partial_model, feature_shape = get_partial_model(model, trainset, config['model'])
model_ts = prepare_torchscript_model(partial_model, config)
train_features, _ = get_train_features(model_ts, trainset, feature_shape, config)
pca_kernel = get_PCA_kernel(train_features, config)
print("Model and PCA kernel loaded successfully.")
# --- Anomaly Detection Function for Gradio ---
def predict_anomaly(input_image: Image.Image, current_category_choice: str):
\"\"\"
Performs anomaly detection on a single input image for a chosen category.
\"\"\"
# Ensure the model is in evaluation mode
model.eval()
# Apply the same transformations as defined in Mvtec
im_size = config['dataset']['image_size']
transform = transforms.Compose([
transforms.Resize((im_size, im_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transformed_image = transform(input_image.convert('RGB')).unsqueeze(0) # Add batch dimension
# Dynamically update category in config for inference if 'all' is chosen or new category
# Note: This config change is local to this function call and won't affect global `config`
# for subsequent calls, which is fine for Gradio's stateless nature per call.
original_category_config = config['dataset']['category_type'] # Store original
config['dataset']['category_type'] = current_category_choice # Use user's choice for this inference
with torch.cpu.amp.autocast(enabled=config['precision']=='bfloat16'):
inputs = transformed_image.contiguous(memory_format=torch.channels_last)
if config['precision'] == 'bfloat16':
inputs = inputs.to(torch.bfloat16)
features = partial_model(inputs)[config['model']['layer']]
pool_out = torch.nn.functional.avg_pool2d(features, config['model']['pool']) if config['model']['pool'] > 1 else features
outputs = pool_out.contiguous().view(pool_out.size(0), -1)
oi = outputs
oi_or = oi
oi_j = pca_kernel.transform(oi)
oi_reconstructed = pca_kernel.inverse_transform(oi_j)
fre = torch.square(oi_or - oi_reconstructed).reshape(outputs.shape)
fre_score = torch.sum(fre, dim=1)
score = -fre_score.item() # Get the single scalar score
# Revert category_type in config if it was changed (good practice, though not strictly needed for Gradio)
config['dataset']['category_type'] = original_category_config
# Simple anomaly threshold for display
# You might want to get a threshold from your eval.yaml or a pre-computed one
# For now, a simple rule: if score is very low (highly negative), it's anomalous.
# This threshold is illustrative and should be determined from training/validation.
ANOMALY_THRESHOLD = -100.0 # Example threshold, adjust based on your model's score range
status = "Anomaly Detected!" if score < ANOMALY_THRESHOLD else "Normal"
return f"Status: {status} | Anomaly Score: {score:.4f}", input_image
# Get available categories from the data directory
DATA_ROOT_DIR = config['dataset']['root_dir']
# Ensure DATA_ROOT_DIR exists before listing
if not os.path.isdir(DATA_ROOT_DIR):
print(f"Warning: Data root directory '{DATA_ROOT_DIR}' not found. Falling back to default categories.")
available_categories = ["bottle", "cable", "capsule", "carpet", "grid", "hazelnut", "leather", "metal_nut", "pill", "screw", "tile", "toothbrush", "transistor", "wood", "zipper"]
else:
available_categories = [
os.path.basename(d) for d in os.listdir(DATA_ROOT_DIR)
if os.path.isdir(os.path.join(DATA_ROOT_DIR, d)) and d not in ['ground_truth'] # Exclude ground_truth if it's a top-level dir
]
available_categories.sort()
if not available_categories:
available_categories = ["bottle"] # Final fallback if no categories found
# --- Gradio Interface ---
iface = gr.Interface(
fn=predict_anomaly,
inputs=[
gr.Image(type="pil", label="Upload Image for Anomaly Detection"),
gr.Dropdown(choices=available_categories, label="Select Category", value=available_categories[0] if available_categories else "bottle")
],
outputs=[
gr.Textbox(label="Anomaly Detection Result"),
gr.Image(type="pil", label="Input Image")
],
title="Visual Anomaly Detection (SimSiam + PCA)",
description="Upload an image and select its category to detect anomalies using a pre-trained SimSiam model with PCA-based anomaly scoring. Note: The anomaly threshold is illustrative and may need tuning."
)
iface.launch() |