Aksel Joonas Reedi commited on
Commit
8db7b4c
·
1 Parent(s): 35377fa
app.py CHANGED
@@ -2,7 +2,6 @@ import altair as alt
2
  import pandas as pd
3
  import plotly.graph_objects as go
4
  import streamlit as st
5
-
6
  from src.helper_functions import custom_metric_box, pollution_box
7
  from src.predict import get_data_and_predictions, update_data_and_predictions
8
 
@@ -45,8 +44,7 @@ col1, col2 = st.columns((1, 3))
45
  # Create a 3-column layout
46
  with col1:
47
  st.subheader("Current Weather")
48
-
49
-
50
  custom_metric_box(
51
  label="🥵 Temperature",
52
  value=f"{round(today['mean_temp'] * 0.1)} °C",
@@ -76,38 +74,38 @@ with col1:
76
  with col2:
77
  st.subheader("Current Pollution Levels")
78
  sub1, sub2 = st.columns((1, 1))
79
-
80
  # Ozone (O₃) Pollution Box
81
  with sub1:
82
  pollution_box(
83
  label="O<sub>3</sub>",
84
  value=f"{round(today['O3'])} µg/m³",
85
  delta=f"{round(int(today['O3']) - int(previous_day['O3']))} µg/m³",
86
- threshold=120
87
  )
88
  with st.expander("Learn more about O3", expanded=False):
89
  st.markdown(
90
  """
91
  *Ozone (O<sub>3</sub>)*: A harmful gas at ground level that can irritate the respiratory system and aggravate asthma.<br>
92
  **Good/Bad**: "Good" means safe levels for most people, while "Bad" suggests harmful levels, especially for sensitive groups.
93
- """,
94
  unsafe_allow_html=True,
95
  )
96
-
97
  # Nitrogen Dioxide (NO₂) Pollution Box
98
  with sub2:
99
  pollution_box(
100
  label="NO<sub>2</sub>",
101
  value=f"{round(today['NO2'])} µg/m³",
102
  delta=f"{round(int(today['NO2']) - int(previous_day['NO2']))} µg/m³",
103
- threshold=40
104
  )
105
  with st.expander("Learn more about NO2", expanded=False):
106
  st.markdown(
107
  """
108
  *Nitrogen Dioxide (NO<sub>2</sub>)*: A toxic gas that contributes to lung irritation and worsens asthma and other respiratory issues.<br>
109
  **Good/Bad**: "Good" means safe air quality, while "Bad" indicates levels that could cause respiratory problems, especially for vulnerable individuals.
110
- """,
111
  unsafe_allow_html=True,
112
  )
113
 
@@ -118,9 +116,12 @@ with col2:
118
  def get_simple_color_scale(values, threshold):
119
  """Returns green for values below the threshold, orange for values between the threshold and 2x the threshold, and red for values above 2x the threshold."""
120
  return [
121
- "#77C124" if v < threshold else
122
- "#E68B0A" if v < 2 * threshold else
123
- "#E63946" for v in values
 
 
 
124
  ]
125
 
126
  # O3 Bar Plot (threshold: 40)
@@ -142,13 +143,17 @@ with col2:
142
  )
143
 
144
  # Add predicted values with reduced opacity
145
- predicted_o3_colors = get_simple_color_scale(o3_future_values, 40) # Color for future values
 
 
146
  fig_o3.add_trace(
147
  go.Bar(
148
  x=df["Date"][-3:], # Dates for predicted values
149
  y=o3_future_values,
150
  name="O3 Predicted",
151
- marker=dict(color=predicted_o3_colors, opacity=0.5), # Set opacity to 0.5 for predictions
 
 
152
  hovertemplate="%{x|%d-%b-%Y}<br>%{y} µg/m³<extra></extra>",
153
  )
154
  )
@@ -179,7 +184,7 @@ with col2:
179
  tickangle=-45,
180
  tickcolor="gray",
181
  ),
182
- showlegend=False # Disable legend
183
  )
184
 
185
  st.plotly_chart(fig_o3, key="fig_o3")
@@ -204,13 +209,17 @@ with col2:
204
  )
205
 
206
  # Add predicted values with reduced opacity
