JuanJoseMV's picture
add methods for each strategy
9485251
import pandas as pd
from .outbreak_detection import (
LSTMforOutbreakDetection,
ARIMAforOutbreakDetection,
IQRforOutbreakDetection
)
from .plotting.visualization import plot_anomalies
from .utils import prepare_time_series_dataframe
THRESHOLD_METHODS = {
"IQR on (ground truth - forecast)": 0,
"IQR on |ground truth - forecast|": 1,
"IQR on |ground truth - forecast|/forecast": 2,
"Percentile threshold on absolute loss": 3,
"Percentile threshold on raw loss": 4
}
def detect_anomalies(file_path: str, method: str, k: int, percentile: float, threshold_method: int):
"""
Detects anomalies in time series data using various detection methods.
Args:
file_path (str): Path to the CSV file containing time series data
method (str): Detection method to use ('LSTM', 'ARIMA', or 'IQR')
k (int): Number of neighbors or window size (method-dependent parameter)
percentile (float): Percentile threshold for anomaly detection
threshold_method (int): Method to determine threshold for anomaly detection
Returns:
plotly.graph_objects.Figure: Plotly figure containing the time series with highlighted anomalies
"""
df = pd.read_csv(file_path)
df = prepare_time_series_dataframe(df)
# Map threshold methods to their descriptions for better readability
detectors = {
'LSTM': LSTMforOutbreakDetection(
checkpoint_path='models/lstm_forec_40_11_06.pth',
k=k,
percentile=percentile,
threshold_method=THRESHOLD_METHODS[threshold_method]
),
'ARIMA': ARIMAforOutbreakDetection(k=k),
'IQR': IQRforOutbreakDetection(k=k)
}
detector = detectors[method]
test, new_label = detector.detect_anomalies(df)
return plot_anomalies(test, anomaly_col=new_label)