ParamDev commited on
Commit
56f90b5
·
verified ·
1 Parent(s): ee6aae3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -156
app.py CHANGED
@@ -1,161 +1,153 @@
1
- import numpy as np
2
- import matplotlib.pyplot as plt
3
- from threading import Thread
4
- from matplotlib.colors import ListedColormap
5
- from sklearn.datasets import make_moons, make_circles, make_classification
6
- from sklearn.datasets import make_blobs, make_circles, make_moons
7
  import gradio as gr
8
- import math
9
- from functools import partial
10
- import time
11
-
12
- import matplotlib
13
-
14
- from sklearn import svm
15
- from sklearn.datasets import make_moons, make_blobs
16
- from sklearn.covariance import EllipticEnvelope
17
- from sklearn.ensemble import IsolationForest
18
- from sklearn.neighbors import LocalOutlierFactor
19
- from sklearn.linear_model import SGDOneClassSVM
20
- from sklearn.kernel_approximation import Nystroem
21
- from sklearn.pipeline import make_pipeline
22
-
23
- def get_groundtruth_model(X, labels):
24
- # dummy model to show true label distribution
25
- class Dummy:
26
- def __init__(self, y):
27
- self.labels_ = labels
28
-
29
- return Dummy(labels)
30
-
31
- #### PLOT
32
- FIGSIZE = 10,10
33
- figure = plt.figure(figsize=(25, 10))
34
-
35
-
36
- def train_models(input_data, outliers_fraction, n_samples, clf_name):
37
- n_outliers = int(outliers_fraction * n_samples)
38
- n_inliers = n_samples - n_outliers
39
- blobs_params = dict(random_state=0, n_samples=n_inliers, n_features=2)
40
- NAME_CLF_MAPPING = {"Robust covariance": EllipticEnvelope(contamination=outliers_fraction),
41
- "One-Class SVM": svm.OneClassSVM(nu=outliers_fraction, kernel="rbf", gamma=0.1),
42
- "One-Class SVM (SGD)":make_pipeline(
43
- Nystroem(gamma=0.1, random_state=42, n_components=150),
44
- SGDOneClassSVM(
45
- nu=outliers_fraction,
46
- shuffle=True,
47
- fit_intercept=True,
48
- random_state=42,
49
- tol=1e-6,
50
- ),
51
- ),
52
- "Isolation Forest": IsolationForest(contamination=outliers_fraction, random_state=42),
53
- "Local Outlier Factor": LocalOutlierFactor(n_neighbors=35, contamination=outliers_fraction),
54
- }
55
- DATA_MAPPING = {
56
- "Central Blob":make_blobs(centers=[[0, 0], [0, 0]], cluster_std=0.5, **blobs_params)[0],
57
- "Two Blobs": make_blobs(centers=[[2, 2], [-2, -2]], cluster_std=[0.5, 0.5], **blobs_params)[0],
58
- "Blob with Noise": make_blobs(centers=[[2, 2], [-2, -2]], cluster_std=[1.5, 0.3], **blobs_params)[0],
59
- "Moons": 4.0
60
- * (
61
- make_moons(n_samples=n_samples, noise=0.05, random_state=0)[0]
62
- - np.array([0.5, 0.25])
63
- ),
64
- "Noise": 14.0 * (np.random.RandomState(42).rand(n_samples, 2) - 0.5),
65
- }
66
- DATASETS = [
67
- make_blobs(centers=[[0, 0], [0, 0]], cluster_std=0.5, **blobs_params)[0],
68
- make_blobs(centers=[[2, 2], [-2, -2]], cluster_std=[0.5, 0.5], **blobs_params)[0],
69
- make_blobs(centers=[[2, 2], [-2, -2]], cluster_std=[1.5, 0.3], **blobs_params)[0],
70
- 4.0
71
- * (
72
- make_moons(n_samples=n_samples, noise=0.05, random_state=0)[0]
73
- - np.array([0.5, 0.25])
74
- ),
75
- 14.0 * (np.random.RandomState(42).rand(n_samples, 2) - 0.5),
76
- ]
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- xx, yy = np.meshgrid(np.linspace(-7, 7, 150), np.linspace(-7, 7, 150))
79
- clf = NAME_CLF_MAPPING[clf_name]
80
- plt.figure(figsize=(len(NAME_CLF_MAPPING) * 2 + 4, 12.5))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
 
