JuanJoseMV commited on
Commit
9485251
·
1 Parent(s): 0d5a7ab

add methods for each strategy

Browse files
.gitignore CHANGED
@@ -16,4 +16,7 @@ build/
16
 
17
  # VSCode
18
  .vscode/
19
- *.code-workspace
 
 
 
 
16
 
17
  # VSCode
18
  .vscode/
19
+ *.code-workspace
20
+
21
+ # Solution template
22
+ solution_template/
app.py CHANGED
@@ -1,5 +1,8 @@
1
  import gradio as gr
2
- from src.visualization import plot_time_series
 
 
 
3
 
4
  PROJECT_HTML_PATH = "sections/project_description.html"
5
  WELCOME_HTML_PATH = "sections/welcome_section.html"
@@ -26,11 +29,48 @@ with gr.Blocks(css=blocks_css) as demo:
26
  gr.Markdown(demo_section_html)
27
  gr.HTML(try_it_yourself_html)
28
 
29
- file_input = gr.File(label="Upload Time Series CSV")
 
 
 
 
30
  plot_btn = gr.Button("Plot Time Series")
31
  plot_output = gr.Plot(label="Time Series Plot")
32
 
33
- method_input = gr.Dropdown(choices=ANOMALY_METHODS, label="Select Anomaly Detection Method", value="LSTM")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  analyze_btn = gr.Button("Detect Anomalies")
35
  anomaly_output = gr.Plot(label="Anomaly Detection Results")
36
 
@@ -40,10 +80,16 @@ with gr.Blocks(css=blocks_css) as demo:
40
  outputs=[plot_output]
41
  )
42
 
43
- # analyze_btn.click(
44
- # fn=detect_anomalies,
45
- # inputs=[file_input, method_input],
46
- # outputs=[anomaly_output]
47
- # )
 
 
 
 
 
 
48
 
49
  demo.launch(show_api=False)
 
1
  import gradio as gr
2
+ from src import plot_time_series
3
+ from src import detect_anomalies
4
+ from src.utils import update_controls
5
+
6
 
7
  PROJECT_HTML_PATH = "sections/project_description.html"
8
  WELCOME_HTML_PATH = "sections/welcome_section.html"
 
29
  gr.Markdown(demo_section_html)
30
  gr.HTML(try_it_yourself_html)
31
 
32
+ file_input = gr.File(
33
+ label="Upload Time Series CSV",
34
+ file_types=[".csv"],
35
+ value="examples/mpox.csv" # Set default example file
36
+ )
37
  plot_btn = gr.Button("Plot Time Series")
38
  plot_output = gr.Plot(label="Time Series Plot")
39
 
40
+ with gr.Row():
41
+ method_input = gr.Dropdown(
42
+ choices=ANOMALY_METHODS,
43
+ interactive=True,
44
+ label="Select Anomaly Detection Method",
45
+ value="LSTM"
46
+ )
47
+ k_input = gr.Slider(
48
+ minimum=1,
49
+ maximum=3,
50
+ step=0.1,
51
+ label="k",
52
+ value=1.5
53
+ )
54
+ percentile_input = gr.Slider(
55
+ minimum=0,
56
+ maximum=100,
57
+ step=1,
58
+ label="Percentile",
59
+ value=95,
60
+ interactive=True
61
+ )
62
+ threshold_method_input = gr.Dropdown(
63
+ choices=[
64
+ "IQR on (ground truth - forecast)",
65
+ "IQR on |ground truth - forecast|",
66
+ "IQR on |ground truth - forecast|/forecast",
67
+ "Percentile threshold on absolute loss",
68
+ "Percentile threshold on raw loss"
69
+ ],
70
+ label="Threshold Method",
71
+ value="IQR on (ground truth - forecast)",
72
+ interactive=True
73
+ )
74
  analyze_btn = gr.Button("Detect Anomalies")
75
  anomaly_output = gr.Plot(label="Anomaly Detection Results")
76
 
 
80
  outputs=[plot_output]
81
  )
82
 
83
+ method_input.change(
84
+ fn=update_controls,
85
+ inputs=[method_input],
86
+ outputs=[percentile_input, threshold_method_input]
87
+ )
88
+
89
+ analyze_btn.click(
90
+ fn=detect_anomalies,
91
+ inputs=[file_input, method_input, k_input, percentile_input, threshold_method_input],
92
+ outputs=[anomaly_output]
93
+ )
94
 
95
  demo.launch(show_api=False)
