File size: 7,221 Bytes
46f679b
56f90b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46f679b
56f90b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46f679b
56f90b5
46f679b
 
56f90b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import gradio as gr
import os
import sys
import yaml
import torch
from PIL import Image
import numpy as np
import torchvision.transforms as transforms # Added this import

# Add the project root to sys.path to allow imports from sibling directories
# Assuming app.py is in the root of the space, and visual-quality-inspection is a subdirectory
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'visual-quality-inspection')))

# Import your core anomaly detection functions
# Make sure these imports work relative to the sys.path adjustments
from visual_quality_inspection.anomaly_detection import load_custom_model, prepare_torchscript_model, inference_score, get_PCA_kernel, get_partial_model, get_train_features # Added get_partial_model, get_train_features
from visual_quality_inspection.dataset import Mvtec # Your custom dataset class

# --- Configuration Loading ---
# Define the path to your eval.yaml within the Space
CONFIG_FILE_PATH = 'visual-quality-inspection/configs/eval.yaml'
# Define the path where your model is located within the Space
MODEL_OUTPUT_PATH = 'visual-quality-inspection/models' # This should point to the 'models' directory you create

# Load config once at startup
with open(CONFIG_FILE_PATH, "r") as f:
    config = yaml.safe_load(f)

# --- Global Model and PCA Kernel Loading (run once when the app starts) ---
# This ensures the model is loaded only once, not on every inference call.
print("Loading model and preparing PCA kernel...")

# Ensure the correct feature_extractor and category_type are set in config
# This assumes you've pre-modified eval.yaml or you set them here programmatically
# For this example, let's assume eval.yaml is already set to 'simsiam' and 'bottle' or 'all'
# If you need to override:
# config['model']['feature_extractor'] = 'simsiam'
# config['dataset']['category_type'] = 'bottle' # Or 'all' if you want to iterate

# Load the pre-trained model
model = load_custom_model(MODEL_OUTPUT_PATH, config)
if model is None:
    raise RuntimeError("Failed to load the custom model. Check model path and file integrity.")

# Prepare a dummy dataset for feature shape inference and PCA training
current_category = config['dataset']['category_type']
if current_category == 'all':
    print("Config category is 'all'. Using 'bottle' for initial PCA training for demo purposes.")
    pca_train_category = 'bottle'
else:
    pca_train_category = current_category

trainset = Mvtec(
    root_dir=config['dataset']['root_dir'],
    object_type=pca_train_category,
    split='train',
    im_size=config['dataset']['image_size']
)

partial_model, feature_shape = get_partial_model(model, trainset, config['model'])
model_ts = prepare_torchscript_model(partial_model, config)
train_features, _ = get_train_features(model_ts, trainset, feature_shape, config)
pca_kernel = get_PCA_kernel(train_features, config)

print("Model and PCA kernel loaded successfully.")

# --- Anomaly Detection Function for Gradio ---
def predict_anomaly(input_image: Image.Image, current_category_choice: str):
    \"\"\"
    Performs anomaly detection on a single input image for a chosen category.
    \"\"\"
    # Ensure the model is in evaluation mode
    model.eval()

    # Apply the same transformations as defined in Mvtec
    im_size = config['dataset']['image_size']
    transform = transforms.Compose([
        transforms.Resize((im_size, im_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    transformed_image = transform(input_image.convert('RGB')).unsqueeze(0) # Add batch dimension

    # Dynamically update category in config for inference if 'all' is chosen or new category
    # Note: This config change is local to this function call and won't affect global `config`
    # for subsequent calls, which is fine for Gradio's stateless nature per call.
    original_category_config = config['dataset']['category_type'] # Store original
    config['dataset']['category_type'] = current_category_choice # Use user's choice for this inference

    with torch.cpu.amp.autocast(enabled=config['precision']=='bfloat16'):
        inputs = transformed_image.contiguous(memory_format=torch.channels_last)
        if config['precision'] == 'bfloat16':
            inputs = inputs.to(torch.bfloat16)

        features = partial_model(inputs)[config['model']['layer']]
        pool_out = torch.nn.functional.avg_pool2d(features, config['model']['pool']) if config['model']['pool'] > 1 else features
        outputs = pool_out.contiguous().view(pool_out.size(0), -1)

        oi = outputs
        oi_or = oi
        oi_j = pca_kernel.transform(oi)
        oi_reconstructed = pca_kernel.inverse_transform(oi_j)
        fre = torch.square(oi_or - oi_reconstructed).reshape(outputs.shape)
        fre_score = torch.sum(fre, dim=1)
        score = -fre_score.item() # Get the single scalar score

    # Revert category_type in config if it was changed (good practice, though not strictly needed for Gradio)
    config['dataset']['category_type'] = original_category_config

    # Simple anomaly threshold for display
    # You might want to get a threshold from your eval.yaml or a pre-computed one
    # For now, a simple rule: if score is very low (highly negative), it's anomalous.
    # This threshold is illustrative and should be determined from training/validation.
    ANOMALY_THRESHOLD = -100.0 # Example threshold, adjust based on your model's score range
    
    status = "Anomaly Detected!" if score < ANOMALY_THRESHOLD else "Normal"

    return f"Status: {status} | Anomaly Score: {score:.4f}", input_image


# Get available categories from the data directory
DATA_ROOT_DIR = config['dataset']['root_dir']
# Ensure DATA_ROOT_DIR exists before listing
if not os.path.isdir(DATA_ROOT_DIR):
    print(f"Warning: Data root directory '{DATA_ROOT_DIR}' not found. Falling back to default categories.")
    available_categories = ["bottle", "cable", "capsule", "carpet", "grid", "hazelnut", "leather", "metal_nut", "pill", "screw", "tile", "toothbrush", "transistor", "wood", "zipper"]
else:
    available_categories = [
        os.path.basename(d) for d in os.listdir(DATA_ROOT_DIR)
        if os.path.isdir(os.path.join(DATA_ROOT_DIR, d)) and d not in ['ground_truth'] # Exclude ground_truth if it's a top-level dir
    ]
    available_categories.sort()

if not available_categories:
    available_categories = ["bottle"] # Final fallback if no categories found

# --- Gradio Interface ---
iface = gr.Interface(
    fn=predict_anomaly,
    inputs=[
        gr.Image(type="pil", label="Upload Image for Anomaly Detection"),
        gr.Dropdown(choices=available_categories, label="Select Category", value=available_categories[0] if available_categories else "bottle")
    ],
    outputs=[
        gr.Textbox(label="Anomaly Detection Result"),
        gr.Image(type="pil", label="Input Image")
    ],
    title="Visual Anomaly Detection (SimSiam + PCA)",
    description="Upload an image and select its category to detect anomalies using a pre-trained SimSiam model with PCA-based anomaly scoring. Note: The anomaly threshold is illustrative and may need tuning."
)

iface.launch()