82
 
83
- plot_num = 1
84
- rng = np.random.RandomState(42)
85
- X = DATA_MAPPING[input_data]
86
- X = np.concatenate([X, rng.uniform(low=-6, high=6, size=(n_outliers, 2))], axis=0)
87
-
88
- t0 = time.time()
89
- clf.fit(X)
90
- t1 = time.time()
91
- # fit the data and tag outliers
92
- if clf_name == "Local Outlier Factor":
93
- y_pred = clf.fit_predict(X)
94
- else:
95
- y_pred = clf.fit(X).predict(X)
96
-
97
- # plot the levels lines and the points
98
- if clf_name != "Local Outlier Factor":
99
- Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
100
- Z = Z.reshape(xx.shape)
101
- plt.contour(xx, yy, Z, levels=[0], linewidths=10, colors="black")
102
-
103
- colors = np.array(["#377eb8", "#ff7f00"])
104
- plt.scatter(X[:, 0], X[:, 1], s=100, color=colors[(y_pred + 1) // 2])
105
-
106
- plt.xlim(-7, 7)
107
- plt.ylim(-7, 7)
108
- plt.xticks(())
109
- plt.yticks(())
110
- plt.text(
111
- 0.99,
112
- 0.01,
113
- ("%.2fs" % (t1 - t0)).lstrip("0"),
114
- transform=plt.gca().transAxes,
115
- size=60,
116
- horizontalalignment="right",
117
- )
118
- plot_num += 1
119
-
120
- return plt
121
-
122
- description = "Learn how different anomaly detection algorithms perform in different datasets."
123
-
124
- def iter_grid(n_rows, n_cols):
125
- # create a grid using gradio Block
126
- for _ in range(n_rows):
127
- with gr.Row():
128
- for _ in range(n_cols):
129
- with gr.Column():
130
- yield
131
-
132
- title = "🕵️‍♀️ compare anomaly detection algorithms 🕵️‍♂️"
133
- with gr.Blocks() as demo:
134
- gr.Markdown(f"## {title}")
135
- gr.Markdown(description)
136
-
137
- input_models = ["Robust covariance","One-Class SVM","One-Class SVM (SGD)","Isolation Forest",
138
- "Local Outlier Factor"]
139
- input_data = gr.Radio(
140
- choices=["Central Blob", "Two Blobs", "Blob with Noise", "Moons", "Noise"],
141
- value="Moons"
142
- )
143
- n_samples = gr.Slider(minimum=100, maximum=500, step=25, label="Number of Samples")
144
- outliers_fraction = gr.Slider(minimum=0.1, maximum=0.9, step=0.1, label="Fraction of Outliers")
145
- counter = 0
146
-
147
-
148
- for _ in iter_grid(5, 5):
149
- if counter >= len(input_models):
150
- break
151
-
152
- input_model = input_models[counter]
153
- plot = gr.Plot(label=input_model)
154
- fn = partial(train_models, clf_name=input_model)
155
- input_data.change(fn=fn, inputs=[input_data, outliers_fraction, n_samples], outputs=plot)
156
- n_samples.change(fn=fn, inputs=[input_data, outliers_fraction, n_samples], outputs=plot)
157
- outliers_fraction.change(fn=fn, inputs=[input_data, outliers_fraction, n_samples], outputs=plot)
158
- counter += 1
159
-
160
- demo.launch(enable_queue=True, debug=True)
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ import sys
4
+ import yaml
5
+ import torch
6
+ from PIL import Image
7
+ import numpy as np
8
+ import torchvision.transforms as transforms # Added this import
9
+
10
+ # Add the project root to sys.path to allow imports from sibling directories
11
+ # Assuming app.py is in the root of the space, and visual-quality-inspection is a subdirectory
12
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'visual-quality-inspection')))
13
+
14
+ # Import your core anomaly detection functions
15
+ # Make sure these imports work relative to the sys.path adjustments
16
+ from visual_quality_inspection.anomaly_detection import load_custom_model, prepare_torchscript_model, inference_score, get_PCA_kernel, get_partial_model, get_train_features # Added get_partial_model, get_train_features
17
+ from visual_quality_inspection.dataset import Mvtec # Your custom dataset class
18
+
19
+ # --- Configuration Loading ---
20
+ # Define the path to your eval.yaml within the Space
21
+ CONFIG_FILE_PATH = 'visual-quality-inspection/configs/eval.yaml'
22
+ # Define the path where your model is located within the Space
23
+ MODEL_OUTPUT_PATH = 'visual-quality-inspection/models' # This should point to the 'models' directory you create
24
+
25
+ # Load config once at startup
26
+ with open(CONFIG_FILE_PATH, "r") as f:
27
+ config = yaml.safe_load(f)
28
+
29
+ # --- Global Model and PCA Kernel Loading (run once when the app starts) ---
30
+ # This ensures the model is loaded only once, not on every inference call.
31
+ print("Loading model and preparing PCA kernel...")
32
+
33
+ # Ensure the correct feature_extractor and category_type are set in config
34
+ # This assumes you've pre-modified eval.yaml or you set them here programmatically
35
+ # For this example, let's assume eval.yaml is already set to 'simsiam' and 'bottle' or 'all'
36
+ # If you need to override:
37
+ # config['model']['feature_extractor'] = 'simsiam'
38
+ # config['dataset']['category_type'] = 'bottle' # Or 'all' if you want to iterate
39
+
40
+ # Load the pre-trained model
41
+ model = load_custom_model(MODEL_OUTPUT_PATH, config)
42
+ if model is None:
43
+ raise RuntimeError("Failed to load the custom model. Check model path and file integrity.")
44
+
45
+ # Prepare a dummy dataset for feature shape inference and PCA training
46
+ current_category = config['dataset']['category_type']
47
+ if current_category == 'all':
48
+ print("Config category is 'all'. Using 'bottle' for initial PCA training for demo purposes.")
49
+ pca_train_category = 'bottle'
50
+ else:
51
+ pca_train_category = current_category
52
+
53
+ trainset = Mvtec(
54
+ root_dir=config['dataset']['root_dir'],
55
+ object_type=pca_train_category,
56
+ split='train',
57
+ im_size=config['dataset']['image_size']
58
+ )
59
+
60
+ partial_model, feature_shape = get_partial_model(model, trainset, config['model'])
61
+ model_ts = prepare_torchscript_model(partial_model, config)
62
+ train_features, _ = get_train_features(model_ts, trainset, feature_shape, config)
63
+ pca_kernel = get_PCA_kernel(train_features, config)
64
+
65
+ print("Model and PCA kernel loaded successfully.")
66
+
67
+ # --- Anomaly Detection Function for Gradio ---
68
+ def predict_anomaly(input_image: Image.Image, current_category_choice: str):
69
+ \"\"\"
70
+ Performs anomaly detection on a single input image for a chosen category.
71
+ \"\"\"
72
+ # Ensure the model is in evaluation mode
73
+ model.eval()
74
+
75
+ # Apply the same transformations as defined in Mvtec
76
+ im_size = config['dataset']['image_size']
77
+ transform = transforms.Compose([
78
+ transforms.Resize((im_size, im_size)),
79
+ transforms.ToTensor(),
80
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
81
+ ])
82
 
