Spaces:
Configuration error
Configuration error
Update app.py
Browse files
app.py
CHANGED
@@ -1,161 +1,153 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import matplotlib.pyplot as plt
|
3 |
-
from threading import Thread
|
4 |
-
from matplotlib.colors import ListedColormap
|
5 |
-
from sklearn.datasets import make_moons, make_circles, make_classification
|
6 |
-
from sklearn.datasets import make_blobs, make_circles, make_moons
|
7 |
import gradio as gr
|
8 |
-
import
|
9 |
-
|
10 |
-
import
|
11 |
-
|
12 |
-
import
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
|
|
82 |
|
83 |
-
plot_num = 1
|
84 |
-
rng = np.random.RandomState(42)
|
85 |
-
X = DATA_MAPPING[input_data]
|
86 |
-
X = np.concatenate([X, rng.uniform(low=-6, high=6, size=(n_outliers, 2))], axis=0)
|
87 |
-
|
88 |
-
t0 = time.time()
|
89 |
-
clf.fit(X)
|
90 |
-
t1 = time.time()
|
91 |
-
# fit the data and tag outliers
|
92 |
-
if clf_name == "Local Outlier Factor":
|
93 |
-
y_pred = clf.fit_predict(X)
|
94 |
-
else:
|
95 |
-
y_pred = clf.fit(X).predict(X)
|
96 |
-
|
97 |
-
# plot the levels lines and the points
|
98 |
-
if clf_name != "Local Outlier Factor":
|
99 |
-
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
|
100 |
-
Z = Z.reshape(xx.shape)
|
101 |
-
plt.contour(xx, yy, Z, levels=[0], linewidths=10, colors="black")
|
102 |
-
|
103 |
-
colors = np.array(["#377eb8", "#ff7f00"])
|
104 |
-
plt.scatter(X[:, 0], X[:, 1], s=100, color=colors[(y_pred + 1) // 2])
|
105 |
-
|
106 |
-
plt.xlim(-7, 7)
|
107 |
-
plt.ylim(-7, 7)
|
108 |
-
plt.xticks(())
|
109 |
-
plt.yticks(())
|
110 |
-
plt.text(
|
111 |
-
0.99,
|
112 |
-
0.01,
|
113 |
-
("%.2fs" % (t1 - t0)).lstrip("0"),
|
114 |
-
transform=plt.gca().transAxes,
|
115 |
-
size=60,
|
116 |
-
horizontalalignment="right",
|
117 |
-
)
|
118 |
-
plot_num += 1
|
119 |
-
|
120 |
-
return plt
|
121 |
-
|
122 |
-
description = "Learn how different anomaly detection algorithms perform in different datasets."
|
123 |
-
|
124 |
-
def iter_grid(n_rows, n_cols):
|
125 |
-
# create a grid using gradio Block
|
126 |
-
for _ in range(n_rows):
|
127 |
-
with gr.Row():
|
128 |
-
for _ in range(n_cols):
|
129 |
-
with gr.Column():
|
130 |
-
yield
|
131 |
-
|
132 |
-
title = "🕵️♀️ compare anomaly detection algorithms 🕵️♂️"
|
133 |
-
with gr.Blocks() as demo:
|
134 |
-
gr.Markdown(f"## {title}")
|
135 |
-
gr.Markdown(description)
|
136 |
-
|
137 |
-
input_models = ["Robust covariance","One-Class SVM","One-Class SVM (SGD)","Isolation Forest",
|
138 |
-
"Local Outlier Factor"]
|
139 |
-
input_data = gr.Radio(
|
140 |
-
choices=["Central Blob", "Two Blobs", "Blob with Noise", "Moons", "Noise"],
|
141 |
-
value="Moons"
|
142 |
-
)
|
143 |
-
n_samples = gr.Slider(minimum=100, maximum=500, step=25, label="Number of Samples")
|
144 |
-
outliers_fraction = gr.Slider(minimum=0.1, maximum=0.9, step=0.1, label="Fraction of Outliers")
|
145 |
-
counter = 0
|
146 |
-
|
147 |
-
|
148 |
-
for _ in iter_grid(5, 5):
|
149 |
-
if counter >= len(input_models):
|
150 |
-
break
|
151 |
-
|
152 |
-
input_model = input_models[counter]
|
153 |
-
plot = gr.Plot(label=input_model)
|
154 |
-
fn = partial(train_models, clf_name=input_model)
|
155 |
-
input_data.change(fn=fn, inputs=[input_data, outliers_fraction, n_samples], outputs=plot)
|
156 |
-
n_samples.change(fn=fn, inputs=[input_data, outliers_fraction, n_samples], outputs=plot)
|
157 |
-
outliers_fraction.change(fn=fn, inputs=[input_data, outliers_fraction, n_samples], outputs=plot)
|
158 |
-
counter += 1
|
159 |
-
|
160 |
-
demo.launch(enable_queue=True, debug=True)
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import yaml
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
+
import torchvision.transforms as transforms # Added this import
|
9 |
+
|
10 |
+
# Add the project root to sys.path to allow imports from sibling directories
|
11 |
+
# Assuming app.py is in the root of the space, and visual-quality-inspection is a subdirectory
|
12 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'visual-quality-inspection')))
|
13 |
+
|
14 |
+
# Import your core anomaly detection functions
|
15 |
+
# Make sure these imports work relative to the sys.path adjustments
|
16 |
+
from visual_quality_inspection.anomaly_detection import load_custom_model, prepare_torchscript_model, inference_score, get_PCA_kernel, get_partial_model, get_train_features # Added get_partial_model, get_train_features
|
17 |
+
from visual_quality_inspection.dataset import Mvtec # Your custom dataset class
|
18 |
+
|
19 |
+
# --- Configuration Loading ---
|
20 |
+
# Define the path to your eval.yaml within the Space
|
21 |
+
CONFIG_FILE_PATH = 'visual-quality-inspection/configs/eval.yaml'
|
22 |
+
# Define the path where your model is located within the Space
|
23 |
+
MODEL_OUTPUT_PATH = 'visual-quality-inspection/models' # This should point to the 'models' directory you create
|
24 |
+
|
25 |
+
# Load config once at startup
|
26 |
+
with open(CONFIG_FILE_PATH, "r") as f:
|
27 |
+
config = yaml.safe_load(f)
|
28 |
+
|
29 |
+
# --- Global Model and PCA Kernel Loading (run once when the app starts) ---
|
30 |
+
# This ensures the model is loaded only once, not on every inference call.
|
31 |
+
print("Loading model and preparing PCA kernel...")
|
32 |
+
|
33 |
+
# Ensure the correct feature_extractor and category_type are set in config
|
34 |
+
# This assumes you've pre-modified eval.yaml or you set them here programmatically
|
35 |
+
# For this example, let's assume eval.yaml is already set to 'simsiam' and 'bottle' or 'all'
|
36 |
+
# If you need to override:
|
37 |
+
# config['model']['feature_extractor'] = 'simsiam'
|
38 |
+
# config['dataset']['category_type'] = 'bottle' # Or 'all' if you want to iterate
|
39 |
+
|
40 |
+
# Load the pre-trained model
|
41 |
+
model = load_custom_model(MODEL_OUTPUT_PATH, config)
|
42 |
+
if model is None:
|
43 |
+
raise RuntimeError("Failed to load the custom model. Check model path and file integrity.")
|
44 |
+
|
45 |
+
# Prepare a dummy dataset for feature shape inference and PCA training
|
46 |
+
current_category = config['dataset']['category_type']
|
47 |
+
if current_category == 'all':
|
48 |
+
print("Config category is 'all'. Using 'bottle' for initial PCA training for demo purposes.")
|
49 |
+
pca_train_category = 'bottle'
|
50 |
+
else:
|
51 |
+
pca_train_category = current_category
|
52 |
+
|
53 |
+
trainset = Mvtec(
|
54 |
+
root_dir=config['dataset']['root_dir'],
|
55 |
+
object_type=pca_train_category,
|
56 |
+
split='train',
|
57 |
+
im_size=config['dataset']['image_size']
|
58 |
+
)
|
59 |
+
|
60 |
+
partial_model, feature_shape = get_partial_model(model, trainset, config['model'])
|
61 |
+
model_ts = prepare_torchscript_model(partial_model, config)
|
62 |
+
train_features, _ = get_train_features(model_ts, trainset, feature_shape, config)
|
63 |
+
pca_kernel = get_PCA_kernel(train_features, config)
|
64 |
+
|
65 |
+
print("Model and PCA kernel loaded successfully.")
|
66 |
+
|
67 |
+
# --- Anomaly Detection Function for Gradio ---
|
68 |
+
def predict_anomaly(input_image: Image.Image, current_category_choice: str):
|
69 |
+
\"\"\"
|
70 |
+
Performs anomaly detection on a single input image for a chosen category.
|
71 |
+
\"\"\"
|
72 |
+
# Ensure the model is in evaluation mode
|
73 |
+
model.eval()
|
74 |
+
|
75 |
+
# Apply the same transformations as defined in Mvtec
|
76 |
+
im_size = config['dataset']['image_size']
|
77 |
+
transform = transforms.Compose([
|
78 |
+
transforms.Resize((im_size, im_size)),
|
79 |
+
transforms.ToTensor(),
|
80 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
81 |
+
])
|
82 |
|
83 |
+
transformed_image = transform(input_image.convert('RGB')).unsqueeze(0) # Add batch dimension
|
84 |
+
|
85 |
+
# Dynamically update category in config for inference if 'all' is chosen or new category
|
86 |
+
# Note: This config change is local to this function call and won't affect global `config`
|
87 |
+
# for subsequent calls, which is fine for Gradio's stateless nature per call.
|
88 |
+
original_category_config = config['dataset']['category_type'] # Store original
|
89 |
+
config['dataset']['category_type'] = current_category_choice # Use user's choice for this inference
|
90 |
+
|
91 |
+
with torch.cpu.amp.autocast(enabled=config['precision']=='bfloat16'):
|
92 |
+
inputs = transformed_image.contiguous(memory_format=torch.channels_last)
|
93 |
+
if config['precision'] == 'bfloat16':
|
94 |
+
inputs = inputs.to(torch.bfloat16)
|
95 |
+
|
96 |
+
features = partial_model(inputs)[config['model']['layer']]
|
97 |
+
pool_out = torch.nn.functional.avg_pool2d(features, config['model']['pool']) if config['model']['pool'] > 1 else features
|
98 |
+
outputs = pool_out.contiguous().view(pool_out.size(0), -1)
|
99 |
+
|
100 |
+
oi = outputs
|
101 |
+
oi_or = oi
|
102 |
+
oi_j = pca_kernel.transform(oi)
|
103 |
+
oi_reconstructed = pca_kernel.inverse_transform(oi_j)
|
104 |
+
fre = torch.square(oi_or - oi_reconstructed).reshape(outputs.shape)
|
105 |
+
fre_score = torch.sum(fre, dim=1)
|
106 |
+
score = -fre_score.item() # Get the single scalar score
|
107 |
+
|
108 |
+
# Revert category_type in config if it was changed (good practice, though not strictly needed for Gradio)
|
109 |
+
config['dataset']['category_type'] = original_category_config
|
110 |
+
|
111 |
+
# Simple anomaly threshold for display
|
112 |
+
# You might want to get a threshold from your eval.yaml or a pre-computed one
|
113 |
+
# For now, a simple rule: if score is very low (highly negative), it's anomalous.
|
114 |
+
# This threshold is illustrative and should be determined from training/validation.
|
115 |
+
ANOMALY_THRESHOLD = -100.0 # Example threshold, adjust based on your model's score range
|
116 |
+
|
117 |
+
status = "Anomaly Detected!" if score < ANOMALY_THRESHOLD else "Normal"
|
118 |
|
119 |
+
return f"Status: {status} | Anomaly Score: {score:.4f}", input_image
|
120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
+
# Get available categories from the data directory
|
123 |
+
DATA_ROOT_DIR = config['dataset']['root_dir']
|
124 |
+
# Ensure DATA_ROOT_DIR exists before listing
|
125 |
+
if not os.path.isdir(DATA_ROOT_DIR):
|
126 |
+
print(f"Warning: Data root directory '{DATA_ROOT_DIR}' not found. Falling back to default categories.")
|
127 |
+
available_categories = ["bottle", "cable", "capsule", "carpet", "grid", "hazelnut", "leather", "metal_nut", "pill", "screw", "tile", "toothbrush", "transistor", "wood", "zipper"]
|
128 |
+
else:
|
129 |
+
available_categories = [
|
130 |
+
os.path.basename(d) for d in os.listdir(DATA_ROOT_DIR)
|
131 |
+
if os.path.isdir(os.path.join(DATA_ROOT_DIR, d)) and d not in ['ground_truth'] # Exclude ground_truth if it's a top-level dir
|
132 |
+
]
|
133 |
+
available_categories.sort()
|
134 |
+
|
135 |
+
if not available_categories:
|
136 |
+
available_categories = ["bottle"] # Final fallback if no categories found
|
137 |
+
|
138 |
+
# --- Gradio Interface ---
|
139 |
+
iface = gr.Interface(
|
140 |
+
fn=predict_anomaly,
|
141 |
+
inputs=[
|
142 |
+
gr.Image(type="pil", label="Upload Image for Anomaly Detection"),
|
143 |
+
gr.Dropdown(choices=available_categories, label="Select Category", value=available_categories[0] if available_categories else "bottle")
|
144 |
+
],
|
145 |
+
outputs=[
|
146 |
+
gr.Textbox(label="Anomaly Detection Result"),
|
147 |
+
gr.Image(type="pil", label="Input Image")
|
148 |
+
],
|
149 |
+
title="Visual Anomaly Detection (SimSiam + PCA)",
|
150 |
+
description="Upload an image and select its category to detect anomalies using a pre-trained SimSiam model with PCA-based anomaly scoring. Note: The anomaly threshold is illustrative and may need tuning."
|
151 |
+
)
|
152 |
+
|
153 |
+
iface.launch()
|