207
- predicted_no2_colors = get_simple_color_scale(no2_future_values, 120) # Color for future values
 
 
208
  fig_no2.add_trace(
209
  go.Bar(
210
  x=df["Date"][-3:], # Dates for predicted values
211
  y=no2_future_values,
212
  name="NO2 Predicted",
213
- marker=dict(color=predicted_no2_colors, opacity=0.5), # Set opacity to 0.5 for predictions
 
 
214
  hovertemplate="%{x|%d-%b-%Y}<br>%{y} µg/m³<extra></extra>",
215
  )
216
  )
@@ -241,7 +250,7 @@ with col2:
241
  tickangle=-45,
242
  tickcolor="gray",
243
  ),
244
- showlegend=False # Disable legend
245
  )
246
 
247
- st.plotly_chart(fig_no2, key="fig_no2")
 
2
  import pandas as pd
3
  import plotly.graph_objects as go
4
  import streamlit as st
 
5
  from src.helper_functions import custom_metric_box, pollution_box
6
  from src.predict import get_data_and_predictions, update_data_and_predictions
7
 
 
44
  # Create a 3-column layout
45
  with col1:
46
  st.subheader("Current Weather")
47
+
 
48
  custom_metric_box(
49
  label="🥵 Temperature",
50
  value=f"{round(today['mean_temp'] * 0.1)} °C",
 
74
  with col2:
75
  st.subheader("Current Pollution Levels")
76
  sub1, sub2 = st.columns((1, 1))
77
+
78
  # Ozone (O₃) Pollution Box
79
  with sub1:
80
  pollution_box(
81
  label="O<sub>3</sub>",
82
  value=f"{round(today['O3'])} µg/m³",
83
  delta=f"{round(int(today['O3']) - int(previous_day['O3']))} µg/m³",
84
+ threshold=120,
85
  )
86
  with st.expander("Learn more about O3", expanded=False):
87
  st.markdown(
88
  """
89
  *Ozone (O<sub>3</sub>)*: A harmful gas at ground level that can irritate the respiratory system and aggravate asthma.<br>
90
  **Good/Bad**: "Good" means safe levels for most people, while "Bad" suggests harmful levels, especially for sensitive groups.
91
+ """,
92
  unsafe_allow_html=True,
93
  )
94
+
95
  # Nitrogen Dioxide (NO₂) Pollution Box
96
  with sub2:
97
  pollution_box(
98
  label="NO<sub>2</sub>",
99
  value=f"{round(today['NO2'])} µg/m³",
100
  delta=f"{round(int(today['NO2']) - int(previous_day['NO2']))} µg/m³",
101
+ threshold=40,
102
  )
103
  with st.expander("Learn more about NO2", expanded=False):
104
  st.markdown(
105
  """
106
  *Nitrogen Dioxide (NO<sub>2</sub>)*: A toxic gas that contributes to lung irritation and worsens asthma and other respiratory issues.<br>
107
  **Good/Bad**: "Good" means safe air quality, while "Bad" indicates levels that could cause respiratory problems, especially for vulnerable individuals.
108
+ """,
109
  unsafe_allow_html=True,
110
  )
111
 
 
116
  def get_simple_color_scale(values, threshold):
117
  """Returns green for values below the threshold, orange for values between the threshold and 2x the threshold, and red for values above 2x the threshold."""
118
  return [
119
+ "#77C124"
120
+ if v < threshold
121
+ else "#E68B0A"
122
+ if v < 2 * threshold
123
+ else "#E63946"
124
+ for v in values
125
  ]
126
 
127
  # O3 Bar Plot (threshold: 40)
 
143
  )
144
 
145
  # Add predicted values with reduced opacity
146
+ predicted_o3_colors = get_simple_color_scale(
147
+ o3_future_values, 40
148
+ ) # Color for future values
149
  fig_o3.add_trace(
150
  go.Bar(
151
  x=df["Date"][-3:], # Dates for predicted values
152
  y=o3_future_values,
153
  name="O3 Predicted",
154
+ marker=dict(
155
+ color=predicted_o3_colors, opacity=0.5
156
+ ), # Set opacity to 0.5 for predictions
157
  hovertemplate="%{x|%d-%b-%Y}<br>%{y} µg/m³<extra></extra>",
158
  )
159
  )
 