examples/mpox.csv ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ date,news
2
+ 2022-04-22,15.0
3
+ 2022-04-23,6.0
4
+ 2022-04-24,3.0
5
+ 2022-04-25,8.0
6
+ 2022-04-26,15.0
7
+ 2022-04-27,19.0
8
+ 2022-04-28,16.0
9
+ 2022-04-29,12.0
10
+ 2022-04-30,5.0
11
+ 2022-05-01,5.0
12
+ 2022-05-02,8.0
13
+ 2022-05-03,6.0
14
+ 2022-05-04,11.0
15
+ 2022-05-05,4.0
16
+ 2022-05-06,9.0
17
+ 2022-05-07,8.0
18
+ 2022-05-08,3.0
19
+ 2022-05-09,15.0
20
+ 2022-05-10,18.0
21
+ 2022-05-11,8.0
22
+ 2022-05-12,15.0
23
+ 2022-05-13,11.0
24
+ 2022-05-14,7.0
25
+ 2022-05-15,10.0
26
+ 2022-05-16,12.0
27
+ 2022-05-17,25.0
28
+ 2022-05-18,45.0
29
+ 2022-05-19,163.0
30
+ 2022-05-20,271.0
31
+ 2022-05-21,146.0
32
+ 2022-05-22,152.0
33
+ 2022-05-23,308.0
34
+ 2022-05-24,259.0
35
+ 2022-05-25,174.0
36
+ 2022-05-26,163.0
37
+ 2022-05-27,116.0
38
+ 2022-05-28,64.0
39
+ 2022-05-29,29.0
40
+ 2022-05-30,80.0
41
+ 2022-05-31,75.0
42
+ 2022-06-01,74.0
43
+ 2022-06-02,79.0
44
+ 2022-06-03,76.0
45
+ 2022-06-04,57.0
46
+ 2022-06-05,24.0
47
+ 2022-06-06,62.0
48
+ 2022-06-07,72.0
49
+ 2022-06-08,74.0
50
+ 2022-06-09,75.0
51
+ 2022-06-10,77.0
52
+ 2022-06-11,31.0
53
+ 2022-06-12,18.0
54
+ 2022-06-13,42.0
55
+ 2022-06-14,52.0
56
+ 2022-06-15,94.0
57
+ 2022-06-16,66.0
58
+ 2022-06-17,44.0
59
+ 2022-06-18,15.0
60
+ 2022-06-19,22.0
61
+ 2022-06-20,24.0
62
+ 2022-06-21,30.0
63
+ 2022-06-22,39.0
64
+ 2022-06-23,71.0
65
+ 2022-06-24,50.0
66
+ 2022-06-25,39.0
67
+ 2022-06-26,51.0
68
+ 2022-06-27,36.0
69
+ 2022-06-28,45.0
70
+ 2022-06-29,79.0
71
+ 2022-06-30,57.0
72
+ 2022-07-01,73.0
73
+ 2022-07-02,42.0
74
+ 2022-07-03,13.0
75
+ 2022-07-04,23.0
76
+ 2022-07-05,38.0
77
+ 2022-07-06,47.0
78
+ 2022-07-07,60.0
79
+ 2022-07-08,63.0
80
+ 2022-07-09,53.0
81
+ 2022-07-10,15.0
82
+ 2022-07-11,39.0
83
+ 2022-07-12,62.0
84
+ 2022-07-13,53.0
85
+ 2022-07-14,41.0
86
+ 2022-07-15,72.0
87
+ 2022-07-16,35.0
88
+ 2022-07-17,25.0
89
+ 2022-07-18,56.0
90
+ 2022-07-19,50.0
91
+ 2022-07-20,45.0
92
+ 2022-07-21,88.0
93
+ 2022-07-22,73.0
94
+ 2022-07-23,112.0
95
+ 2022-07-24,128.0
96
+ 2022-07-25,166.0
97
+ 2022-07-26,169.0
98
+ 2022-07-27,173.0
99
+ 2022-07-28,161.0
100
+ 2022-07-29,130.0
101
+ 2022-07-30,136.0
102
+ 2022-07-31,87.0
103
+ 2022-08-01,136.0
104
+ 2022-08-02,189.0
105
+ 2022-08-03,138.0
106
+ 2022-08-04,163.0
107
+ 2022-08-05,167.0
108
+ 2022-08-06,66.0
109
+ 2022-08-07,37.0
110
+ 2022-08-08,49.0
111
+ 2022-08-09,91.0
112
+ 2022-08-10,112.0
113
+ 2022-08-11,92.0
114
+ 2022-08-12,66.0
115
+ 2022-08-13,41.0
116
+ 2022-08-14,16.0
117
+ 2022-08-15,55.0
118
+ 2022-08-16,93.0
119
+ 2022-08-17,90.0
120
+ 2022-08-18,90.0
121
+ 2022-08-19,74.0
122
+ 2022-08-20,43.0
123
+ 2022-08-21,34.0
124
+ 2022-08-22,36.0
125
+ 2022-08-23,75.0
126
+ 2022-08-24,59.0
127
+ 2022-08-25,74.0
128
+ 2022-08-26,60.0
129
+ 2022-08-27,32.0
130
+ 2022-08-28,13.0
131
+ 2022-08-29,19.0
132
+ 2022-08-30,59.0
133
+ 2022-08-31,36.0
134
+ 2022-09-01,36.0
135
+ 2022-09-02,46.0
136
+ 2022-09-03,15.0
137
+ 2022-09-04,7.0
138
+ 2022-09-05,9.0
139
+ 2022-09-06,30.0
140
+ 2022-09-07,37.0
141
+ 2022-09-08,47.0
142
+ 2022-09-09,39.0
143
+ 2022-09-10,10.0
144
+ 2022-09-11,12.0
145
+ 2022-09-12,21.0
146
+ 2022-09-13,24.0
147
+ 2022-09-14,32.0
148
+ 2022-09-15,19.0
149
+ 2022-09-16,27.0
150
+ 2022-09-17,11.0
151
+ 2022-09-18,8.0
152
+ 2022-09-19,34.0
153
+ 2022-09-20,24.0
154
+ 2022-09-21,12.0
155
+ 2022-09-22,18.0
156
+ 2022-09-23,14.0
157
+ 2022-09-24,1.0
158
+ 2022-09-25,3.0
159
+ 2022-09-26,5.0
160
+ 2022-09-27,10.0
161
+ 2022-09-28,13.0
162
+ 2022-09-29,16.0
163
+ 2022-09-30,12.0
164
+ 2022-10-01,9.0
165
+ 2022-10-02,3.0
166
+ 2022-10-03,9.0
167
+ 2022-10-04,9.0
168
+ 2022-10-05,9.0
169
+ 2022-10-06,7.0
170
+ 2022-10-07,3.0
171
+ 2022-10-08,4.0
172
+ 2022-10-09,3.0
173
+ 2022-10-10,3.0
174
+ 2022-10-11,7.0
175
+ 2022-10-12,15.0
176
+ 2022-10-13,22.0
177
+ 2022-10-14,13.0
178
+ 2022-10-15,4.0
179
+ 2022-10-16,2.0
180
+ 2022-10-17,14.0
181
+ 2022-10-18,12.0
182
+ 2022-10-19,10.0
183
+ 2022-10-20,11.0
184
+ 2022-10-21,10.0
185
+ 2022-10-22,7.0
186
+ 2022-10-23,6.0
187
+ 2022-10-24,7.0
188
+ 2022-10-25,4.0
189
+ 2022-10-26,8.0
190
+ 2022-10-27,10.0
191
+ 2022-10-28,15.0
192
+ 2022-10-29,5.0
193
+ 2022-10-30,4.0
194
+ 2022-10-31,15.0
195
+ 2022-11-01,7.0
196
+ 2022-11-02,8.0
197
+ 2022-11-03,6.0
198
+ 2022-11-04,10.0
199
+ 2022-11-05,3.0
200
+ 2022-11-06,3.0
201
+ 2022-11-07,9.0
202
+ 2022-11-08,7.0
203
+ 2022-11-09,3.0
204
+ 2022-11-10,4.0
205
+ 2022-11-11,2.0
206
+ 2022-11-12,3.0
207
+ 2022-11-13,3.0
208
+ 2022-11-14,8.0
209
+ 2022-11-15,7.0
210
+ 2022-11-16,5.0
211
+ 2022-11-17,6.0
212
+ 2022-11-18,2.0
213
+ 2022-11-19,2.0
214
+ 2022-11-20,0.0
215
+ 2022-11-21,0.0
216
+ 2022-11-22,4.0
217
+ 2022-11-23,7.0
218
+ 2022-11-24,11.0
219
+ 2022-11-25,1.0
220
+ 2022-11-26,3.0
221
+ 2022-11-27,0.0
222
+ 2022-11-28,14.0
223
+ 2022-11-29,16.0
224
+ 2022-11-30,4.0
225
+ 2022-12-01,11.0
226
+ 2022-12-02,5.0
227
+ 2022-12-03,2.0
228
+ 2022-12-04,1.0
229
+ 2022-12-05,6.0
230
+ 2022-12-06,6.0
231
+ 2022-12-07,8.0
232
+ 2022-12-08,7.0
233
+ 2022-12-09,7.0
234
+ 2022-12-10,3.0
235
+ 2022-12-11,3.0
236
+ 2022-12-12,4.0
237
+ 2022-12-13,5.0
238
+ 2022-12-14,6.0
239
+ 2022-12-15,2.0
240
+ 2022-12-16,4.0
241
+ 2022-12-17,3.0
242
+ 2022-12-18,2.0
243
+ 2022-12-19,2.0
244
+ 2022-12-20,4.0
245
+ 2022-12-21,12.0
246
+ 2022-12-22,16.0
247
+ 2022-12-23,9.0
248
+ 2022-12-24,4.0
249
+ 2022-12-25,3.0
250
+ 2022-12-26,4.0
251
+ 2022-12-27,4.0
252
+ 2022-12-28,5.0
253
+ 2022-12-29,14.0
254
+ 2022-12-30,13.0
255
+ 2022-12-31,4.0
256
+ 2023-01-01,3.0
257
+ 2023-01-02,2.0
258
+ 2023-01-03,10.0
259
+ 2023-01-04,10.0
260
+ 2023-01-05,19.0
261
+ 2023-01-06,7.0
262
+ 2023-01-07,3.0
263
+ 2023-01-08,10.0
264
+ 2023-01-09,12.0
265
+ 2023-01-10,17.0
266
+ 2023-01-11,14.0
267
+ 2023-01-12,15.0
268
+ 2023-01-13,10.0
269
+ 2023-01-14,15.0
270
+ 2023-01-15,1.0
271
+ 2023-01-16,9.0
272
+ 2023-01-17,6.0
273
+ 2023-01-18,4.0
274
+ 2023-01-19,2.0
275
+ 2023-01-20,4.0
276
+ 2023-01-21,2.0
277
+ 2023-01-22,3.0
278
+ 2023-01-23,1.0
279
+ 2023-01-24,4.0
280
+ 2023-01-25,7.0
281
+ 2023-01-26,5.0
282
+ 2023-01-27,1.0
283
+ 2023-01-28,0.0
284
+ 2023-01-29,2.0
285
+ 2023-01-30,3.0
286
+ 2023-01-31,6.0
287
+ 2023-02-01,8.0
288
+ 2023-02-02,7.0
289
+ 2023-02-03,4.0
290
+ 2023-02-04,3.0
291
+ 2023-02-05,1.0
292
+ 2023-02-06,13.0
293
+ 2023-02-07,4.0
294
+ 2023-02-08,6.0
295
+ 2023-02-09,3.0
296
+ 2023-02-10,7.0
297
+ 2023-02-11,3.0
298
+ 2023-02-12,0.0
299
+ 2023-02-13,3.0
models/lstm_forec_40_11_06.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df812ddfa22e1d13ed44b8645a69b5a5c6d00ccd3008975f59dca3b9e3637b5c
3
+ size 19807
requirements.txt CHANGED
@@ -1,2 +1,5 @@
1
  gradio==5.29.0