83
+ transformed_image = transform(input_image.convert('RGB')).unsqueeze(0) # Add batch dimension
84
+
85
+ # Dynamically update category in config for inference if 'all' is chosen or new category
86
+ # Note: This config change is local to this function call and won't affect global `config`
87
+ # for subsequent calls, which is fine for Gradio's stateless nature per call.
88
+ original_category_config = config['dataset']['category_type'] # Store original
89
+ config['dataset']['category_type'] = current_category_choice # Use user's choice for this inference
90
+
91
+ with torch.cpu.amp.autocast(enabled=config['precision']=='bfloat16'):
92
+ inputs = transformed_image.contiguous(memory_format=torch.channels_last)
93
+ if config['precision'] == 'bfloat16':
94
+ inputs = inputs.to(torch.bfloat16)
95
+
96
+ features = partial_model(inputs)[config['model']['layer']]
97
+ pool_out = torch.nn.functional.avg_pool2d(features, config['model']['pool']) if config['model']['pool'] > 1 else features
98
+ outputs = pool_out.contiguous().view(pool_out.size(0), -1)
99
+
100
+ oi = outputs
101
+ oi_or = oi
102
+ oi_j = pca_kernel.transform(oi)
103
+ oi_reconstructed = pca_kernel.inverse_transform(oi_j)
104
+ fre = torch.square(oi_or - oi_reconstructed).reshape(outputs.shape)
105
+ fre_score = torch.sum(fre, dim=1)
106
+ score = -fre_score.item() # Get the single scalar score
107
+
108
+ # Revert category_type in config if it was changed (good practice, though not strictly needed for Gradio)
109
+ config['dataset']['category_type'] = original_category_config
110
+
111
+ # Simple anomaly threshold for display
112
+ # You might want to get a threshold from your eval.yaml or a pre-computed one
113
+ # For now, a simple rule: if score is very low (highly negative), it's anomalous.
114
+ # This threshold is illustrative and should be determined from training/validation.
115
+ ANOMALY_THRESHOLD = -100.0 # Example threshold, adjust based on your model's score range
116
+
117
+ status = "Anomaly Detected!" if score < ANOMALY_THRESHOLD else "Normal"
118
 