184
  tickangle=-45,
185
  tickcolor="gray",
186
  ),
187
+ showlegend=False, # Disable legend
188
  )
189
 
190
  st.plotly_chart(fig_o3, key="fig_o3")
 
209
  )
210
 
211
  # Add predicted values with reduced opacity
212
+ predicted_no2_colors = get_simple_color_scale(
213
+ no2_future_values, 120
214
+ ) # Color for future values
215
  fig_no2.add_trace(
216
  go.Bar(
217
  x=df["Date"][-3:], # Dates for predicted values
218
  y=no2_future_values,
219
  name="NO2 Predicted",
220
+ marker=dict(
221
+ color=predicted_no2_colors, opacity=0.5
222
+ ), # Set opacity to 0.5 for predictions
223
  hovertemplate="%{x|%d-%b-%Y}<br>%{y} µg/m³<extra></extra>",
224
  )
225
  )
 
250
  tickangle=-45,
251
  tickcolor="gray",
252
  ),
253
+ showlegend=False, # Disable legend
254
  )
255
 
256
+ st.plotly_chart(fig_no2, key="fig_no2")
src/data_api_calls.py CHANGED
@@ -14,7 +14,11 @@ WEATHER_DATA_FILE = "weather_data.csv"
14
  POLLUTION_DATA_FILE = "pollution_data.csv"
15
 
16
 
17
- def update_weather_data():
 
 
 
 
18
  today = date.today().isoformat()
19
 
20
  if os.path.exists(WEATHER_DATA_FILE):
@@ -50,7 +54,11 @@ def update_weather_data():
50
  sys.exit()
51
 
52
 
53
- def update_pollution_data():
 
 
 
 
54
  O3 = []
55
  NO2 = []
56
  particles = ["NO2", "O3"]
@@ -113,14 +121,21 @@ def update_pollution_data():
113
  updated_data.to_csv(POLLUTION_DATA_FILE, index=False)
114
 
115
 
116
- def get_combined_data():
 
 
117
 
 
 
 
118
  weather_df = pd.read_csv(WEATHER_DATA_FILE)
119
-
120
  today = pd.Timestamp.now().normalize()
121
  seven_days_ago = today - pd.Timedelta(days=7)
122
  weather_df["date"] = pd.to_datetime(weather_df["date"])
123
- weather_df = weather_df[(weather_df["date"] >= seven_days_ago) & (weather_df["date"] <= today)]
 
 
124
 
125
  weather_df.insert(1, "NO2", None)
126
  weather_df.insert(2, "O3", None)
@@ -168,7 +183,9 @@ def get_combined_data():
168
  pollution_df = pd.read_csv(POLLUTION_DATA_FILE)
169
 
170
  pollution_df["date"] = pd.to_datetime(pollution_df["date"])
171
- pollution_df = pollution_df[(pollution_df["date"] >= seven_days_ago) & (pollution_df["date"] <= today)]
 
 
172
 
173
  combined_df["NO2"] = pollution_df["NO2"]
174
  combined_df["O3"] = pollution_df["O3"]
 
14
  POLLUTION_DATA_FILE = "pollution_data.csv"
15
 
16
 
17
+ def update_weather_data() -> None:
18
+ """
19
+ Updates weather data by fetching data.
20
+ If the data file exists, it appends new data. If not, it creates a new file.
21
+ """
22
  today = date.today().isoformat()
23
 
24
  if os.path.exists(WEATHER_DATA_FILE):
 
54
  sys.exit()
55
 
56
 
57
+ def update_pollution_data() -> None:
58
+ """
59
+ Updates pollution data for NO2 and O3.
60
+ The new data is appended to the existing pollution data file.
61
+ """
62
  O3 = []
63
  NO2 = []
64
  particles = ["NO2", "O3"]
 
121
  updated_data.to_csv(POLLUTION_DATA_FILE, index=False)
122
 
123
 
124
+ def get_combined_data() -> pd.DataFrame:
125
+ """
126
+ Combines weather and pollution data for the last 7 days.
127
 
128
+ Returns:
129
+ pd.DataFrame: A DataFrame containing the combined weather and pollution data.
130
+ """
131
  weather_df = pd.read_csv(WEATHER_DATA_FILE)