2
- plotly==6.0.1
 
 
 
 
1
  gradio==5.29.0
2
+ torch==2.7.0
3
+ scikit-learn==1.6.1
4
+ plotly==6.0.1
5
+ statsmodels==0.14.4
sections/try_it_yourself.html CHANGED
@@ -4,7 +4,52 @@
4
  <li>📈 Upload a CSV file with two columns: dates in the first column and disease mention counts in the second</li>
5
  <li>🎯 Click "Plot Time Series" to visualize your data</li>
6
  <li>🔍 Select an anomaly detection method from the dropdown</li>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  <li>⚡ Click "Detect Anomalies" to identify unusual patterns in your time series</li>
8
  </ul>
9
- <p>This tool combines time series analysis and anomaly detection to help identify potential disease outbreaks based on news coverage patterns. 💡</p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  </section>
 
4
  <li>📈 Upload a CSV file with two columns: dates in the first column and disease mention counts in the second</li>
5
  <li>🎯 Click "Plot Time Series" to visualize your data</li>
6
  <li>🔍 Select an anomaly detection method from the dropdown</li>
7
+ <li>⚙️ Configure the detection parameters:
8
+ <ul>
9
+ <li><strong>For LSTM method:</strong>
10
+ <ul>
11
+ <li><em>k</em>: Controls sensitivity (1-3). Higher values mean stricter anomaly detection.</li>
12
+ <li><em>Percentile</em>: Threshold percentile for anomaly detection (0-100).</li>
13
+ <li><em>Threshold Method</em>: Choose how to calculate anomaly thresholds:
14
+ <ul>
15
+ <li>IQR-based methods: Compare predictions with actual values using different metrics</li>
16
+ <li>Percentile-based methods: Use statistical thresholds on prediction errors</li>
17
+ </ul>
18
+ </li>
19
+ </ul>
20
+ </li>
21
+ <li><strong>For ARIMA method:</strong>
22
+ <ul>
23
+ <li><em>k</em>: Sensitivity multiplier for standard deviation-based thresholds (1-3).</li>
24
+ </ul>
25
+ </li>
26
+ <li><strong>For IQR method:</strong>
27
+ <ul>
28
+ <li><em>k</em>: IQR multiplier (1-3). Higher values detect more extreme outliers.</li>
29
+ </ul>
30
+ </li>
31
+ </ul>
32
+ </li>
33
  <li>⚡ Click "Detect Anomalies" to identify unusual patterns in your time series</li>
