Commit
·
9485251
1
Parent(s):
0d5a7ab
add methods for each strategy
Browse files- .gitignore +4 -1
- app.py +54 -8
- examples/mpox.csv +299 -0
- models/lstm_forec_40_11_06.pth +3 -0
- requirements.txt +4 -1
- sections/try_it_yourself.html +46 -1
- src/__init__.py +24 -0
- src/anomaly_detection.py +51 -0
- src/outbreak_detection/__init__.py +12 -0
- src/outbreak_detection/arima.py +111 -0
- src/outbreak_detection/iqr.py +64 -0
- src/outbreak_detection/lstm.py +191 -0
- src/outbreak_detection/lstm_model.py +83 -0
- src/plotting/__init__.py +3 -0
- src/plotting/visualization.py +44 -0
- src/utils.py +82 -0
- src/visualization.py +0 -20
.gitignore
CHANGED
@@ -16,4 +16,7 @@ build/
|
|
16 |
|
17 |
# VSCode
|
18 |
.vscode/
|
19 |
-
*.code-workspace
|
|
|
|
|
|
|
|
16 |
|
17 |
# VSCode
|
18 |
.vscode/
|
19 |
+
*.code-workspace
|
20 |
+
|
21 |
+
# Solution template
|
22 |
+
solution_template/
|
app.py
CHANGED
@@ -1,5 +1,8 @@
|
|
1 |
import gradio as gr
|
2 |
-
from src
|
|
|
|
|
|
|
3 |
|
4 |
PROJECT_HTML_PATH = "sections/project_description.html"
|
5 |
WELCOME_HTML_PATH = "sections/welcome_section.html"
|
@@ -26,11 +29,48 @@ with gr.Blocks(css=blocks_css) as demo:
|
|
26 |
gr.Markdown(demo_section_html)
|
27 |
gr.HTML(try_it_yourself_html)
|
28 |
|
29 |
-
file_input = gr.File(
|
|
|
|
|
|
|
|
|
30 |
plot_btn = gr.Button("Plot Time Series")
|
31 |
plot_output = gr.Plot(label="Time Series Plot")
|
32 |
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
analyze_btn = gr.Button("Detect Anomalies")
|
35 |
anomaly_output = gr.Plot(label="Anomaly Detection Results")
|
36 |
|
@@ -40,10 +80,16 @@ with gr.Blocks(css=blocks_css) as demo:
|
|
40 |
outputs=[plot_output]
|
41 |
)
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
demo.launch(show_api=False)
|
|
|
1 |
import gradio as gr
|
2 |
+
from src import plot_time_series
|
3 |
+
from src import detect_anomalies
|
4 |
+
from src.utils import update_controls
|
5 |
+
|
6 |
|
7 |
PROJECT_HTML_PATH = "sections/project_description.html"
|
8 |
WELCOME_HTML_PATH = "sections/welcome_section.html"
|
|
|
29 |
gr.Markdown(demo_section_html)
|
30 |
gr.HTML(try_it_yourself_html)
|
31 |
|
32 |
+
file_input = gr.File(
|
33 |
+
label="Upload Time Series CSV",
|
34 |
+
file_types=[".csv"],
|
35 |
+
value="examples/mpox.csv" # Set default example file
|
36 |
+
)
|
37 |
plot_btn = gr.Button("Plot Time Series")
|
38 |
plot_output = gr.Plot(label="Time Series Plot")
|
39 |
|
40 |
+
with gr.Row():
|
41 |
+
method_input = gr.Dropdown(
|
42 |
+
choices=ANOMALY_METHODS,
|
43 |
+
interactive=True,
|
44 |
+
label="Select Anomaly Detection Method",
|
45 |
+
value="LSTM"
|
46 |
+
)
|
47 |
+
k_input = gr.Slider(
|
48 |
+
minimum=1,
|
49 |
+
maximum=3,
|
50 |
+
step=0.1,
|
51 |
+
label="k",
|
52 |
+
value=1.5
|
53 |
+
)
|
54 |
+
percentile_input = gr.Slider(
|
55 |
+
minimum=0,
|
56 |
+
maximum=100,
|
57 |
+
step=1,
|
58 |
+
label="Percentile",
|
59 |
+
value=95,
|
60 |
+
interactive=True
|
61 |
+
)
|
62 |
+
threshold_method_input = gr.Dropdown(
|
63 |
+
choices=[
|
64 |
+
"IQR on (ground truth - forecast)",
|
65 |
+
"IQR on |ground truth - forecast|",
|
66 |
+
"IQR on |ground truth - forecast|/forecast",
|
67 |
+
"Percentile threshold on absolute loss",
|
68 |
+
"Percentile threshold on raw loss"
|
69 |
+
],
|
70 |
+
label="Threshold Method",
|
71 |
+
value="IQR on (ground truth - forecast)",
|
72 |
+
interactive=True
|
73 |
+
)
|
74 |
analyze_btn = gr.Button("Detect Anomalies")
|
75 |
anomaly_output = gr.Plot(label="Anomaly Detection Results")
|
76 |
|
|
|
80 |
outputs=[plot_output]
|
81 |
)
|
82 |
|
83 |
+
method_input.change(
|
84 |
+
fn=update_controls,
|
85 |
+
inputs=[method_input],
|
86 |
+
outputs=[percentile_input, threshold_method_input]
|
87 |
+
)
|
88 |
+
|
89 |
+
analyze_btn.click(
|
90 |
+
fn=detect_anomalies,
|
91 |
+
inputs=[file_input, method_input, k_input, percentile_input, threshold_method_input],
|
92 |
+
outputs=[anomaly_output]
|
93 |
+
)
|
94 |
|
95 |
demo.launch(show_api=False)
|
examples/mpox.csv
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
date,news
|
2 |
+
2022-04-22,15.0
|
3 |
+
2022-04-23,6.0
|
4 |
+
2022-04-24,3.0
|
5 |
+
2022-04-25,8.0
|
6 |
+
2022-04-26,15.0
|
7 |
+
2022-04-27,19.0
|
8 |
+
2022-04-28,16.0
|
9 |
+
2022-04-29,12.0
|
10 |
+
2022-04-30,5.0
|
11 |
+
2022-05-01,5.0
|
12 |
+
2022-05-02,8.0
|
13 |
+
2022-05-03,6.0
|
14 |
+
2022-05-04,11.0
|
15 |
+
2022-05-05,4.0
|
16 |
+
2022-05-06,9.0
|
17 |
+
2022-05-07,8.0
|
18 |
+
2022-05-08,3.0
|
19 |
+
2022-05-09,15.0
|
20 |
+
2022-05-10,18.0
|
21 |
+
2022-05-11,8.0
|
22 |
+
2022-05-12,15.0
|
23 |
+
2022-05-13,11.0
|
24 |
+
2022-05-14,7.0
|
25 |
+
2022-05-15,10.0
|
26 |
+
2022-05-16,12.0
|
27 |
+
2022-05-17,25.0
|
28 |
+
2022-05-18,45.0
|
29 |
+
2022-05-19,163.0
|
30 |
+
2022-05-20,271.0
|
31 |
+
2022-05-21,146.0
|
32 |
+
2022-05-22,152.0
|
33 |
+
2022-05-23,308.0
|
34 |
+
2022-05-24,259.0
|
35 |
+
2022-05-25,174.0
|
36 |
+
2022-05-26,163.0
|
37 |
+
2022-05-27,116.0
|
38 |
+
2022-05-28,64.0
|
39 |
+
2022-05-29,29.0
|
40 |
+
2022-05-30,80.0
|
41 |
+
2022-05-31,75.0
|
42 |
+
2022-06-01,74.0
|
43 |
+
2022-06-02,79.0
|
44 |
+
2022-06-03,76.0
|
45 |
+
2022-06-04,57.0
|
46 |
+
2022-06-05,24.0
|
47 |
+
2022-06-06,62.0
|
48 |
+
2022-06-07,72.0
|
49 |
+
2022-06-08,74.0
|
50 |
+
2022-06-09,75.0
|
51 |
+
2022-06-10,77.0
|
52 |
+
2022-06-11,31.0
|
53 |
+
2022-06-12,18.0
|
54 |
+
2022-06-13,42.0
|
55 |
+
2022-06-14,52.0
|
56 |
+
2022-06-15,94.0
|
57 |
+
2022-06-16,66.0
|
58 |
+
2022-06-17,44.0
|
59 |
+
2022-06-18,15.0
|
60 |
+
2022-06-19,22.0
|
61 |
+
2022-06-20,24.0
|
62 |
+
2022-06-21,30.0
|
63 |
+
2022-06-22,39.0
|
64 |
+
2022-06-23,71.0
|
65 |
+
2022-06-24,50.0
|
66 |
+
2022-06-25,39.0
|
67 |
+
2022-06-26,51.0
|
68 |
+
2022-06-27,36.0
|
69 |
+
2022-06-28,45.0
|
70 |
+
2022-06-29,79.0
|
71 |
+
2022-06-30,57.0
|
72 |
+
2022-07-01,73.0
|
73 |
+
2022-07-02,42.0
|
74 |
+
2022-07-03,13.0
|
75 |
+
2022-07-04,23.0
|
76 |
+
2022-07-05,38.0
|
77 |
+
2022-07-06,47.0
|
78 |
+
2022-07-07,60.0
|
79 |
+
2022-07-08,63.0
|
80 |
+
2022-07-09,53.0
|
81 |
+
2022-07-10,15.0
|
82 |
+
2022-07-11,39.0
|
83 |
+
2022-07-12,62.0
|
84 |
+
2022-07-13,53.0
|
85 |
+
2022-07-14,41.0
|
86 |
+
2022-07-15,72.0
|
87 |
+
2022-07-16,35.0
|
88 |
+
2022-07-17,25.0
|
89 |
+
2022-07-18,56.0
|
90 |
+
2022-07-19,50.0
|
91 |
+
2022-07-20,45.0
|
92 |
+
2022-07-21,88.0
|
93 |
+
2022-07-22,73.0
|
94 |
+
2022-07-23,112.0
|
95 |
+
2022-07-24,128.0
|
96 |
+
2022-07-25,166.0
|
97 |
+
2022-07-26,169.0
|
98 |
+
2022-07-27,173.0
|
99 |
+
2022-07-28,161.0
|
100 |
+
2022-07-29,130.0
|
101 |
+
2022-07-30,136.0
|
102 |
+
2022-07-31,87.0
|
103 |
+
2022-08-01,136.0
|
104 |
+
2022-08-02,189.0
|
105 |
+
2022-08-03,138.0
|
106 |
+
2022-08-04,163.0
|
107 |
+
2022-08-05,167.0
|
108 |
+
2022-08-06,66.0
|
109 |
+
2022-08-07,37.0
|
110 |
+
2022-08-08,49.0
|
111 |
+
2022-08-09,91.0
|
112 |
+
2022-08-10,112.0
|
113 |
+
2022-08-11,92.0
|
114 |
+
2022-08-12,66.0
|
115 |
+
2022-08-13,41.0
|
116 |
+
2022-08-14,16.0
|
117 |
+
2022-08-15,55.0
|
118 |
+
2022-08-16,93.0
|
119 |
+
2022-08-17,90.0
|
120 |
+
2022-08-18,90.0
|
121 |
+
2022-08-19,74.0
|
122 |
+
2022-08-20,43.0
|
123 |
+
2022-08-21,34.0
|
124 |
+
2022-08-22,36.0
|
125 |
+
2022-08-23,75.0
|
126 |
+
2022-08-24,59.0
|
127 |
+
2022-08-25,74.0
|
128 |
+
2022-08-26,60.0
|
129 |
+
2022-08-27,32.0
|
130 |
+
2022-08-28,13.0
|
131 |
+
2022-08-29,19.0
|
132 |
+
2022-08-30,59.0
|
133 |
+
2022-08-31,36.0
|
134 |
+
2022-09-01,36.0
|
135 |
+
2022-09-02,46.0
|
136 |
+
2022-09-03,15.0
|
137 |
+
2022-09-04,7.0
|
138 |
+
2022-09-05,9.0
|
139 |
+
2022-09-06,30.0
|
140 |
+
2022-09-07,37.0
|
141 |
+
2022-09-08,47.0
|
142 |
+
2022-09-09,39.0
|
143 |
+
2022-09-10,10.0
|
144 |
+
2022-09-11,12.0
|
145 |
+
2022-09-12,21.0
|
146 |
+
2022-09-13,24.0
|
147 |
+
2022-09-14,32.0
|
148 |
+
2022-09-15,19.0
|
149 |
+
2022-09-16,27.0
|
150 |
+
2022-09-17,11.0
|
151 |
+
2022-09-18,8.0
|
152 |
+
2022-09-19,34.0
|
153 |
+
2022-09-20,24.0
|
154 |
+
2022-09-21,12.0
|
155 |
+
2022-09-22,18.0
|
156 |
+
2022-09-23,14.0
|
157 |
+
2022-09-24,1.0
|
158 |
+
2022-09-25,3.0
|
159 |
+
2022-09-26,5.0
|
160 |
+
2022-09-27,10.0
|
161 |
+
2022-09-28,13.0
|
162 |
+
2022-09-29,16.0
|
163 |
+
2022-09-30,12.0
|
164 |
+
2022-10-01,9.0
|
165 |
+
2022-10-02,3.0
|
166 |
+
2022-10-03,9.0
|
167 |
+
2022-10-04,9.0
|
168 |
+
2022-10-05,9.0
|
169 |
+
2022-10-06,7.0
|
170 |
+
2022-10-07,3.0
|
171 |
+
2022-10-08,4.0
|
172 |
+
2022-10-09,3.0
|
173 |
+
2022-10-10,3.0
|
174 |
+
2022-10-11,7.0
|
175 |
+
2022-10-12,15.0
|
176 |
+
2022-10-13,22.0
|
177 |
+
2022-10-14,13.0
|
178 |
+
2022-10-15,4.0
|
179 |
+
2022-10-16,2.0
|
180 |
+
2022-10-17,14.0
|
181 |
+
2022-10-18,12.0
|
182 |
+
2022-10-19,10.0
|
183 |
+
2022-10-20,11.0
|
184 |
+
2022-10-21,10.0
|
185 |
+
2022-10-22,7.0
|
186 |
+
2022-10-23,6.0
|
187 |
+
2022-10-24,7.0
|
188 |
+
2022-10-25,4.0
|
189 |
+
2022-10-26,8.0
|
190 |
+
2022-10-27,10.0
|
191 |
+
2022-10-28,15.0
|
192 |
+
2022-10-29,5.0
|
193 |
+
2022-10-30,4.0
|
194 |
+
2022-10-31,15.0
|
195 |
+
2022-11-01,7.0
|
196 |
+
2022-11-02,8.0
|
197 |
+
2022-11-03,6.0
|
198 |
+
2022-11-04,10.0
|
199 |
+
2022-11-05,3.0
|
200 |
+
2022-11-06,3.0
|
201 |
+
2022-11-07,9.0
|
202 |
+
2022-11-08,7.0
|
203 |
+
2022-11-09,3.0
|
204 |
+
2022-11-10,4.0
|
205 |
+
2022-11-11,2.0
|
206 |
+
2022-11-12,3.0
|
207 |
+
2022-11-13,3.0
|
208 |
+
2022-11-14,8.0
|
209 |
+
2022-11-15,7.0
|
210 |
+
2022-11-16,5.0
|
211 |
+
2022-11-17,6.0
|
212 |
+
2022-11-18,2.0
|
213 |
+
2022-11-19,2.0
|
214 |
+
2022-11-20,0.0
|
215 |
+
2022-11-21,0.0
|
216 |
+
2022-11-22,4.0
|
217 |
+
2022-11-23,7.0
|
218 |
+
2022-11-24,11.0
|
219 |
+
2022-11-25,1.0
|
220 |
+
2022-11-26,3.0
|
221 |
+
2022-11-27,0.0
|
222 |
+
2022-11-28,14.0
|
223 |
+
2022-11-29,16.0
|
224 |
+
2022-11-30,4.0
|
225 |
+
2022-12-01,11.0
|
226 |
+
2022-12-02,5.0
|
227 |
+
2022-12-03,2.0
|
228 |
+
2022-12-04,1.0
|
229 |
+
2022-12-05,6.0
|
230 |
+
2022-12-06,6.0
|
231 |
+
2022-12-07,8.0
|
232 |
+
2022-12-08,7.0
|
233 |
+
2022-12-09,7.0
|
234 |
+
2022-12-10,3.0
|
235 |
+
2022-12-11,3.0
|
236 |
+
2022-12-12,4.0
|
237 |
+
2022-12-13,5.0
|
238 |
+
2022-12-14,6.0
|
239 |
+
2022-12-15,2.0
|
240 |
+
2022-12-16,4.0
|
241 |
+
2022-12-17,3.0
|
242 |
+
2022-12-18,2.0
|
243 |
+
2022-12-19,2.0
|
244 |
+
2022-12-20,4.0
|
245 |
+
2022-12-21,12.0
|
246 |
+
2022-12-22,16.0
|
247 |
+
2022-12-23,9.0
|
248 |
+
2022-12-24,4.0
|
249 |
+
2022-12-25,3.0
|
250 |
+
2022-12-26,4.0
|
251 |
+
2022-12-27,4.0
|
252 |
+
2022-12-28,5.0
|
253 |
+
2022-12-29,14.0
|
254 |
+
2022-12-30,13.0
|
255 |
+
2022-12-31,4.0
|
256 |
+
2023-01-01,3.0
|
257 |
+
2023-01-02,2.0
|
258 |
+
2023-01-03,10.0
|
259 |
+
2023-01-04,10.0
|
260 |
+
2023-01-05,19.0
|
261 |
+
2023-01-06,7.0
|
262 |
+
2023-01-07,3.0
|
263 |
+
2023-01-08,10.0
|
264 |
+
2023-01-09,12.0
|
265 |
+
2023-01-10,17.0
|
266 |
+
2023-01-11,14.0
|
267 |
+
2023-01-12,15.0
|
268 |
+
2023-01-13,10.0
|
269 |
+
2023-01-14,15.0
|
270 |
+
2023-01-15,1.0
|
271 |
+
2023-01-16,9.0
|
272 |
+
2023-01-17,6.0
|
273 |
+
2023-01-18,4.0
|
274 |
+
2023-01-19,2.0
|
275 |
+
2023-01-20,4.0
|
276 |
+
2023-01-21,2.0
|
277 |
+
2023-01-22,3.0
|
278 |
+
2023-01-23,1.0
|
279 |
+
2023-01-24,4.0
|
280 |
+
2023-01-25,7.0
|
281 |
+
2023-01-26,5.0
|
282 |
+
2023-01-27,1.0
|
283 |
+
2023-01-28,0.0
|
284 |
+
2023-01-29,2.0
|
285 |
+
2023-01-30,3.0
|
286 |
+
2023-01-31,6.0
|
287 |
+
2023-02-01,8.0
|
288 |
+
2023-02-02,7.0
|
289 |
+
2023-02-03,4.0
|
290 |
+
2023-02-04,3.0
|
291 |
+
2023-02-05,1.0
|
292 |
+
2023-02-06,13.0
|
293 |
+
2023-02-07,4.0
|
294 |
+
2023-02-08,6.0
|
295 |
+
2023-02-09,3.0
|
296 |
+
2023-02-10,7.0
|
297 |
+
2023-02-11,3.0
|
298 |
+
2023-02-12,0.0
|
299 |
+
2023-02-13,3.0
|
models/lstm_forec_40_11_06.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df812ddfa22e1d13ed44b8645a69b5a5c6d00ccd3008975f59dca3b9e3637b5c
|
3 |
+
size 19807
|
requirements.txt
CHANGED
@@ -1,2 +1,5 @@
|
|
1 |
gradio==5.29.0
|
2 |
-
|
|
|
|
|
|
|
|
1 |
gradio==5.29.0
|
2 |
+
torch==2.7.0
|
3 |
+
scikit-learn==1.6.1
|
4 |
+
plotly==6.0.1
|
5 |
+
statsmodels==0.14.4
|
sections/try_it_yourself.html
CHANGED
@@ -4,7 +4,52 @@
|
|
4 |
<li>📈 Upload a CSV file with two columns: dates in the first column and disease mention counts in the second</li>
|
5 |
<li>🎯 Click "Plot Time Series" to visualize your data</li>
|
6 |
<li>🔍 Select an anomaly detection method from the dropdown</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
<li>⚡ Click "Detect Anomalies" to identify unusual patterns in your time series</li>
|
8 |
</ul>
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
</section>
|
|
|
4 |
<li>📈 Upload a CSV file with two columns: dates in the first column and disease mention counts in the second</li>
|
5 |
<li>🎯 Click "Plot Time Series" to visualize your data</li>
|
6 |
<li>🔍 Select an anomaly detection method from the dropdown</li>
|
7 |
+
<li>⚙️ Configure the detection parameters:
|
8 |
+
<ul>
|
9 |
+
<li><strong>For LSTM method:</strong>
|
10 |
+
<ul>
|
11 |
+
<li><em>k</em>: Controls sensitivity (1-3). Higher values mean stricter anomaly detection.</li>
|
12 |
+
<li><em>Percentile</em>: Threshold percentile for anomaly detection (0-100).</li>
|
13 |
+
<li><em>Threshold Method</em>: Choose how to calculate anomaly thresholds:
|
14 |
+
<ul>
|
15 |
+
<li>IQR-based methods: Compare predictions with actual values using different metrics</li>
|
16 |
+
<li>Percentile-based methods: Use statistical thresholds on prediction errors</li>
|
17 |
+
</ul>
|
18 |
+
</li>
|
19 |
+
</ul>
|
20 |
+
</li>
|
21 |
+
<li><strong>For ARIMA method:</strong>
|
22 |
+
<ul>
|
23 |
+
<li><em>k</em>: Sensitivity multiplier for standard deviation-based thresholds (1-3).</li>
|
24 |
+
</ul>
|
25 |
+
</li>
|
26 |
+
<li><strong>For IQR method:</strong>
|
27 |
+
<ul>
|
28 |
+
<li><em>k</em>: IQR multiplier (1-3). Higher values detect more extreme outliers.</li>
|
29 |
+
</ul>
|
30 |
+
</li>
|
31 |
+
</ul>
|
32 |
+
</li>
|
33 |
<li>⚡ Click "Detect Anomalies" to identify unusual patterns in your time series</li>
|
34 |
</ul>
|
35 |
+
|
36 |
+
<div class="example-section">
|
37 |
+
<h3>📋 Example Dataset</h3>
|
38 |
+
<p>Try out the tool with our sample dataset:</p>
|
39 |
+
<ul>
|
40 |
+
<li><strong>Dataset:</strong> <code>mpox.csv</code> - News coverage time series for Monkeypox/Mpox outbreak</li>
|
41 |
+
<li><strong>Time Period:</strong> Daily counts from early 2022</li>
|
42 |
+
<li><strong>Recommended Settings:</strong>
|
43 |
+
<ul>
|
44 |
+
<li>Method: LSTM</li>
|
45 |
+
<li>k: 1.5</li>
|
46 |
+
<li>Percentile: 95</li>
|
47 |
+
<li>Threshold Method: "IQR on |ground truth - forecast|"</li>
|
48 |
+
</ul>
|
49 |
+
</li>
|
50 |
+
<li><strong>Expected Results:</strong> The analysis should identify significant spikes in news coverage that corresponded to major outbreak events and public health announcements during the 2022 Mpox outbreak.</li>
|
51 |
+
</ul>
|
52 |
+
</div>
|
53 |
+
|
54 |
+
<p>This tool combines time series analysis and anomaly detection to help identify potential disease outbreaks based on news coverage patterns. The results can be used to alert public health officials about emerging health concerns. 💡</p>
|
55 |
</section>
|
src/__init__.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .plotting.visualization import plot_time_series, plot_anomalies
|
2 |
+
from .anomaly_detection import detect_anomalies
|
3 |
+
from .outbreak_detection import (
|
4 |
+
LSTMforOutbreakDetection,
|
5 |
+
ARIMAforOutbreakDetection,
|
6 |
+
IQRforOutbreakDetection
|
7 |
+
)
|
8 |
+
from .utils import (
|
9 |
+
timestamp_wise_evaluation,
|
10 |
+
tolerance_based_evaluation
|
11 |
+
)
|
12 |
+
|
13 |
+
__all__ = [
|
14 |
+
'plot_time_series',
|
15 |
+
'plot_anomalies',
|
16 |
+
'detect_anomalies',
|
17 |
+
'LSTMforOutbreakDetection',
|
18 |
+
'ARIMAforOutbreakDetection',
|
19 |
+
'IQRforOutbreakDetection',
|
20 |
+
'timestamp_wise_evaluation',
|
21 |
+
'tolerance_based_evaluation'
|
22 |
+
'prepare_time_series_dataframe',
|
23 |
+
'tolerance_based_evaluation',
|
24 |
+
]
|
src/anomaly_detection.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from .outbreak_detection import (
|
3 |
+
LSTMforOutbreakDetection,
|
4 |
+
ARIMAforOutbreakDetection,
|
5 |
+
IQRforOutbreakDetection
|
6 |
+
)
|
7 |
+
from .plotting.visualization import plot_anomalies
|
8 |
+
from .utils import prepare_time_series_dataframe
|
9 |
+
|
10 |
+
|
11 |
+
THRESHOLD_METHODS = {
|
12 |
+
"IQR on (ground truth - forecast)": 0,
|
13 |
+
"IQR on |ground truth - forecast|": 1,
|
14 |
+
"IQR on |ground truth - forecast|/forecast": 2,
|
15 |
+
"Percentile threshold on absolute loss": 3,
|
16 |
+
"Percentile threshold on raw loss": 4
|
17 |
+
}
|
18 |
+
|
19 |
+
def detect_anomalies(file_path: str, method: str, k: int, percentile: float, threshold_method: int):
|
20 |
+
"""
|
21 |
+
Detects anomalies in time series data using various detection methods.
|
22 |
+
Args:
|
23 |
+
file_path (str): Path to the CSV file containing time series data
|
24 |
+
method (str): Detection method to use ('LSTM', 'ARIMA', or 'IQR')
|
25 |
+
k (int): Number of neighbors or window size (method-dependent parameter)
|
26 |
+
percentile (float): Percentile threshold for anomaly detection
|
27 |
+
threshold_method (int): Method to determine threshold for anomaly detection
|
28 |
+
Returns:
|
29 |
+
plotly.graph_objects.Figure: Plotly figure containing the time series with highlighted anomalies
|
30 |
+
"""
|
31 |
+
df = pd.read_csv(file_path)
|
32 |
+
df = prepare_time_series_dataframe(df)
|
33 |
+
|
34 |
+
# Map threshold methods to their descriptions for better readability
|
35 |
+
|
36 |
+
detectors = {
|
37 |
+
'LSTM': LSTMforOutbreakDetection(
|
38 |
+
checkpoint_path='models/lstm_forec_40_11_06.pth',
|
39 |
+
k=k,
|
40 |
+
percentile=percentile,
|
41 |
+
threshold_method=THRESHOLD_METHODS[threshold_method]
|
42 |
+
),
|
43 |
+
'ARIMA': ARIMAforOutbreakDetection(k=k),
|
44 |
+
'IQR': IQRforOutbreakDetection(k=k)
|
45 |
+
}
|
46 |
+
|
47 |
+
detector = detectors[method]
|
48 |
+
test, new_label = detector.detect_anomalies(df)
|
49 |
+
return plot_anomalies(test, anomaly_col=new_label)
|
50 |
+
|
51 |
+
|
src/outbreak_detection/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .lstm import LSTMforOutbreakDetection
|
2 |
+
from .arima import ARIMAforOutbreakDetection
|
3 |
+
from .iqr import IQRforOutbreakDetection
|
4 |
+
from .lstm_model import LstmModel, testing
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
'LSTMforOutbreakDetection',
|
8 |
+
'ARIMAforOutbreakDetection',
|
9 |
+
'IQRforOutbreakDetection',
|
10 |
+
'LstmModel',
|
11 |
+
'testing'
|
12 |
+
]
|
src/outbreak_detection/arima.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
from statsmodels.tsa.stattools import adfuller, acf, pacf
|
4 |
+
from statsmodels.tsa.arima.model import ARIMA
|
5 |
+
|
6 |
+
|
7 |
+
NEW_ANOMALY_COLUMN_NAME = 'anomaly'
|
8 |
+
|
9 |
+
class ARIMAforOutbreakDetection:
|
10 |
+
def __init__(self, window_size=7, stride=1, k=1.5, significance=0.05, max_lag=30):
|
11 |
+
self.window_size = window_size
|
12 |
+
self.stride = stride
|
13 |
+
self.k = k
|
14 |
+
self.significance = significance
|
15 |
+
self.max_lag = max_lag
|
16 |
+
|
17 |
+
def test_stationarity(self, ts_data, column=''):
|
18 |
+
if isinstance(ts_data, pd.Series):
|
19 |
+
adf_test = adfuller(ts_data, autolag='AIC')
|
20 |
+
else:
|
21 |
+
adf_test = adfuller(ts_data[column], autolag='AIC')
|
22 |
+
return "Stationary" if adf_test[1] <= self.significance else "Non-Stationary"
|
23 |
+
|
24 |
+
def make_stationary(self, dataframe, column):
|
25 |
+
df_to_return = None
|
26 |
+
result = self.test_stationarity(dataframe, column)
|
27 |
+
|
28 |
+
if result == "Stationary":
|
29 |
+
return dataframe
|
30 |
+
|
31 |
+
diff_series = dataframe.copy()
|
32 |
+
for diff_count in range(5):
|
33 |
+
diff_series = diff_series.diff().fillna(0)
|
34 |
+
if self.test_stationarity(diff_series, column) == "Stationary":
|
35 |
+
return diff_series
|
36 |
+
return diff_series
|
37 |
+
|
38 |
+
def create_windows(self, df):
|
39 |
+
windows, gts = [], []
|
40 |
+
for i in range(0, len(df) - self.window_size, self.stride):
|
41 |
+
end_id = i + self.window_size
|
42 |
+
windows.append(df.iloc[i:end_id, :])
|
43 |
+
gts.append(df.iloc[end_id, :])
|
44 |
+
return np.stack(windows), np.stack(gts)
|
45 |
+
|
46 |
+
def find_p_q(self, series):
|
47 |
+
N = len(series)
|
48 |
+
acf_values, _ = acf(series, nlags=self.max_lag, alpha=self.significance, fft=False)
|
49 |
+
pacf_values, _ = pacf(series, nlags=self.max_lag, alpha=self.significance)
|
50 |
+
threshold = 1.96 / np.sqrt(N)
|
51 |
+
|
52 |
+
def find_last_consecutive_outlier(values):
|
53 |
+
for i in range(1, len(values)):
|
54 |
+
if values[i] < 0 or (values[i] > 0 and abs(values[i]) < threshold):
|
55 |
+
return i
|
56 |
+
return len(values) - 1
|
57 |
+
|
58 |
+
return find_last_consecutive_outlier(pacf_values), find_last_consecutive_outlier(acf_values)
|
59 |
+
|
60 |
+
def detect_anomalies(self, dataset, news_or_cases='news'):
|
61 |
+
stationary_data = self.make_stationary(dataset, news_or_cases)
|
62 |
+
p, q = self.find_p_q(stationary_data[news_or_cases])
|
63 |
+
anomalies, means, stdevs, residuals, predictions, gts = self._train_arima_model(stationary_data, p, q)
|
64 |
+
result_df = self._prepare_resulting_dataframe(
|
65 |
+
residuals, means, stdevs, dataset.iloc[self.window_size:],
|
66 |
+
anomalies, gts, predictions
|
67 |
+
)
|
68 |
+
return self._postprocess_anomalies(result_df, news_or_cases), NEW_ANOMALY_COLUMN_NAME
|
69 |
+
|
70 |
+
def _train_arima_model(self, dataset, p, q):
|
71 |
+
predictions, residuals, means, stdevs, anomalies = [], [], [], [], []
|
72 |
+
windows, gts = self.create_windows(dataset)
|
73 |
+
|
74 |
+
for window, gt in zip(windows, gts):
|
75 |
+
model = ARIMA(window, order=(p, 0, q))
|
76 |
+
model.initialize_approximate_diffuse()
|
77 |
+
fit = model.fit()
|
78 |
+
|
79 |
+
pred = fit.forecast(steps=1)[0]
|
80 |
+
residual = np.abs(gt - pred)
|
81 |
+
mu, std = np.mean(fit.resid), np.std(fit.resid)
|
82 |
+
|
83 |
+
anomalies.append(
|
84 |
+
1 if residual > mu + self.k * std or residual < mu - self.k * std else 0
|
85 |
+
)
|
86 |
+
|
87 |
+
means.append(mu)
|
88 |
+
stdevs.append(std)
|
89 |
+
residuals.append(residual)
|
90 |
+
predictions.append(pred)
|
91 |
+
|
92 |
+
return anomalies, means, stdevs, residuals, predictions, gts
|
93 |
+
|
94 |
+
def _prepare_resulting_dataframe(self, residuals, means, stdevs, original_dataset,
|
95 |
+
anomalies, gts, predictions):
|
96 |
+
result_df = original_dataset.copy()
|
97 |
+
result_df['residuals'] = residuals
|
98 |
+
result_df['mu'] = means
|
99 |
+
result_df['sigma'] = stdevs
|
100 |
+
result_df['anomaly'] = anomalies
|
101 |
+
result_df['gts_diff'] = gts
|
102 |
+
result_df['pred_diff'] = predictions
|
103 |
+
return result_df
|
104 |
+
|
105 |
+
def _postprocess_anomalies(self, dataframe, col_name='news'):
|
106 |
+
dataframe['derivative'] = dataframe[col_name].diff().fillna(0)
|
107 |
+
dataframe['new_anomaly'] = [
|
108 |
+
0 if row.derivative < 0 and row.anomaly == 1 else row.anomaly
|
109 |
+
for _, row in dataframe.iterrows()
|
110 |
+
]
|
111 |
+
return dataframe
|
src/outbreak_detection/iqr.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
NEW_LABEL_COLUMN_NAME = 'new_label'
|
6 |
+
|
7 |
+
class IQRforOutbreakDetection:
|
8 |
+
def __init__(self, window_size=7, stride=1, k=1.5):
|
9 |
+
self.window_size = window_size
|
10 |
+
self.stride = stride
|
11 |
+
self.k = k
|
12 |
+
|
13 |
+
def _iqr_rolling(self, timeseries):
|
14 |
+
q1 = np.percentile(timeseries, 25)
|
15 |
+
q3 = np.percentile(timeseries, 75)
|
16 |
+
iqr = q3 - q1
|
17 |
+
ub = q3 + self.k * iqr
|
18 |
+
lb = q1 - self.k * iqr
|
19 |
+
return ub, lb
|
20 |
+
|
21 |
+
def detect_anomalies(self, df, news_or_cases='news'):
|
22 |
+
""""
|
23 |
+
input methods: k
|
24 |
+
"""
|
25 |
+
if isinstance(df, pd.Series):
|
26 |
+
timeseries = df
|
27 |
+
else:
|
28 |
+
timeseries = df[news_or_cases]
|
29 |
+
|
30 |
+
tot_peaks, final_peaks, _ = self._windowed_iqr(timeseries)
|
31 |
+
result_df = self._prepare_resulting_dataframe(final_peaks, timeseries)
|
32 |
+
processed_df = self._postprocess_anomalies(result_df, news_or_cases)
|
33 |
+
print(processed_df)
|
34 |
+
|
35 |
+
return processed_df, NEW_LABEL_COLUMN_NAME
|
36 |
+
|
37 |
+
def _windowed_iqr(self, df):
|
38 |
+
tot_peaks = {}
|
39 |
+
for i in range(0, len(df) - self.window_size + 1, self.stride):
|
40 |
+
end_id = i + self.window_size
|
41 |
+
window = df[i:end_id]
|
42 |
+
ub, _ = self._iqr_rolling(window)
|
43 |
+
|
44 |
+
for j in window.index:
|
45 |
+
peaks_list = tot_peaks.setdefault(f'{j}', [])
|
46 |
+
peaks_list.append(window.loc[j] > ub)
|
47 |
+
|
48 |
+
final_peaks = {k: True if True in v else False
|
49 |
+
for k, v in tot_peaks.items()}
|
50 |
+
|
51 |
+
return tot_peaks, final_peaks, end_id
|
52 |
+
|
53 |
+
def _prepare_resulting_dataframe(self, peaks_df, news_or_cases_df):
|
54 |
+
final_df_iqr = pd.DataFrame.from_dict(peaks_df, orient='index')
|
55 |
+
dff = pd.DataFrame(news_or_cases_df)
|
56 |
+
dff['peaks'] = final_df_iqr.loc[:, 0].values
|
57 |
+
dff['peaks'] = dff['peaks'].map({True: 1, False: 0})
|
58 |
+
return dff
|
59 |
+
|
60 |
+
def _postprocess_anomalies(self, dataframe, col_name='news'):
|
61 |
+
dataframe['derivative'] = dataframe[col_name].diff().fillna(0)
|
62 |
+
dataframe['new_label'] = [0 if v.derivative < 0 and v.peaks == 1 else v.peaks
|
63 |
+
for _, v in dataframe.iterrows()]
|
64 |
+
return dataframe
|
src/outbreak_detection/lstm.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import torch.utils.data as data_utils
|
5 |
+
from sklearn.preprocessing import MinMaxScaler
|
6 |
+
from .lstm_model import LstmModel, testing
|
7 |
+
|
8 |
+
|
9 |
+
PRETRAINED_MODEL_N_CHANNELS = 1
|
10 |
+
PRETRAINED_MODEL_Z_SIZE = 32
|
11 |
+
|
12 |
+
class LSTMforOutbreakDetection:
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
checkpoint_path=None,
|
16 |
+
n_channels=PRETRAINED_MODEL_N_CHANNELS,
|
17 |
+
z_size=PRETRAINED_MODEL_Z_SIZE,
|
18 |
+
device='cpu',
|
19 |
+
window=7,
|
20 |
+
batch_size=32,
|
21 |
+
k=1.5,
|
22 |
+
percentile=95,
|
23 |
+
threshold_method=0
|
24 |
+
):
|
25 |
+
self.device = torch.device(device)
|
26 |
+
self.window = window
|
27 |
+
self.batch_size = batch_size
|
28 |
+
self.n_channels = n_channels
|
29 |
+
self.z_size = z_size
|
30 |
+
self.scaler = MinMaxScaler(feature_range=(0,1))
|
31 |
+
self.k = k
|
32 |
+
self.percentile = percentile
|
33 |
+
self.threshold_method = threshold_method
|
34 |
+
if checkpoint_path:
|
35 |
+
self.model = self._load_model(checkpoint_path)
|
36 |
+
|
37 |
+
def _load_model(self, checkpoint_path):
|
38 |
+
model = LstmModel(self.n_channels, self.z_size)
|
39 |
+
model = model.to(self.device)
|
40 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location=self.device))
|
41 |
+
return model
|
42 |
+
|
43 |
+
def create_test_sequences(self, dataframe, time_steps, news_or_cases='news'):
|
44 |
+
if news_or_cases not in ['news', 'cases']:
|
45 |
+
raise ValueError("news_or_cases should be either 'news' or 'cases'")
|
46 |
+
|
47 |
+
output, output2 = [], []
|
48 |
+
dataframe[[news_or_cases]] = self.scaler.fit_transform(dataframe[[news_or_cases]])
|
49 |
+
norm = np.array(dataframe[[news_or_cases]]).astype(float)
|
50 |
+
|
51 |
+
for i in range(len(norm)):
|
52 |
+
end_ix = i + time_steps
|
53 |
+
|
54 |
+
if end_ix > len(norm)-1:
|
55 |
+
break
|
56 |
+
|
57 |
+
seq_x, seq_y = norm[i:end_ix, :], norm[end_ix, 0]
|
58 |
+
output.append(seq_x)
|
59 |
+
output2.append(seq_y)
|
60 |
+
|
61 |
+
return np.stack(output), np.stack(output2)
|
62 |
+
|
63 |
+
def prepare_input_dataframe(self, dataframe, news_column_name='news'):
|
64 |
+
X_test, y_test = self.create_test_sequences(dataframe, self.window, news_column_name)
|
65 |
+
test_loader = torch.utils.data.DataLoader(
|
66 |
+
data_utils.TensorDataset(
|
67 |
+
torch.from_numpy(X_test).float(),
|
68 |
+
torch.from_numpy(y_test).float()
|
69 |
+
),
|
70 |
+
batch_size=self.batch_size,
|
71 |
+
shuffle=False,
|
72 |
+
num_workers=0
|
73 |
+
)
|
74 |
+
return test_loader, y_test
|
75 |
+
|
76 |
+
def predict(self, dataframe, news_column_name='news'):
|
77 |
+
test_loader, y_test = self.prepare_input_dataframe(dataframe, news_column_name)
|
78 |
+
results, w = testing(self.model, test_loader, self.device)
|
79 |
+
forecast_test = np.concatenate([
|
80 |
+
torch.stack(w[:-1]).flatten().detach().cpu().numpy(),
|
81 |
+
w[-1].flatten().detach().cpu().numpy()
|
82 |
+
])
|
83 |
+
|
84 |
+
test_df = dataframe[self.window:].copy()
|
85 |
+
test_df['y_test'] = y_test
|
86 |
+
test_df['pred_forec'] = forecast_test
|
87 |
+
test_df['abs_loss'] = np.abs(test_df.y_test - test_df.pred_forec)
|
88 |
+
test_df['rel_loss'] = np.abs((test_df['pred_forec'] - test_df['y_test']) / (1 + test_df['pred_forec']))
|
89 |
+
test_df['diff'] = test_df['y_test'] - test_df['pred_forec']
|
90 |
+
|
91 |
+
return test_df
|
92 |
+
|
93 |
+
@staticmethod
|
94 |
+
def _iqr_rolling(timeseries, k):
|
95 |
+
q1, q3 = np.percentile(timeseries, [25, 75])
|
96 |
+
iqr = q3 - q1
|
97 |
+
|
98 |
+
return q3 + k * iqr
|
99 |
+
|
100 |
+
def windowed_iqr(self, df, k, type_of_loss='diff'):
|
101 |
+
peaks = {}
|
102 |
+
|
103 |
+
for i in range(len(df)):
|
104 |
+
end_ix = i + self.window
|
105 |
+
|
106 |
+
if end_ix > len(df)-1:
|
107 |
+
break
|
108 |
+
|
109 |
+
seq_x = df.iloc[i:end_ix, :]
|
110 |
+
ub = self._iqr_rolling(seq_x[type_of_loss], k)
|
111 |
+
|
112 |
+
for j in seq_x.index:
|
113 |
+
condition = int(seq_x.loc[j, type_of_loss] > ub)
|
114 |
+
peaks.setdefault(f'{j}', []).append(condition)
|
115 |
+
|
116 |
+
return {k: 1 if sum(v) > 0 else 0 for k, v in peaks.items()}
|
117 |
+
|
118 |
+
def get_perc_threshold(self, test_df, percentile, col='abs_loss'):
|
119 |
+
if col not in ['abs_loss', 'loss']:
|
120 |
+
raise ValueError("col should be either 'abs_loss' or 'loss'")
|
121 |
+
|
122 |
+
test1 = test_df[:-1].copy()
|
123 |
+
anom_perc_loss = {}
|
124 |
+
|
125 |
+
for i in range(len(test_df)):
|
126 |
+
end_ix = i + self.window
|
127 |
+
if end_ix > len(test_df)-1:
|
128 |
+
break
|
129 |
+
|
130 |
+
seq_x = test_df.iloc[i:end_ix, :].copy()
|
131 |
+
mae = seq_x['abs_loss'].values if col == 'abs_loss' else seq_x['y_test'] - seq_x['pred_forec']
|
132 |
+
|
133 |
+
threshold = np.percentile(mae, percentile)
|
134 |
+
seq_x['threshold'] = threshold
|
135 |
+
|
136 |
+
for j in seq_x.index:
|
137 |
+
condition = int(seq_x.loc[j, col] > seq_x.loc[j, 'threshold'])
|
138 |
+
anom_perc_loss.setdefault(f'{j}', []).append(condition)
|
139 |
+
|
140 |
+
final_anom = {k: 1 if sum(v) > 0 else 0 for k, v in anom_perc_loss.items()}
|
141 |
+
new_col = 'anom_perc_abs_loss' if col == 'abs_loss' else 'anom_perc_diff_gt_pred'
|
142 |
+
test1[new_col] = pd.Series(final_anom)
|
143 |
+
|
144 |
+
return test1
|
145 |
+
|
146 |
+
def postprocess_anomalies(self, test_df, new_col, old_col, news_or_cases):
|
147 |
+
test_df = test_df.copy()
|
148 |
+
test_df['derivative'] = test_df[news_or_cases].diff().fillna(0)
|
149 |
+
test_df[new_col] = [0 if v.derivative < 0 and v[old_col] == 1 else v[old_col]
|
150 |
+
for k, v in test_df.iterrows()]
|
151 |
+
|
152 |
+
return test_df
|
153 |
+
|
154 |
+
def detect_anomalies(self, test_df, news_or_cases='news'):
|
155 |
+
"""
|
156 |
+
Detect anomalies using different methods:
|
157 |
+
0: IQR on (ground truth - forecast)
|
158 |
+
1: IQR on |ground truth - forecast|
|
159 |
+
2: IQR on |ground truth - forecast|/forecast
|
160 |
+
3: Percentile threshold on absolute loss
|
161 |
+
4: Percentile threshold on raw loss
|
162 |
+
|
163 |
+
input parameters: k (1-3), threshold_method, percentile
|
164 |
+
"""
|
165 |
+
test_df = test_df.copy()
|
166 |
+
|
167 |
+
test = self.predict(test_df, news_column_name=news_or_cases)
|
168 |
+
|
169 |
+
if self.threshold_method in [0, 1, 2]:
|
170 |
+
loss_type = {0: 'diff', 1: 'abs_loss', 2: 'rel_loss'}[self.threshold_method]
|
171 |
+
iqr_suffix = {0: 'f_iqr', 1: 'abs_iqr', 2: 'rel_iqr'}[self.threshold_method]
|
172 |
+
new_label = {0: 'f_new_label', 1: 'abs_new_label', 2: 'rel_new_label'}[self.threshold_method]
|
173 |
+
|
174 |
+
peaks = self.windowed_iqr(test, self.k, loss_type)
|
175 |
+
peak_series = pd.Series(peaks)
|
176 |
+
peak_series.index = pd.to_datetime(peak_series.index)
|
177 |
+
test[iqr_suffix] = peak_series
|
178 |
+
test = self.postprocess_anomalies(test, new_label, iqr_suffix, news_or_cases)
|
179 |
+
return test, new_label
|
180 |
+
|
181 |
+
elif self.threshold_method in [3, 4]:
|
182 |
+
loss_type = 'abs_loss' if self.threshold_method == 3 else 'loss'
|
183 |
+
new_label = 'new_anom_absl' if self.threshold_method == 3 else 'new_anom_diff'
|
184 |
+
old_label = 'anom_perc_abs_loss' if self.threshold_method == 3 else 'anom_perc_diff_gt_pred'
|
185 |
+
|
186 |
+
test = self.get_perc_threshold(test, self.percentile, loss_type)
|
187 |
+
test = self.postprocess_anomalies(test, new_label, old_label, news_or_cases)
|
188 |
+
return test, new_label
|
189 |
+
|
190 |
+
raise ValueError("threshold_method must be between 0 and 4")
|
191 |
+
|
src/outbreak_detection/lstm_model.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class LstmModel(nn.Module):
|
6 |
+
def __init__(self, in_size: int, latent_size: int) -> None:
|
7 |
+
"""
|
8 |
+
Initialize the LSTM autoencoder model.
|
9 |
+
|
10 |
+
Parameters:
|
11 |
+
-----------
|
12 |
+
in_size : int
|
13 |
+
Number of features in the input (input dimension).
|
14 |
+
latent_size : int
|
15 |
+
Size of the latent space representation in the LSTM.
|
16 |
+
|
17 |
+
Example:
|
18 |
+
--------
|
19 |
+
For in_size = 5, latent_size = 50, the model will:
|
20 |
+
- Take inputs with 5 features
|
21 |
+
- Encode them into a 50-dimensional latent space
|
22 |
+
- Decode back to the original 5 feature dimensions
|
23 |
+
|
24 |
+
Architecture:
|
25 |
+
- LSTM layer for encoding with dropout 0.2
|
26 |
+
- Additional dropout layer (0.2)
|
27 |
+
- ReLU activation
|
28 |
+
- Fully connected layer for decoding
|
29 |
+
"""
|
30 |
+
super().__init__() # Corrected the position of super().__init__()
|
31 |
+
self.lstm = nn.LSTM(
|
32 |
+
input_size=in_size,
|
33 |
+
hidden_size=latent_size,
|
34 |
+
num_layers=1,
|
35 |
+
batch_first=True,
|
36 |
+
dropout=0.2
|
37 |
+
) # input and output tensors are provided as (batch, seq_len, feature(size))
|
38 |
+
self.dropout = nn.Dropout(0.2)
|
39 |
+
self.relu = nn.ReLU()
|
40 |
+
self.fc = nn.Linear(latent_size, in_size)
|
41 |
+
|
42 |
+
def forward(self, w: torch.Tensor) -> torch.Tensor:
|
43 |
+
"""
|
44 |
+
Forward pass through the LSTM model.
|
45 |
+
|
46 |
+
Parameters:
|
47 |
+
-----------
|
48 |
+
w : torch.Tensor
|
49 |
+
Input tensor of shape (batch_size, seq_len, in_size).
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
--------
|
53 |
+
torch.Tensor
|
54 |
+
Output tensor of shape (batch_size, in_size).
|
55 |
+
|
56 |
+
Example:
|
57 |
+
--------
|
58 |
+
If the input tensor w has shape (32, 10, 5), the output will have shape (32, 5).
|
59 |
+
The LSTM processes the input sequence and returns the last output of the sequence.
|
60 |
+
The output is then passed through a ReLU activation function and a dropout layer.
|
61 |
+
Finally, it is passed through a fully connected layer to produce the final output.
|
62 |
+
The output is the reconstructed input sequence.
|
63 |
+
"""
|
64 |
+
z, (h_n, c_n) = self.lstm(w)
|
65 |
+
forecast = z[:, -1, :]
|
66 |
+
forecast = self.relu(forecast)
|
67 |
+
forecast = self.dropout(forecast)
|
68 |
+
output = self.fc(forecast)
|
69 |
+
|
70 |
+
return output
|
71 |
+
|
72 |
+
def testing(model, test_loader, device):
|
73 |
+
results=[]
|
74 |
+
forecast = []
|
75 |
+
with torch.no_grad():
|
76 |
+
for X_batch, y_batch in test_loader:
|
77 |
+
X_batch = X_batch.to(device)
|
78 |
+
y_batch = y_batch.to(device)
|
79 |
+
w=model(X_batch)
|
80 |
+
results.append(torch.mean((y_batch.unsqueeze(1)-w)**2, axis=1))
|
81 |
+
forecast.append(w)
|
82 |
+
|
83 |
+
return results, forecast
|
src/plotting/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .visualization import plot_time_series, plot_anomalies
|
2 |
+
|
3 |
+
__all__ = ['plot_time_series', 'plot_anomalies']
|
src/plotting/visualization.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import plotly.graph_objects as go
|
3 |
+
|
4 |
+
|
5 |
+
NEWS_COUNT_COLUMN = 0
|
6 |
+
|
7 |
+
def plot_time_series(file_path):
|
8 |
+
df = pd.read_csv(file_path)
|
9 |
+
fig = go.Figure()
|
10 |
+
fig.add_trace(go.Scatter(x=df.iloc[:, 0], y=df.iloc[:, 1], mode='lines', name='Time Series'))
|
11 |
+
fig.update_layout(title='Disease Mention Time Series', xaxis_title='Date', yaxis_title='Count')
|
12 |
+
|
13 |
+
return fig
|
14 |
+
|
15 |
+
def plot_anomalies(df, anomaly_col='new_label'):
|
16 |
+
print(df)
|
17 |
+
fig = go.Figure()
|
18 |
+
|
19 |
+
fig.add_trace(go.Scatter(
|
20 |
+
x=df.index,
|
21 |
+
y=df.iloc[:, NEWS_COUNT_COLUMN],
|
22 |
+
mode='lines',
|
23 |
+
name='Time Series',
|
24 |
+
line=dict(color='blue')
|
25 |
+
))
|
26 |
+
|
27 |
+
anomalies = df[df[anomaly_col] == 1]
|
28 |
+
|
29 |
+
fig.add_trace(go.Scatter(
|
30 |
+
x=anomalies.index,
|
31 |
+
y=anomalies.iloc[:, NEWS_COUNT_COLUMN],
|
32 |
+
mode='markers',
|
33 |
+
name='Anomalies',
|
34 |
+
marker=dict(color='red', size=10, symbol='circle')
|
35 |
+
))
|
36 |
+
|
37 |
+
fig.update_layout(
|
38 |
+
title='Disease Mention Time Series with Detected Anomalies',
|
39 |
+
xaxis_title='Date',
|
40 |
+
yaxis_title='Count',
|
41 |
+
showlegend=True
|
42 |
+
)
|
43 |
+
|
44 |
+
return fig
|
src/utils.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
from sklearn.metrics import classification_report
|
4 |
+
|
5 |
+
def timestamp_wise_evaluation(anomalies_cases, anomalies_news, threshold):
|
6 |
+
print(f"Classification report for threshold {threshold} (timestamp-wise evaluation):")
|
7 |
+
print(classification_report(anomalies_cases, anomalies_news))
|
8 |
+
|
9 |
+
def tolerance_based_evaluation(anomalies_cases, anomalies_news, cases_df, news_df, threshold):
|
10 |
+
Tp = 0
|
11 |
+
Fp = 0
|
12 |
+
Fn = 0
|
13 |
+
Tn = 0
|
14 |
+
for i in range(len(news_df)):
|
15 |
+
news_an = news_df.iloc[i][anomalies_news]
|
16 |
+
if news_an == 1:
|
17 |
+
if i == len(news_df) - 1:
|
18 |
+
if cases_df.iloc[i][anomalies_cases] == 1:
|
19 |
+
Tp += 1
|
20 |
+
else:
|
21 |
+
Fp += 1
|
22 |
+
elif i == len(news_df) - 2:
|
23 |
+
if cases_df.iloc[i][anomalies_cases] == 1 or cases_df.iloc[i+1][anomalies_cases] == 1:
|
24 |
+
Tp += 1
|
25 |
+
else:
|
26 |
+
Fp += 1
|
27 |
+
else:
|
28 |
+
if cases_df.iloc[i][anomalies_cases] == 1 or cases_df.iloc[i+1][anomalies_cases] == 1 or cases_df.iloc[i+2][anomalies_cases] == 1:
|
29 |
+
Tp += 1
|
30 |
+
else:
|
31 |
+
Fp += 1
|
32 |
+
else:
|
33 |
+
if i == len(news_df) - 1:
|
34 |
+
if cases_df.iloc[i][anomalies_cases] == 1:
|
35 |
+
Fn += 1
|
36 |
+
else:
|
37 |
+
Tn += 1
|
38 |
+
elif i == len(news_df) - 2:
|
39 |
+
if cases_df.iloc[i][anomalies_cases] == 1:
|
40 |
+
Fn += 1
|
41 |
+
else:
|
42 |
+
Tn += 1
|
43 |
+
else:
|
44 |
+
if cases_df.iloc[i][anomalies_cases] == 1:
|
45 |
+
Fn += 1
|
46 |
+
else:
|
47 |
+
Tn += 1
|
48 |
+
print(f"Tolerance-based evaluation for method {threshold}:")
|
49 |
+
print(f"True Positives: {Tp}, False Positives: {Fp}, False Negatives: {Fn}, True Negatives: {Tn}")
|
50 |
+
precision = Tp / (Tp + Fp)
|
51 |
+
recall = Tp / (Tp + Fn)
|
52 |
+
f1 = 2 * (precision * recall) / (precision + recall)
|
53 |
+
print(f"Precision: {precision}, Recall: {recall}, F1: {f1}")
|
54 |
+
|
55 |
+
|
56 |
+
def prepare_time_series_dataframe(df):
|
57 |
+
"""Prepare dataframe for time series analysis by setting datetime index and renaming columns"""
|
58 |
+
df.set_index(df.columns[0], inplace=True)
|
59 |
+
|
60 |
+
try:
|
61 |
+
df.index = pd.to_datetime(df.index)
|
62 |
+
except ValueError:
|
63 |
+
raise ValueError("The first column of the CSV file must be a datetime column.")
|
64 |
+
|
65 |
+
df.rename(columns={df.columns[0]: "news"}, inplace=True)
|
66 |
+
return df
|
67 |
+
|
68 |
+
def update_controls(method):
|
69 |
+
"""
|
70 |
+
Updates the interactivity of control elements based on the selected method.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
method (str): The selected anomaly detection method
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
dict: Update configuration for Gradio components
|
77 |
+
"""
|
78 |
+
is_lstm = method == "LSTM"
|
79 |
+
return [
|
80 |
+
gr.update(interactive=is_lstm),
|
81 |
+
gr.update(interactive=is_lstm)
|
82 |
+
]
|
src/visualization.py
DELETED
@@ -1,20 +0,0 @@
|
|
1 |
-
import pandas as pd
|
2 |
-
import plotly.express as px
|
3 |
-
|
4 |
-
|
5 |
-
def plot_time_series(file):
|
6 |
-
"""
|
7 |
-
Plots a time series graph from a CSV file.
|
8 |
-
|
9 |
-
This function reads the CSV file and generates a line plot
|
10 |
-
showing the disease mentions over time.
|
11 |
-
|
12 |
-
"""
|
13 |
-
df = pd.read_csv(file.name)
|
14 |
-
fig = px.line(
|
15 |
-
df,
|
16 |
-
x=df.columns[0],
|
17 |
-
y=df.columns[1],
|
18 |
-
title='Disease Mentions Over Time'
|
19 |
-
)
|
20 |
-
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|