132
+
133
  today = pd.Timestamp.now().normalize()
134
  seven_days_ago = today - pd.Timedelta(days=7)
135
  weather_df["date"] = pd.to_datetime(weather_df["date"])
136
+ weather_df = weather_df[
137
+ (weather_df["date"] >= seven_days_ago) & (weather_df["date"] <= today)
138
+ ]
139
 
140
  weather_df.insert(1, "NO2", None)
141
  weather_df.insert(2, "O3", None)
 
183
  pollution_df = pd.read_csv(POLLUTION_DATA_FILE)
184
 
185
  pollution_df["date"] = pd.to_datetime(pollution_df["date"])
186
+ pollution_df = pollution_df[
187
+ (pollution_df["date"] >= seven_days_ago) & (pollution_df["date"] <= today)
188
+ ]
189
 
190
  combined_df["NO2"] = pollution_df["NO2"]
191
  combined_df["O3"] = pollution_df["O3"]
src/features_pipeline.py CHANGED
@@ -6,7 +6,6 @@ import numpy as np
6
  import pandas as pd
7
  from dotenv import load_dotenv
8
  from huggingface_hub import hf_hub_download, login
9
-
10
  from src.past_data_api_calls import get_past_combined_data
11
 
12
  warnings.filterwarnings("ignore")
@@ -16,11 +15,44 @@ login(token=os.getenv("HUGGINGFACE_DOWNLOAD_TOKEN"))
16
 
17
 
