import gradio as gr import os import sys import yaml import torch from PIL import Image import numpy as np import torchvision.transforms as transforms # Added this import # Add the project root to sys.path to allow imports from sibling directories # Assuming app.py is in the root of the space, and visual-quality-inspection is a subdirectory sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'visual-quality-inspection'))) # Import your core anomaly detection functions # Make sure these imports work relative to the sys.path adjustments from visual_quality_inspection.anomaly_detection import load_custom_model, prepare_torchscript_model, inference_score, get_PCA_kernel, get_partial_model, get_train_features # Added get_partial_model, get_train_features from visual_quality_inspection.dataset import Mvtec # Your custom dataset class # --- Configuration Loading --- # Define the path to your eval.yaml within the Space CONFIG_FILE_PATH = 'visual-quality-inspection/configs/eval.yaml' # Define the path where your model is located within the Space MODEL_OUTPUT_PATH = 'visual-quality-inspection/models' # This should point to the 'models' directory you create # Load config once at startup with open(CONFIG_FILE_PATH, "r") as f: config = yaml.safe_load(f) # --- Global Model and PCA Kernel Loading (run once when the app starts) --- # This ensures the model is loaded only once, not on every inference call. print("Loading model and preparing PCA kernel...") # Ensure the correct feature_extractor and category_type are set in config # This assumes you've pre-modified eval.yaml or you set them here programmatically # For this example, let's assume eval.yaml is already set to 'simsiam' and 'bottle' or 'all' # If you need to override: # config['model']['feature_extractor'] = 'simsiam' # config['dataset']['category_type'] = 'bottle' # Or 'all' if you want to iterate # Load the pre-trained model model = load_custom_model(MODEL_OUTPUT_PATH, config) if model is None: raise RuntimeError("Failed to load the custom model. Check model path and file integrity.") # Prepare a dummy dataset for feature shape inference and PCA training current_category = config['dataset']['category_type'] if current_category == 'all': print("Config category is 'all'. Using 'bottle' for initial PCA training for demo purposes.") pca_train_category = 'bottle' else: pca_train_category = current_category trainset = Mvtec( root_dir=config['dataset']['root_dir'], object_type=pca_train_category, split='train', im_size=config['dataset']['image_size'] ) partial_model, feature_shape = get_partial_model(model, trainset, config['model']) model_ts = prepare_torchscript_model(partial_model, config) train_features, _ = get_train_features(model_ts, trainset, feature_shape, config) pca_kernel = get_PCA_kernel(train_features, config) print("Model and PCA kernel loaded successfully.") # --- Anomaly Detection Function for Gradio --- def predict_anomaly(input_image: Image.Image, current_category_choice: str): \"\"\" Performs anomaly detection on a single input image for a chosen category. \"\"\" # Ensure the model is in evaluation mode model.eval() # Apply the same transformations as defined in Mvtec im_size = config['dataset']['image_size'] transform = transforms.Compose([ transforms.Resize((im_size, im_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) transformed_image = transform(input_image.convert('RGB')).unsqueeze(0) # Add batch dimension # Dynamically update category in config for inference if 'all' is chosen or new category # Note: This config change is local to this function call and won't affect global `config` # for subsequent calls, which is fine for Gradio's stateless nature per call. original_category_config = config['dataset']['category_type'] # Store original config['dataset']['category_type'] = current_category_choice # Use user's choice for this inference with torch.cpu.amp.autocast(enabled=config['precision']=='bfloat16'): inputs = transformed_image.contiguous(memory_format=torch.channels_last) if config['precision'] == 'bfloat16': inputs = inputs.to(torch.bfloat16) features = partial_model(inputs)[config['model']['layer']] pool_out = torch.nn.functional.avg_pool2d(features, config['model']['pool']) if config['model']['pool'] > 1 else features outputs = pool_out.contiguous().view(pool_out.size(0), -1) oi = outputs oi_or = oi oi_j = pca_kernel.transform(oi) oi_reconstructed = pca_kernel.inverse_transform(oi_j) fre = torch.square(oi_or - oi_reconstructed).reshape(outputs.shape) fre_score = torch.sum(fre, dim=1) score = -fre_score.item() # Get the single scalar score # Revert category_type in config if it was changed (good practice, though not strictly needed for Gradio) config['dataset']['category_type'] = original_category_config # Simple anomaly threshold for display # You might want to get a threshold from your eval.yaml or a pre-computed one # For now, a simple rule: if score is very low (highly negative), it's anomalous. # This threshold is illustrative and should be determined from training/validation. ANOMALY_THRESHOLD = -100.0 # Example threshold, adjust based on your model's score range status = "Anomaly Detected!" if score < ANOMALY_THRESHOLD else "Normal" return f"Status: {status} | Anomaly Score: {score:.4f}", input_image # Get available categories from the data directory DATA_ROOT_DIR = config['dataset']['root_dir'] # Ensure DATA_ROOT_DIR exists before listing if not os.path.isdir(DATA_ROOT_DIR): print(f"Warning: Data root directory '{DATA_ROOT_DIR}' not found. Falling back to default categories.") available_categories = ["bottle", "cable", "capsule", "carpet", "grid", "hazelnut", "leather", "metal_nut", "pill", "screw", "tile", "toothbrush", "transistor", "wood", "zipper"] else: available_categories = [ os.path.basename(d) for d in os.listdir(DATA_ROOT_DIR) if os.path.isdir(os.path.join(DATA_ROOT_DIR, d)) and d not in ['ground_truth'] # Exclude ground_truth if it's a top-level dir ] available_categories.sort() if not available_categories: available_categories = ["bottle"] # Final fallback if no categories found # --- Gradio Interface --- iface = gr.Interface( fn=predict_anomaly, inputs=[ gr.Image(type="pil", label="Upload Image for Anomaly Detection"), gr.Dropdown(choices=available_categories, label="Select Category", value=available_categories[0] if available_categories else "bottle") ], outputs=[ gr.Textbox(label="Anomaly Detection Result"), gr.Image(type="pil", label="Input Image") ], title="Visual Anomaly Detection (SimSiam + PCA)", description="Upload an image and select its category to detect anomalies using a pre-trained SimSiam model with PCA-based anomaly scoring. Note: The anomaly threshold is illustrative and may need tuning." ) iface.launch()