34
  </ul>
35
+
36
+ <div class="example-section">
37
+ <h3>📋 Example Dataset</h3>
38
+ <p>Try out the tool with our sample dataset:</p>
39
+ <ul>
40
+ <li><strong>Dataset:</strong> <code>mpox.csv</code> - News coverage time series for Monkeypox/Mpox outbreak</li>
41
+ <li><strong>Time Period:</strong> Daily counts from early 2022</li>
42
+ <li><strong>Recommended Settings:</strong>
43
+ <ul>
44
+ <li>Method: LSTM</li>
45
+ <li>k: 1.5</li>
46
+ <li>Percentile: 95</li>
47
+ <li>Threshold Method: "IQR on |ground truth - forecast|"</li>
48
+ </ul>
49
+ </li>
50
+ <li><strong>Expected Results:</strong> The analysis should identify significant spikes in news coverage that corresponded to major outbreak events and public health announcements during the 2022 Mpox outbreak.</li>
51
+ </ul>
52
+ </div>
53
+
54
+ <p>This tool combines time series analysis and anomaly detection to help identify potential disease outbreaks based on news coverage patterns. The results can be used to alert public health officials about emerging health concerns. 💡</p>
55
  </section>
src/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .plotting.visualization import plot_time_series, plot_anomalies
2
+ from .anomaly_detection import detect_anomalies
3
+ from .outbreak_detection import (
4
+ LSTMforOutbreakDetection,
5
+ ARIMAforOutbreakDetection,
6
+ IQRforOutbreakDetection
7
+ )
8
+ from .utils import (
9
+ timestamp_wise_evaluation,
10
+ tolerance_based_evaluation
11
+ )
12
+
13
+ __all__ = [
14
+ 'plot_time_series',
15
+ 'plot_anomalies',
16
+ 'detect_anomalies',
17
+ 'LSTMforOutbreakDetection',
18
+ 'ARIMAforOutbreakDetection',
19
+ 'IQRforOutbreakDetection',
20
+ 'timestamp_wise_evaluation',
21
+ 'tolerance_based_evaluation'
22
+ 'prepare_time_series_dataframe',
23
+ 'tolerance_based_evaluation',
24
+ ]
src/anomaly_detection.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from .outbreak_detection import (
3
+ LSTMforOutbreakDetection,
4
+ ARIMAforOutbreakDetection,
5
+ IQRforOutbreakDetection
6
+ )
7
+ from .plotting.visualization import plot_anomalies
8
+ from .utils import prepare_time_series_dataframe
9
+
10
+
11
+ THRESHOLD_METHODS = {
12
+ "IQR on (ground truth - forecast)": 0,
13
+ "IQR on |ground truth - forecast|": 1,
14
+ "IQR on |ground truth - forecast|/forecast": 2,
15
+ "Percentile threshold on absolute loss": 3,
16
+ "Percentile threshold on raw loss": 4
17
+ }
18
+
19
+ def detect_anomalies(file_path: str, method: str, k: int, percentile: float, threshold_method: int):
20
+ """
21
+ Detects anomalies in time series data using various detection methods.
22
+ Args:
23
+ file_path (str): Path to the CSV file containing time series data
24
+ method (str): Detection method to use ('LSTM', 'ARIMA', or 'IQR')
25
+ k (int): Number of neighbors or window size (method-dependent parameter)
26
+ percentile (float): Percentile threshold for anomaly detection
27
+ threshold_method (int): Method to determine threshold for anomaly detection
28
+ Returns:
29
+ plotly.graph_objects.Figure: Plotly figure containing the time series with highlighted anomalies
30
+ """
31
+ df = pd.read_csv(file_path)
32
+ df = prepare_time_series_dataframe(df)
33
+
34
+ # Map threshold methods to their descriptions for better readability
35
+
36
+ detectors = {
37
+ 'LSTM': LSTMforOutbreakDetection(
38
+ checkpoint_path='models/lstm_forec_40_11_06.pth',
39
+ k=k,
40
+ percentile=percentile,
41
+ threshold_method=THRESHOLD_METHODS[threshold_method]
42
+ ),
43
+ 'ARIMA': ARIMAforOutbreakDetection(k=k),
44
+ 'IQR': IQRforOutbreakDetection(k=k)
45
+ }
46
+
47
+ detector = detectors[method]
48
+ test, new_label = detector.detect_anomalies(df)
49
+ return plot_anomalies(test, anomaly_col=new_label)
50
+
51
+
src/outbreak_detection/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .lstm import LSTMforOutbreakDetection
2
+ from .arima import ARIMAforOutbreakDetection
3
+ from .iqr import IQRforOutbreakDetection
4
+ from .lstm_model import LstmModel, testing
5
+
6
+ __all__ = [
7
+ 'LSTMforOutbreakDetection',
8
+ 'ARIMAforOutbreakDetection',
9
+ 'IQRforOutbreakDetection',
10
+ 'LstmModel',
11
+ 'testing'
12
+ ]
src/outbreak_detection/arima.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from statsmodels.tsa.stattools import adfuller, acf, pacf
4
+ from statsmodels.tsa.arima.model import ARIMA
5
+
6
+
7
+ NEW_ANOMALY_COLUMN_NAME = 'anomaly'
8
+
9
+ class ARIMAforOutbreakDetection:
10
+ def __init__(self, window_size=7, stride=1, k=1.5, significance=0.05, max_lag=30):
11
+ self.window_size = window_size
12
+ self.stride = stride
13
+ self.k = k
14
+ self.significance = significance
15
+ self.max_lag = max_lag
16
+
17
+ def test_stationarity(self, ts_data, column=''):
18
+ if isinstance(ts_data, pd.Series):
19
+ adf_test = adfuller(ts_data, autolag='AIC')
20
+ else:
21
+ adf_test = adfuller(ts_data[column], autolag='AIC')
22
+ return "Stationary" if adf_test[1] <= self.significance else "Non-Stationary"
23
+
24
+ def make_stationary(self, dataframe, column):
25
+ df_to_return = None
26
+ result = self.test_stationarity(dataframe, column)
27
+
28
+ if result == "Stationary":
29
+ return dataframe
30
+
31
+ diff_series = dataframe.copy()
32
+ for diff_count in range(5):
33
+ diff_series = diff_series.diff().fillna(0)
34
+ if self.test_stationarity(diff_series, column) == "Stationary":
35
+ return diff_series
36
+ return diff_series
37
+
38
+ def create_windows(self, df):
39
+ windows, gts = [], []
40
+ for i in range(0, len(df) - self.window_size, self.stride):
41
+ end_id = i + self.window_size
42
+ windows.append(df.iloc[i:end_id, :])
43
+ gts.append(df.iloc[end_id, :])
44
+ return np.stack(windows), np.stack(gts)
45
+
46
+ def find_p_q(self, series):
47
+ N = len(series)
48
+ acf_values, _ = acf(series, nlags=self.max_lag, alpha=self.significance, fft=False)
49
+ pacf_values, _ = pacf(series, nlags=self.max_lag, alpha=self.significance)
50
+ threshold = 1.96 / np.sqrt(N)
51
+
52
+ def find_last_consecutive_outlier(values):
53
+ for i in range(1, len(values)):
54
+ if values[i] < 0 or (values[i] > 0 and abs(values[i]) < threshold):
55
+ return i
56
+ return len(values) - 1
57
+
58
+ return find_last_consecutive_outlier(pacf_values), find_last_consecutive_outlier(acf_values)
59
+
60
+ def detect_anomalies(self, dataset, news_or_cases='news'):
61
+ stationary_data = self.make_stationary(dataset, news_or_cases)
62
+ p, q = self.find_p_q(stationary_data[news_or_cases])
63
+ anomalies, means, stdevs, residuals, predictions, gts = self._train_arima_model(stationary_data, p, q)
64
+ result_df = self._prepare_resulting_dataframe(
65
+ residuals, means, stdevs, dataset.iloc[self.window_size:],
66
+ anomalies, gts, predictions
67
+ )
68
+ return self._postprocess_anomalies(result_df, news_or_cases), NEW_ANOMALY_COLUMN_NAME
69
+
70
+ def _train_arima_model(self, dataset, p, q):
71
+ predictions, residuals, means, stdevs, anomalies = [], [], [], [], []
72
+ windows, gts = self.create_windows(dataset)
73
+
74
+ for window, gt in zip(windows, gts):
75
+ model = ARIMA(window, order=(p, 0, q))
76
+ model.initialize_approximate_diffuse()
77
+ fit = model.fit()
78
+
79
+ pred = fit.forecast(steps=1)[0]
80
+ residual = np.abs(gt - pred)
81
+ mu, std = np.mean(fit.resid), np.std(fit.resid)
82
+
83
+ anomalies.append(
84
+ 1 if residual > mu + self.k * std or residual < mu - self.k * std else 0
85
+ )
86
+
87
+ means.append(mu)
88
+ stdevs.append(std)
89
+ residuals.append(residual)
90
+ predictions.append(pred)
91
+
92
+ return anomalies, means, stdevs, residuals, predictions, gts
93
+
94
+ def _prepare_resulting_dataframe(self, residuals, means, stdevs, original_dataset,
95
+ anomalies, gts, predictions):
96
+ result_df = original_dataset.copy()
97
+ result_df['residuals'] = residuals
98
+ result_df['mu'] = means
99
+ result_df['sigma'] = stdevs
100
+ result_df['anomaly'] = anomalies
101
+ result_df['gts_diff'] = gts
102
+ result_df['pred_diff'] = predictions
103
+ return result_df
104
+
105
+ def _postprocess_anomalies(self, dataframe, col_name='news'):
106
+ dataframe['derivative'] = dataframe[col_name].diff().fillna(0)
107
+ dataframe['new_anomaly'] = [
108
+ 0 if row.derivative < 0 and row.anomaly == 1 else row.anomaly
109
+ for _, row in dataframe.iterrows()
110
+ ]
111
+ return dataframe
src/outbreak_detection/iqr.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+
4
+
5
+ NEW_LABEL_COLUMN_NAME = 'new_label'
6
+
7
+ class IQRforOutbreakDetection:
8
+ def __init__(self, window_size=7, stride=1, k=1.5):
9
+ self.window_size = window_size
10
+ self.stride = stride
11
+ self.k = k
12
+
13
+ def _iqr_rolling(self, timeseries):
14
+ q1 = np.percentile(timeseries, 25)
15
+ q3 = np.percentile(timeseries, 75)
16
+ iqr = q3 - q1
17
+ ub = q3 + self.k * iqr
18
+ lb = q1 - self.k * iqr
19
+ return ub, lb
20
+
21
+ def detect_anomalies(self, df, news_or_cases='news'):
22
+ """"
23
+ input methods: k
24
+ """
25
+ if isinstance(df, pd.Series):
26
+ timeseries = df
27
+ else:
28
+ timeseries = df[news_or_cases]
29
+
30
+ tot_peaks, final_peaks, _ = self._windowed_iqr(timeseries)
31
+ result_df = self._prepare_resulting_dataframe(final_peaks, timeseries)
32
+ processed_df = self._postprocess_anomalies(result_df, news_or_cases)
33
+ print(processed_df)
34
+
35
+ return processed_df, NEW_LABEL_COLUMN_NAME
36
+
37
+ def _windowed_iqr(self, df):
38
+ tot_peaks = {}
39
+ for i in range(0, len(df) - self.window_size + 1, self.stride):
40
+ end_id = i + self.window_size
41
+ window = df[i:end_id]
42
+ ub, _ = self._iqr_rolling(window)
43
+
44
+ for j in window.index:
45
+ peaks_list = tot_peaks.setdefault(f'{j}', [])
46
+ peaks_list.append(window.loc[j] > ub)
47
+
48
+ final_peaks = {k: True if True in v else False
49
+ for k, v in tot_peaks.items()}
50
+
51
+ return tot_peaks, final_peaks, end_id
52
+
53
+ def _prepare_resulting_dataframe(self, peaks_df, news_or_cases_df):
54
+ final_df_iqr = pd.DataFrame.from_dict(peaks_df, orient='index')
55
+ dff = pd.DataFrame(news_or_cases_df)
56
+ dff['peaks'] = final_df_iqr.loc[:, 0].values
57
+ dff['peaks'] = dff['peaks'].map({True: 1, False: 0})
58
+ return dff
59
+
60
+ def _postprocess_anomalies(self, dataframe, col_name='news'):
61
+ dataframe['derivative'] = dataframe[col_name].diff().fillna(0)
62
+ dataframe['new_label'] = [0 if v.derivative < 0 and v.peaks == 1 else v.peaks
63
+ for _, v in dataframe.iterrows()]
64
+ return dataframe
src/outbreak_detection/lstm.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch.utils.data as data_utils
5
+ from sklearn.preprocessing import MinMaxScaler
6
+ from .lstm_model import LstmModel, testing
7
+
8
+
9
+ PRETRAINED_MODEL_N_CHANNELS = 1
10
+ PRETRAINED_MODEL_Z_SIZE = 32
11
+
12
+ class LSTMforOutbreakDetection:
13
+ def __init__(
14
+ self,
15
+ checkpoint_path=None,
16
+ n_channels=PRETRAINED_MODEL_N_CHANNELS,
17
+ z_size=PRETRAINED_MODEL_Z_SIZE,
18
+ device='cpu',
19
+ window=7,
20
+ batch_size=32,
21
+ k=1.5,
22
+ percentile=95,
23
+ threshold_method=0
24
+ ):
25
+ self.device = torch.device(device)
26
+ self.window = window
27
+ self.batch_size = batch_size
28
+ self.n_channels = n_channels
29
+ self.z_size = z_size
30
+ self.scaler = MinMaxScaler(feature_range=(0,1))
31
+ self.k = k
32
+ self.percentile = percentile
33
+ self.threshold_method = threshold_method
34
+ if checkpoint_path:
35
+ self.model = self._load_model(checkpoint_path)
36
+
37
+ def _load_model(self, checkpoint_path):
38
+ model = LstmModel(self.n_channels, self.z_size)
39
+ model = model.to(self.device)
40
+ model.load_state_dict(torch.load(checkpoint_path, map_location=self.device))
41
+ return model
42
+
43
+ def create_test_sequences(self, dataframe, time_steps, news_or_cases='news'):
44
+ if news_or_cases not in ['news', 'cases']:
45
+ raise ValueError("news_or_cases should be either 'news' or 'cases'")
46
+
47
+ output, output2 = [], []
48
+ dataframe[[news_or_cases]] = self.scaler.fit_transform(dataframe[[news_or_cases]])
49
+ norm = np.array(dataframe[[news_or_cases]]).astype(float)
50
+
51
+ for i in range(len(norm)):
52
+ end_ix = i + time_steps
53
+
54
+ if end_ix > len(norm)-1:
55
+ break
56
+
57
+ seq_x, seq_y = norm[i:end_ix, :], norm[end_ix, 0]
58
+ output.append(seq_x)
59
+ output2.append(seq_y)
60
+
61
+ return np.stack(output), np.stack(output2)
62
+
63
+ def prepare_input_dataframe(self, dataframe, news_column_name='news'):
64
+ X_test, y_test = self.create_test_sequences(dataframe, self.window, news_column_name)
65
+ test_loader = torch.utils.data.DataLoader(
66
+ data_utils.TensorDataset(
67
+ torch.from_numpy(X_test).float(),
68
+ torch.from_numpy(y_test).float()
69
+ ),
70
+ batch_size=self.batch_size,
71
+ shuffle=False,
72
+ num_workers=0
73
+ )
74
+ return test_loader, y_test
75
+
76
+ def predict(self, dataframe, news_column_name='news'):
77
+ test_loader, y_test = self.prepare_input_dataframe(dataframe, news_column_name)
78
+ results, w = testing(self.model, test_loader, self.device)
79
+ forecast_test = np.concatenate([
80
+ torch.stack(w[:-1]).flatten().detach().cpu().numpy(),
81
+ w[-1].flatten().detach().cpu().numpy()
82
+ ])
83
+
84
+ test_df = dataframe[self.window:].copy()
85
+ test_df['y_test'] = y_test
86
+ test_df['pred_forec'] = forecast_test
87
+ test_df['abs_loss'] = np.abs(test_df.y_test - test_df.pred_forec)
88
+ test_df['rel_loss'] = np.abs((test_df['pred_forec'] - test_df['y_test']) / (1 + test_df['pred_forec']))
89
+ test_df['diff'] = test_df['y_test'] - test_df['pred_forec']
90
+
91
+ return test_df
92
+
93
+ @staticmethod
94
+ def _iqr_rolling(timeseries, k):
95
+ q1, q3 = np.percentile(timeseries, [25, 75])
96
+ iqr = q3 - q1
97
+
98
+ return q3 + k * iqr
99
+
100
+ def windowed_iqr(self, df, k, type_of_loss='diff'):
101
+ peaks = {}
102
+
103
+ for i in range(len(df)):
104
+ end_ix = i + self.window
105
+
106
+ if end_ix > len(df)-1:
107
+ break
108
+
109
+ seq_x = df.iloc[i:end_ix, :]
110
+ ub = self._iqr_rolling(seq_x[type_of_loss], k)
111
+
112
+ for j in seq_x.index:
113
+ condition = int(seq_x.loc[j, type_of_loss] > ub)
114
+ peaks.setdefault(f'{j}', []).append(condition)
115
+
116
+ return {k: 1 if sum(v) > 0 else 0 for k, v in peaks.items()}
117
+
118
+ def get_perc_threshold(self, test_df, percentile, col='abs_loss'):
119
+ if col not in ['abs_loss', 'loss']:
120
+ raise ValueError("col should be either 'abs_loss' or 'loss'")
121
+
122
+ test1 = test_df[:-1].copy()
123
+ anom_perc_loss = {}
124
+
125
+ for i in range(len(test_df)):
126
+ end_ix = i + self.window
127
+ if end_ix > len(test_df)-1:
128
+ break
129
+
130
+ seq_x = test_df.iloc[i:end_ix, :].copy()
131
+ mae = seq_x['abs_loss'].values if col == 'abs_loss' else seq_x['y_test'] - seq_x['pred_forec']
132
+
133
+ threshold = np.percentile(mae, percentile)
134
+ seq_x['threshold'] = threshold
135
+
136
+ for j in seq_x.index:
137
+ condition = int(seq_x.loc[j, col] > seq_x.loc[j, 'threshold'])
138
+ anom_perc_loss.setdefault(f'{j}', []).append(condition)
139
+
140
+ final_anom = {k: 1 if sum(v) > 0 else 0 for k, v in anom_perc_loss.items()}
141
+ new_col = 'anom_perc_abs_loss' if col == 'abs_loss' else 'anom_perc_diff_gt_pred'
142
+ test1[new_col] = pd.Series(final_anom)
143
+
144
+ return test1
145
+
146
+ def postprocess_anomalies(self, test_df, new_col, old_col, news_or_cases):
147
+ test_df = test_df.copy()
148
+ test_df['derivative'] = test_df[news_or_cases].diff().fillna(0)
149
+ test_df[new_col] = [0 if v.derivative < 0 and v[old_col] == 1 else v[old_col]
150
+ for k, v in test_df.iterrows()]
151
+
152
+ return test_df
153
+
154
+ def detect_anomalies(self, test_df, news_or_cases='news'):
155
+ """
156
+ Detect anomalies using different methods:
157
+ 0: IQR on (ground truth - forecast)
158
+ 1: IQR on |ground truth - forecast|
159
+ 2: IQR on |ground truth - forecast|/forecast
160
+ 3: Percentile threshold on absolute loss
161
+ 4: Percentile threshold on raw loss
162
+
163
+ input parameters: k (1-3), threshold_method, percentile
164
+ """
165
+ test_df = test_df.copy()
166
+
167
+ test = self.predict(test_df, news_column_name=news_or_cases)
168
+
169
+ if self.threshold_method in [0, 1, 2]:
170
+ loss_type = {0: 'diff', 1: 'abs_loss', 2: 'rel_loss'}[self.threshold_method]
171
+ iqr_suffix = {0: 'f_iqr', 1: 'abs_iqr', 2: 'rel_iqr'}[self.threshold_method]
172
+ new_label = {0: 'f_new_label', 1: 'abs_new_label', 2: 'rel_new_label'}[self.threshold_method]
173
+
174
+ peaks = self.windowed_iqr(test, self.k, loss_type)
175
+ peak_series = pd.Series(peaks)
176
+ peak_series.index = pd.to_datetime(peak_series.index)
177
+ test[iqr_suffix] = peak_series
178
+ test = self.postprocess_anomalies(test, new_label, iqr_suffix, news_or_cases)
179
+ return test, new_label
180
+
181
+ elif self.threshold_method in [3, 4]:
182
+ loss_type = 'abs_loss' if self.threshold_method == 3 else 'loss'
183
+ new_label = 'new_anom_absl' if self.threshold_method == 3 else 'new_anom_diff'
184
+ old_label = 'anom_perc_abs_loss' if self.threshold_method == 3 else 'anom_perc_diff_gt_pred'
185
+
186
+ test = self.get_perc_threshold(test, self.percentile, loss_type)
187
+ test = self.postprocess_anomalies(test, new_label, old_label, news_or_cases)
188
+ return test, new_label
189
+
190
+ raise ValueError("threshold_method must be between 0 and 4")
191
+
src/outbreak_detection/lstm_model.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class LstmModel(nn.Module):
6
+ def __init__(self, in_size: int, latent_size: int) -> None:
7
+ """
8
+ Initialize the LSTM autoencoder model.
9
+
10
+ Parameters:
11
+ -----------
12
+ in_size : int
13
+ Number of features in the input (input dimension).
14
+ latent_size : int
15
+ Size of the latent space representation in the LSTM.
16
+
17
+ Example:
18
+ --------
19
+ For in_size = 5, latent_size = 50, the model will:
20
+ - Take inputs with 5 features
21
+ - Encode them into a 50-dimensional latent space
22
+ - Decode back to the original 5 feature dimensions
23
+
24
+ Architecture:
25
+ - LSTM layer for encoding with dropout 0.2
26
+ - Additional dropout layer (0.2)
27
+ - ReLU activation
28
+ - Fully connected layer for decoding
29
+ """
30
+ super().__init__() # Corrected the position of super().__init__()
31
+ self.lstm = nn.LSTM(
32
+ input_size=in_size,
33
+ hidden_size=latent_size,
34
+ num_layers=1,
35
+ batch_first=True,
36
+ dropout=0.2
37
+ ) # input and output tensors are provided as (batch, seq_len, feature(size))
38
+ self.dropout = nn.Dropout(0.2)
39
+ self.relu = nn.ReLU()
40
+ self.fc = nn.Linear(latent_size, in_size)
41
+
42
+ def forward(self, w: torch.Tensor) -> torch.Tensor:
43
+ """
44
+ Forward pass through the LSTM model.
45
+
46
+ Parameters:
47
+ -----------
48
+ w : torch.Tensor
49
+ Input tensor of shape (batch_size, seq_len, in_size).
50
+
51
+ Returns:
52
+ --------
53
+ torch.Tensor
54
+ Output tensor of shape (batch_size, in_size).
55
+
56
+ Example:
57
+ --------
58
+ If the input tensor w has shape (32, 10, 5), the output will have shape (32, 5).
59
+ The LSTM processes the input sequence and returns the last output of the sequence.
60
+ The output is then passed through a ReLU activation function and a dropout layer.
61
+ Finally, it is passed through a fully connected layer to produce the final output.
62
+ The output is the reconstructed input sequence.
63
+ """
64
+ z, (h_n, c_n) = self.lstm(w)
65
+ forecast = z[:, -1, :]
66
+ forecast = self.relu(forecast)
67
+ forecast = self.dropout(forecast)
68
+ output = self.fc(forecast)
69
+
70
+ return output
71
+
72
+ def testing(model, test_loader, device):
73
+ results=[]
74
+ forecast = []
75
+ with torch.no_grad():
76
+ for X_batch, y_batch in test_loader:
77
+ X_batch = X_batch.to(device)
78
+ y_batch = y_batch.to(device)
79
+ w=model(X_batch)
80
+ results.append(torch.mean((y_batch.unsqueeze(1)-w)**2, axis=1))
81
+ forecast.append(w)
82
+
83
+ return results, forecast
src/plotting/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .visualization import plot_time_series, plot_anomalies
2
+
3
+ __all__ = ['plot_time_series', 'plot_anomalies']
src/plotting/visualization.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import plotly.graph_objects as go
3
+
4
+
5
+ NEWS_COUNT_COLUMN = 0
6
+
7
+ def plot_time_series(file_path):
8
+ df = pd.read_csv(file_path)
9
+ fig = go.Figure()
10
+ fig.add_trace(go.Scatter(x=df.iloc[:, 0], y=df.iloc[:, 1], mode='lines', name='Time Series'))
11
+ fig.update_layout(title='Disease Mention Time Series', xaxis_title='Date', yaxis_title='Count')
12
+
13
+ return fig
14
+
15
+ def plot_anomalies(df, anomaly_col='new_label'):
16
+ print(df)
17
+ fig = go.Figure()
18
+
19
+ fig.add_trace(go.Scatter(
20
+ x=df.index,
21
+ y=df.iloc[:, NEWS_COUNT_COLUMN],
22
+ mode='lines',
23
+ name='Time Series',
24
+ line=dict(color='blue')
25
+ ))
26
+
27
+ anomalies = df[df[anomaly_col] == 1]
28
+
29
+ fig.add_trace(go.Scatter(
30
+ x=anomalies.index,
31
+ y=anomalies.iloc[:, NEWS_COUNT_COLUMN],
32
+ mode='markers',
33
+ name='Anomalies',
34
+ marker=dict(color='red', size=10, symbol='circle')
35
+ ))
36
+
37
+ fig.update_layout(
38
+ title='Disease Mention Time Series with Detected Anomalies',
39
+ xaxis_title='Date',
40
+ yaxis_title='Count',
41
+ showlegend=True
42
+ )
43
+
44
+ return fig
src/utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from sklearn.metrics import classification_report
4
+
5
+ def timestamp_wise_evaluation(anomalies_cases, anomalies_news, threshold):
6
+ print(f"Classification report for threshold {threshold} (timestamp-wise evaluation):")
7
+ print(classification_report(anomalies_cases, anomalies_news))
8
+
9
+ def tolerance_based_evaluation(anomalies_cases, anomalies_news, cases_df, news_df, threshold):
10
+ Tp = 0
11
+ Fp = 0
12
+ Fn = 0
13
+ Tn = 0
14
+ for i in range(len(news_df)):
15
+ news_an = news_df.iloc[i][anomalies_news]
16
+ if news_an == 1:
17
+ if i == len(news_df) - 1:
18
+ if cases_df.iloc[i][anomalies_cases] == 1:
19
+ Tp += 1
20
+ else:
21
+ Fp += 1
22
+ elif i == len(news_df) - 2:
23
+ if cases_df.iloc[i][anomalies_cases] == 1 or cases_df.iloc[i+1][anomalies_cases] == 1:
24
+ Tp += 1
25
+ else:
26
+ Fp += 1
27
+ else:
28
+ if cases_df.iloc[i][anomalies_cases] == 1 or cases_df.iloc[i+1][anomalies_cases] == 1 or cases_df.iloc[i+2][anomalies_cases] == 1:
29
+ Tp += 1
30
+ else:
31
+ Fp += 1
32
+ else:
33
+ if i == len(news_df) - 1:
34
+ if cases_df.iloc[i][anomalies_cases] == 1:
35
+ Fn += 1
36
+ else:
37
+ Tn += 1
38
+ elif i == len(news_df) - 2:
39
+ if cases_df.iloc[i][anomalies_cases] == 1:
40
+ Fn += 1
41
+ else:
42
+ Tn += 1
43
+ else:
44
+ if cases_df.iloc[i][anomalies_cases] == 1:
45
+ Fn += 1
46
+ else:
47
+ Tn += 1
48
+ print(f"Tolerance-based evaluation for method {threshold}:")
49
+ print(f"True Positives: {Tp}, False Positives: {Fp}, False Negatives: {Fn}, True Negatives: {Tn}")
50
+ precision = Tp / (Tp + Fp)
51
+ recall = Tp / (Tp + Fn)
52
+ f1 = 2 * (precision * recall) / (precision + recall)
53
+ print(f"Precision: {precision}, Recall: {recall}, F1: {f1}")
54
+
55
+
56
+ def prepare_time_series_dataframe(df):
57
+ """Prepare dataframe for time series analysis by setting datetime index and renaming columns"""
58
+ df.set_index(df.columns[0], inplace=True)
59
+
60
+ try:
61
+ df.index = pd.to_datetime(df.index)
62
+ except ValueError:
63
+ raise ValueError("The first column of the CSV file must be a datetime column.")
64
+
65
+ df.rename(columns={df.columns[0]: "news"}, inplace=True)
66
+ return df
67
+
68
+ def update_controls(method):
69
+ """
70
+ Updates the interactivity of control elements based on the selected method.
71
+
72
+ Args:
73
+ method (str): The selected anomaly detection method
74
+
75
+ Returns:
76
+ dict: Update configuration for Gradio components
77
+ """
78
+ is_lstm = method == "LSTM"
79
+ return [
80
+ gr.update(interactive=is_lstm),
81
+ gr.update(interactive=is_lstm)
82
+ ]
src/visualization.py DELETED
@@ -1,20 +0,0 @@
1
- import pandas as pd
2
- import plotly.express as px
3
-
4
-
5
- def plot_time_series(file):
6
- """
7
- Plots a time series graph from a CSV file.
8
-
9
- This function reads the CSV file and generates a line plot
10
- showing the disease mentions over time.
11
-
12
- """
13
- df = pd.read_csv(file.name)
14
- fig = px.line(
15
- df,
16
- x=df.columns[0],
17
- y=df.columns[1],
18
- title='Disease Mentions Over Time'
19
- )
20
- return fig