18
  def create_features(
19
- data,
20
- target_particle, # Added this parameter
21
- lag_days=7,
22
- sma_days=7,
23
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  lag_features = [
25
  "NO2",
26
  "O3",
 
6
  import pandas as pd
7
  from dotenv import load_dotenv
8
  from huggingface_hub import hf_hub_download, login
 
9
  from src.past_data_api_calls import get_past_combined_data
10
 
11
  warnings.filterwarnings("ignore")
 
15
 
16
 
17
  def create_features(
18
+ data: pd.DataFrame,
19
+ target_particle: str, # Added this parameter
20
+ lag_days: int = 7,
21
+ sma_days: int = 7,
22
+ ) -> pd.DataFrame:
23
+ """
24
+ Create features for predicting air quality particles (NO2 or O3) based on historical weather data.
25
+
26
+ This function performs several feature engineering tasks, including:
27
+ - Creating lagged features for specified pollutants.
28
+ - Calculating rolling mean (SMA) features.
29
+ - Adding sine and cosine transformations of the weekday and month.
30
+ - Incorporating historical data for the same date in the previous year.
31
+
32
+ Parameters:
33
+ ----------
34
+ data : pd.DataFrame
35
+ A DataFrame containing historical weather and air quality data with a 'date' column.
36
+
37
+ target_particle : str
38
+ The target particle for prediction, must be either 'O3' or 'NO2'.
39
+
40
+ lag_days : int, optional
41
+ The number of days for which lagged features will be created. Default is 7.
42
+
43
+ sma_days : int, optional
44
+ The window size for calculating the simple moving average (SMA). Default is 7.
45
+
46
+ Returns:
47
+ -------
48
+ pd.DataFrame
49
+ A DataFrame containing the transformed features, ready for modeling.
50
+
51
+ Raises:
52
+ ------
53
+ ValueError
54
+ If target_particle is not 'O3' or 'NO2'.
55
+ """
56
  lag_features = [
57
  "NO2",
58
  "O3",
src/helper_functions.py CHANGED
@@ -1,9 +1,26 @@
1
  import streamlit as st
2
 
3
 
4
- # Custom function to create styled metric boxes with compact layout
5
- def custom_metric_box(label, value):
6
- st.markdown(f"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  <div style="
8
  padding: 5px;
9
  margin-bottom: 5px;
@@ -19,17 +36,42 @@ def custom_metric_box(label, value):
19
  <p style="font-size: 18px; font-weight: bold; margin: 0;">{value}</p> <!-- Smaller metric -->
20
  </div>
21
  </div>
22
- """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # Custom function to create pollution metric boxes with side-by-side layout for label and value
25
- # Custom function to create pollution metric boxes with side-by-side layout and fixed width
26
- def pollution_box(label, value, delta, threshold):
 
27
  # Determine if the pollution level is "Good" or "Bad"
28
  status = "Good" if float(value.split()[0]) < threshold else "Bad"
29
  status_color = "#77C124" if status == "Good" else "#E68B0A"
30
 
31
  # Render the pollution box
32
- st.markdown(f"""
 
33
  <div style="
34
  background: rgba(255, 255, 255, 0.05);
35
  border-radius: 16px;
@@ -44,4 +86,6 @@ def pollution_box(label, value, delta, threshold):
44
  <p style="font-size: 36px; font-weight: bold; color: {status_color}; margin: 0;">{status}</p> <!-- Good/Bad with color -->
45
  <p style="font-size: 18px; margin: 0;">{value}</p> <!-- Smaller value where delta used to be -->
46
  </div>
47
- """, unsafe_allow_html=True)
 
 
 
1
  import streamlit as st
2
 
3
 
4
+ def custom_metric_box(label: str, value: str) -> None:
5
+ """
6
+ Create a styled metric box with a compact layout.
7
+
8
+ This function generates a styled markdown box displaying a label and its corresponding value.
9
+
10
+ Parameters:
11
+ ----------
12
+ label : str
13
+ The text label to display in the metric box.
14
+
15
+ value : str
16
+ The value to be displayed in the metric box, typically representing a metric.
17
+
18
+ Returns:
19
+ -------
20
+ None
21
+ """
22
+ st.markdown(
23
+ f"""
24
  <div style="
25
  padding: 5px;
26
  margin-bottom: 5px;
 
36
  <p style="font-size: 18px; font-weight: bold; margin: 0;">{value}</p> <!-- Smaller metric -->
37
  </div>
38
  </div>
39
+ """,
40
+ unsafe_allow_html=True,
41
+ )
42
+
43
+
44
+ def pollution_box(label: str, value: str, delta: str, threshold: float) -> None:
45
+ """
46
+ Create a pollution metric box with a side-by-side layout and fixed width.
47
+
48
+ This function generates a styled markdown box displaying pollution level status, value, and other related information.
49
+
50
+ Parameters:
51
+ ----------
52
+ label : str
53
+ The text label representing the type of pollution or metric.
54
+
55
+ value : str
56
+ The value of the pollution metric, typically a string that can be converted to a float.
57
+
58
+ delta : str
59
+ A string representing the change in pollution level, though not currently used in the rendering.
60
+
61
+ threshold : float
62
+ The threshold value to determine if the pollution level is "Good" or "Bad".
63
 
64
+ Returns:
65
+ -------
66
+ None
67
+ """
68
  # Determine if the pollution level is "Good" or "Bad"
69
  status = "Good" if float(value.split()[0]) < threshold else "Bad"
70
  status_color = "#77C124" if status == "Good" else "#E68B0A"
71
 
72
  # Render the pollution box
73
+ st.markdown(
74
+ f"""
75
  <div style="
76
  background: rgba(255, 255, 255, 0.05);
77
  border-radius: 16px;
 
86
  <p style="font-size: 36px; font-weight: bold; color: {status_color}; margin: 0;">{status}</p> <!-- Good/Bad with color -->
87
  <p style="font-size: 18px; margin: 0;">{value}</p> <!-- Smaller value where delta used to be -->
88
  </div>
89
+ """,
90
+ unsafe_allow_html=True,
91
+ )
src/past_data_api_calls.py CHANGED
@@ -14,7 +14,11 @@ PAST_WEATHER_DATA_FILE = "past_weather_data.csv"
14
  PAST_POLLUTION_DATA_FILE = "past_pollution_data.csv"
15
 
16
 
17
- def update_past_weather_data():
 
 
 
 
18
  last_year_date = date.today() - timedelta(days=365)
19
 
20
  if os.path.exists(PAST_WEATHER_DATA_FILE):
@@ -51,7 +55,13 @@ def update_past_weather_data():
51
  sys.exit()
52
 
53
 
54
- def update_past_pollution_data():
 
 
 
 
 
 
55
  O3 = []
56
  NO2 = []
57
  particles = ["NO2", "O3"]
@@ -65,7 +75,7 @@ def update_past_pollution_data():
65
  last_date = pd.to_datetime(existing_data["date"]).max()
66
  if last_date >= pd.to_datetime(last_year_date):
67
  print("Data is already up to date.")
68
- return
69
  else:
70
  start_date = last_date.date()
71
  end_date = last_year_date + timedelta(days=3)
@@ -129,7 +139,13 @@ def update_past_pollution_data():
129
  return NO2, O3
130
 
131
 
132
- def get_past_combined_data():
 
 
 
 
 
 
133
  update_past_weather_data()
134
  update_past_pollution_data()
135
 
 
14
  PAST_POLLUTION_DATA_FILE = "past_pollution_data.csv"
15
 
16
 
17
+ def update_past_weather_data() -> None:
18
+ """
19
+ Updates past weather data.
20
+ The data is saved to a CSV file. If the file already exists, new data is appended.
21
+ """
22
  last_year_date = date.today() - timedelta(days=365)
23
 
24
  if os.path.exists(PAST_WEATHER_DATA_FILE):
 
55
  sys.exit()
56
 
57
 
58
+ def update_past_pollution_data() -> tuple[list[float], list[float]]:
59
+ """
60
+ Updates past pollution data for NO2 and O3.
61
+
62
+ Returns:
63
+ tuple: A tuple containing two lists with NO2 and O3 average values.
64
+ """
65
  O3 = []
66
  NO2 = []
67
  particles = ["NO2", "O3"]
 
75
  last_date = pd.to_datetime(existing_data["date"]).max()
76
  if last_date >= pd.to_datetime(last_year_date):
77
  print("Data is already up to date.")
78
+ return [], []
79
  else:
80
  start_date = last_date.date()
81
  end_date = last_year_date + timedelta(days=3)
 
139
  return NO2, O3
140
 
141
 
142
+ def get_past_combined_data() -> pd.DataFrame:
143
+ """
144
+ Retrieves and combines past weather and pollution data.
145
+
146
+ Returns:
147
+ pd.DataFrame: A DataFrame containing the combined past weather and pollution data.
148
+ """
149
  update_past_weather_data()
150
  update_past_pollution_data()
151
 
src/predict.py CHANGED
@@ -6,7 +6,6 @@ import pandas as pd
6
  import torch
7
  from dotenv import load_dotenv
8
  from huggingface_hub import hf_hub_download, login
9
-
10
  from src.data_api_calls import (
11
  get_combined_data,
12
  update_pollution_data,
@@ -18,12 +17,18 @@ load_dotenv()
18
  login(token=os.getenv("HUGGINGFACE_DOWNLOAD_TOKEN"))
19
 
20
 
21
- def load_nn():
 
 
 
 
 
 
22
  import torch.nn as nn
23
  from huggingface_hub import PyTorchModelHubMixin
24
 
25
  class AirPollutionNet(nn.Module, PyTorchModelHubMixin):
26
- def __init__(self, input_size, layers, dropout_rate):
27
  super(AirPollutionNet, self).__init__()
28
  self.layers_list = nn.ModuleList()
29
  in_features = input_size
@@ -36,7 +41,16 @@ def load_nn():
36
 
37
  self.output = nn.Linear(in_features, 3) # Output size is 3 for next 3 days
38
 
39
- def forward(self, x):
 
 
 
 
 
 
 
 
 
40
  for layer in self.layers_list:
41
  x = layer(x)
42
  x = self.output(x)
@@ -48,7 +62,16 @@ def load_nn():
48
  return model
49
 
50
 
51
- def load_model(particle):
 
 
 
 
 
 
 
 
 
52
  repo_id = f"elisaklunder/Utrecht-{particle}-Forecasting-Model"
53
  if particle == "O3":
54
  file_name = "O3_svr_model.pkl"
@@ -60,7 +83,17 @@ def load_model(particle):
60
  return model
61
 
62
 
63
- def run_model(particle, data):
 
 
 
 
 
 
 
 
 
 
64
  input_data = create_features(data=data, target_particle=particle)
65
  model = load_model(particle)
66
 
@@ -83,7 +116,11 @@ def run_model(particle, data):
83
  return prediction
84
 
85
 
86
- def update_data_and_predictions():
 
 
 
 
87
  update_weather_data()
88
  update_pollution_data()
89
 
@@ -129,7 +166,16 @@ def update_data_and_predictions():
129
  combined_data.to_csv(PREDICTIONS_FILE, index=False)
130
 
131
 
132
- def get_data_and_predictions():
 
 
 
 
 
 
 
 
 
133
  week_data = get_combined_data()
134
 
135
  PREDICTIONS_FILE = "predictions_history.csv"
@@ -148,5 +194,6 @@ def get_data_and_predictions():
148
 
149
  return week_data, [o3_predictions], [no2_predictions]
150
 
151
- if __name__=="__main__":
152
- update_data_and_predictions()
 
 
6
  import torch
7
  from dotenv import load_dotenv
8
  from huggingface_hub import hf_hub_download, login
 
9
  from src.data_api_calls import (
10
  get_combined_data,
11
  update_pollution_data,
 
17
  login(token=os.getenv("HUGGINGFACE_DOWNLOAD_TOKEN"))
18
 
19
 
20
+ def load_nn() -> torch.nn.Module:
21
+ """
22
+ Loads the neural network model for air pollution forecasting.
23
+
24
+ Returns:
25
+ torch.nn.Module: The loaded neural network model.
26
+ """
27
  import torch.nn as nn
28
  from huggingface_hub import PyTorchModelHubMixin
29
 
30
  class AirPollutionNet(nn.Module, PyTorchModelHubMixin):
31
+ def __init__(self, input_size: int, layers: list[int], dropout_rate: float):
32
  super(AirPollutionNet, self).__init__()
33
  self.layers_list = nn.ModuleList()
34
  in_features = input_size
 
41
 
42
  self.output = nn.Linear(in_features, 3) # Output size is 3 for next 3 days
43
 
44
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
45
+ """
46
+ Forward pass of the neural network.
47
+
48
+ Args:
49
+ x (torch.Tensor): Input tensor.
50
+
51
+ Returns:
52
+ torch.Tensor: Output tensor after passing through the network.
53
+ """
54
  for layer in self.layers_list:
55
  x = layer(x)
56
  x = self.output(x)
 
62
  return model
63
 
64
 
65
+ def load_model(particle: str) -> object:
66
+ """
67
+ Loads the forecasting model based on the specified particle.
68
+
69
+ Args:
70
+ particle (str): The type of particle ("O3" or "NO2").
71
+
72
+ Returns:
73
+ object: The loaded model (either a neural network or a support vector regression model).
74
+ """
75
  repo_id = f"elisaklunder/Utrecht-{particle}-Forecasting-Model"
76
  if particle == "O3":
77
  file_name = "O3_svr_model.pkl"
 
83
  return model
84
 
85
 
86
+ def run_model(particle: str, data: pd.DataFrame) -> list:
87
+ """
88
+ Runs the model for the specified particle and makes predictions based on the input data.
89
+
90
+ Args:
91
+ particle (str): The type of particle ("O3" or "NO2").
92
+ data (pd.DataFrame): The input data for making predictions.
93
+
94
+ Returns:
95
+ list: The predictions for the specified particle.
96
+ """
97
  input_data = create_features(data=data, target_particle=particle)
98
  model = load_model(particle)
99
 
 
116
  return prediction
117
 
118
 
119
+ def update_data_and_predictions() -> None:
120
+ """
121
+ Updates the weather and pollution data, makes predictions for O3 and NO2,
122
+ and stores them in a CSV file.
123
+ """
124
  update_weather_data()
125
  update_pollution_data()
126
 
 
166
  combined_data.to_csv(PREDICTIONS_FILE, index=False)
167
 
168
 
169
+ def get_data_and_predictions() -> tuple[pd.DataFrame, list, list]:
170
+ """
171
+ Retrieves combined data and today's predictions for O3 and NO2.
172
+
173
+ Returns:
174
+ tuple: A tuple containing:
175
+ - week_data (pd.DataFrame): The combined data for the week.
176
+ - list: Predictions for O3.
177
+ - list: Predictions for NO2.
178
+ """
179
  week_data = get_combined_data()
180
 
181
  PREDICTIONS_FILE = "predictions_history.csv"
 
194
 
195
  return week_data, [o3_predictions], [no2_predictions]
196
 
197
+
198
+ if __name__ == "__main__":
199
+ update_data_and_predictions()