File size: 1,856 Bytes
9485251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import pandas as pd
from .outbreak_detection import (
    LSTMforOutbreakDetection,
    ARIMAforOutbreakDetection,
    IQRforOutbreakDetection
)
from .plotting.visualization import plot_anomalies
from .utils import prepare_time_series_dataframe


THRESHOLD_METHODS = {
    "IQR on (ground truth - forecast)": 0,
    "IQR on |ground truth - forecast|": 1,
    "IQR on |ground truth - forecast|/forecast": 2,
    "Percentile threshold on absolute loss": 3,
    "Percentile threshold on raw loss": 4
}

def detect_anomalies(file_path: str, method: str, k: int, percentile: float, threshold_method: int):
    """
    Detects anomalies in time series data using various detection methods.
    Args:
        file_path (str): Path to the CSV file containing time series data
        method (str): Detection method to use ('LSTM', 'ARIMA', or 'IQR')
        k (int): Number of neighbors or window size (method-dependent parameter)
        percentile (float): Percentile threshold for anomaly detection
        threshold_method (int): Method to determine threshold for anomaly detection
    Returns:
        plotly.graph_objects.Figure: Plotly figure containing the time series with highlighted anomalies
    """
    df = pd.read_csv(file_path)
    df = prepare_time_series_dataframe(df)

    # Map threshold methods to their descriptions for better readability
    
    detectors = {
        'LSTM': LSTMforOutbreakDetection(
            checkpoint_path='models/lstm_forec_40_11_06.pth', 
            k=k,
            percentile=percentile, 
            threshold_method=THRESHOLD_METHODS[threshold_method]
        ),
        'ARIMA': ARIMAforOutbreakDetection(k=k),
        'IQR': IQRforOutbreakDetection(k=k)
    }
    
    detector = detectors[method]
    test, new_label = detector.detect_anomalies(df)
    return plot_anomalies(test, anomaly_col=new_label)