119
+ return f"Status: {status} | Anomaly Score: {score:.4f}", input_image
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
+ # Get available categories from the data directory
123
+ DATA_ROOT_DIR = config['dataset']['root_dir']
124
+ # Ensure DATA_ROOT_DIR exists before listing
125
+ if not os.path.isdir(DATA_ROOT_DIR):
126
+ print(f"Warning: Data root directory '{DATA_ROOT_DIR}' not found. Falling back to default categories.")
127
+ available_categories = ["bottle", "cable", "capsule", "carpet", "grid", "hazelnut", "leather", "metal_nut", "pill", "screw", "tile", "toothbrush", "transistor", "wood", "zipper"]
128
+ else:
129
+ available_categories = [
130
+ os.path.basename(d) for d in os.listdir(DATA_ROOT_DIR)
131
+ if os.path.isdir(os.path.join(DATA_ROOT_DIR, d)) and d not in ['ground_truth'] # Exclude ground_truth if it's a top-level dir
132
+ ]
133
+ available_categories.sort()
134
+
135
+ if not available_categories:
136
+ available_categories = ["bottle"] # Final fallback if no categories found
137
+
138
+ # --- Gradio Interface ---
139
+ iface = gr.Interface(
140
+ fn=predict_anomaly,
141
+ inputs=[
142
+ gr.Image(type="pil", label="Upload Image for Anomaly Detection"),
143
+ gr.Dropdown(choices=available_categories, label="Select Category", value=available_categories[0] if available_categories else "bottle")
144
+ ],
145
+ outputs=[
146
+ gr.Textbox(label="Anomaly Detection Result"),
147
+ gr.Image(type="pil", label="Input Image")
148
+ ],
149
+ title="Visual Anomaly Detection (SimSiam + PCA)",
150
+ description="Upload an image and select its category to detect anomalies using a pre-trained SimSiam model with PCA-based anomaly scoring. Note: The anomaly threshold is illustrative and may need tuning."
151
+ )
152
+
153
+ iface.launch()