|
import pandas as pd |
|
from .outbreak_detection import ( |
|
LSTMforOutbreakDetection, |
|
ARIMAforOutbreakDetection, |
|
IQRforOutbreakDetection |
|
) |
|
from .plotting.visualization import plot_anomalies |
|
from .utils import prepare_time_series_dataframe |
|
|
|
|
|
THRESHOLD_METHODS = { |
|
"IQR on (ground truth - forecast)": 0, |
|
"IQR on |ground truth - forecast|": 1, |
|
"IQR on |ground truth - forecast|/forecast": 2, |
|
"Percentile threshold on absolute loss": 3, |
|
"Percentile threshold on raw loss": 4 |
|
} |
|
|
|
def detect_anomalies(file_path: str, method: str, k: int, percentile: float, threshold_method: int): |
|
""" |
|
Detects anomalies in time series data using various detection methods. |
|
Args: |
|
file_path (str): Path to the CSV file containing time series data |
|
method (str): Detection method to use ('LSTM', 'ARIMA', or 'IQR') |
|
k (int): Number of neighbors or window size (method-dependent parameter) |
|
percentile (float): Percentile threshold for anomaly detection |
|
threshold_method (int): Method to determine threshold for anomaly detection |
|
Returns: |
|
plotly.graph_objects.Figure: Plotly figure containing the time series with highlighted anomalies |
|
""" |
|
df = pd.read_csv(file_path) |
|
df = prepare_time_series_dataframe(df) |
|
|
|
|
|
|
|
detectors = { |
|
'LSTM': LSTMforOutbreakDetection( |
|
checkpoint_path='models/lstm_forec_40_11_06.pth', |
|
k=k, |
|
percentile=percentile, |
|
threshold_method=THRESHOLD_METHODS[threshold_method] |
|
), |
|
'ARIMA': ARIMAforOutbreakDetection(k=k), |
|
'IQR': IQRforOutbreakDetection(k=k) |
|
} |
|
|
|
detector = detectors[method] |
|
test, new_label = detector.detect_anomalies(df) |
|
return plot_anomalies(test, anomaly_col=new_label) |
|
|
|
|
|
|