vincentiusyoshuac commited on
Commit
58edeb6
·
verified ·
1 Parent(s): f08d236

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -91
app.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  from chronos import ChronosPipeline
7
 
8
  class TimeSeriesForecaster:
9
- def __init__(self, model_name="amazon/chronos-t5-small"):
10
  self.pipeline = ChronosPipeline.from_pretrained(
11
  model_name,
12
  device_map="cuda" if torch.cuda.is_available() else "cpu",
@@ -15,7 +15,7 @@ class TimeSeriesForecaster:
15
  self.original_series = None
16
  self.context = None
17
 
18
- def preprocess_data(self, df, date_column, value_column, context_length=30, prediction_length=7):
19
  """
20
  Prepare time series data from DataFrame
21
  """
@@ -36,151 +36,168 @@ class TimeSeriesForecaster:
36
 
37
  return self.context, context_length
38
 
39
- def forecast(self, context, prediction_length=7, num_samples=100):
40
  """
41
  Perform time series forecasting
42
  """
43
  forecasts = self.pipeline.predict(context, prediction_length, num_samples=num_samples)
44
  return forecasts
45
 
46
- def visualize_forecast(self, context, forecasts):
47
  """
48
- Create visualization of predictions
49
  """
50
- plt.figure(figsize=(12, 6), facecolor='#f0f2f6')
51
-
52
- # Plot original series
53
- plt.plot(range(len(self.original_series)), self.original_series, label='Historical Data', color='#1E88E5', linewidth=2)
54
 
55
  # Calculate forecast statistics
56
  forecast_np = forecasts[0].numpy()
57
  low, median, high = np.quantile(forecast_np, [0.1, 0.5, 0.9], axis=0)
58
 
59
- # Plot forecast
60
- forecast_index = range(len(self.original_series), len(self.original_series) + len(median))
61
- plt.plot(forecast_index, median, color='#D81B60', linewidth=2, label='Median Forecast')
62
- plt.fill_between(forecast_index, low, high, color='#D81B60', alpha=0.3, label='80% Prediction Interval')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- plt.title('Time Series Forecasting', fontsize=16, fontweight='bold')
65
- plt.xlabel('Time Index', fontsize=12)
66
- plt.ylabel('Value', fontsize=12)
67
- plt.legend(frameon=True)
68
- plt.grid(True, linestyle='--', alpha=0.7)
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  return plt
71
 
72
  def main():
73
- # Set page configuration
74
  st.set_page_config(
75
- page_title="Time Series Forecaster",
76
- page_icon="📈",
77
- layout="wide",
78
- initial_sidebar_state="collapsed"
79
  )
80
 
81
- # Custom CSS for modern look
82
  st.markdown("""
83
  <style>
84
- /* Modern, clean design */
85
  .stApp {
86
- background-color: #f0f2f6;
87
- font-family: 'Inter', sans-serif;
88
  }
89
 
90
- /* Card-like containers */
91
- .card {
92
- background-color: white;
93
- border-radius: 10px;
 
 
 
 
 
 
 
 
 
94
  box-shadow: 0 4px 6px rgba(0,0,0,0.1);
95
- padding: 20px;
96
- margin-bottom: 20px;
97
  }
98
 
99
- /* Stylish file uploader */
100
  .stFileUploader {
101
- background-color: #f8f9fa;
102
- border: 2px dashed #6c757d;
103
  border-radius: 10px;
104
  padding: 20px;
105
  text-align: center;
106
- }
107
-
108
- /* Buttons */
109
- .stButton>button {
110
- background-color: #1E88E5;
111
- color: white;
112
- border-radius: 6px;
113
  transition: all 0.3s ease;
114
  }
115
- .stButton>button:hover {
116
- background-color: #1565c0;
117
- transform: scale(1.05);
118
  }
119
  </style>
120
  """, unsafe_allow_html=True)
121
 
122
- # Title and description
123
- st.markdown("<h1 style='text-align: center; color: #1E88E5;'>🕰️ Time Series Forecaster</h1>", unsafe_allow_html=True)
124
- st.markdown("<p style='text-align: center; color: #6c757d;'>Predict future trends with advanced machine learning</p>", unsafe_allow_html=True)
 
 
125
 
126
- # File upload section
127
- st.markdown("<div class='card'>", unsafe_allow_html=True)
128
  uploaded_file = st.file_uploader(
129
- "Upload CSV File",
130
  type=['csv'],
131
- help="Upload a CSV file with time series data",
132
- label_visibility="collapsed"
133
  )
134
- st.markdown("</div>", unsafe_allow_html=True)
135
 
136
- # Data and forecast configuration
137
  if uploaded_file is not None:
138
  # Read CSV
139
  df = pd.read_csv(uploaded_file)
140
 
141
- # Configuration card
142
- st.markdown("<div class='card'>", unsafe_allow_html=True)
143
  col1, col2 = st.columns(2)
144
 
145
  with col1:
146
- date_column = st.selectbox(
147
- 'Select Date Column',
148
- options=df.columns,
149
- help="Choose the column representing timestamps"
150
- )
151
 
152
  with col2:
153
  value_column = st.selectbox(
154
- 'Select Value Column',
155
- options=[col for col in df.columns if col != date_column],
156
- help="Choose the numeric column to forecast"
157
  )
158
 
159
- # Prediction parameters
160
  col3, col4 = st.columns(2)
161
 
162
  with col3:
163
  context_length = st.slider(
164
- 'Context Length',
165
- min_value=10,
166
- max_value=100,
167
- value=30,
168
- help="Number of historical data points to use for prediction"
169
  )
170
 
171
  with col4:
172
  prediction_length = st.slider(
173
- 'Prediction Length',
174
- min_value=1,
175
- max_value=30,
176
- value=7,
177
  help="Number of future time steps to predict"
178
  )
179
 
180
- # Forecast button
181
  if st.button('Generate Forecast'):
182
  try:
183
- # Initialize forecaster
184
  forecaster = TimeSeriesForecaster()
185
 
186
  # Preprocess data
@@ -195,22 +212,17 @@ def main():
195
  # Perform forecasting
196
  forecasts = forecaster.forecast(context, prediction_length)
197
 
198
- # Visualization card
199
- st.markdown("<div class='card'>", unsafe_allow_html=True)
200
- st.subheader('Forecast Visualization')
201
- plt = forecaster.visualize_forecast(context, forecasts)
202
  st.pyplot(plt)
203
- st.markdown("</div>", unsafe_allow_html=True)
204
-
205
- # Forecast details card
206
- st.markdown("<div class='card'>", unsafe_allow_html=True)
207
- st.subheader('Forecast Details')
208
 
 
209
  forecast_np = forecasts[0].numpy()
210
  forecast_mean = forecast_np.mean(axis=0)
211
  forecast_lower = np.percentile(forecast_np, 10, axis=0)
212
  forecast_upper = np.percentile(forecast_np, 90, axis=0)
213
 
 
214
  prediction_df = pd.DataFrame({
215
  'Mean Forecast': forecast_mean,
216
  'Lower Bound (10%)': forecast_lower,
@@ -218,12 +230,9 @@ def main():
218
  })
219
 
220
  st.dataframe(prediction_df)
221
- st.markdown("</div>", unsafe_allow_html=True)
222
 
223
  except Exception as e:
224
- st.error(f"An error occurred: {str(e)}")
225
-
226
- st.markdown("</div>", unsafe_allow_html=True)
227
 
228
  if __name__ == '__main__':
229
  main()
 
6
  from chronos import ChronosPipeline
7
 
8
  class TimeSeriesForecaster:
9
+ def __init__(self, model_name="amazon/chronos-t5-large"):
10
  self.pipeline = ChronosPipeline.from_pretrained(
11
  model_name,
12
  device_map="cuda" if torch.cuda.is_available() else "cpu",
 
15
  self.original_series = None
16
  self.context = None
17
 
18
+ def preprocess_data(self, df, date_column, value_column, context_length=100, prediction_length=365):
19
  """
20
  Prepare time series data from DataFrame
21
  """
 
36
 
37
  return self.context, context_length
38
 
39
+ def forecast(self, context, prediction_length=365, num_samples=100):
40
  """
41
  Perform time series forecasting
42
  """
43
  forecasts = self.pipeline.predict(context, prediction_length, num_samples=num_samples)
44
  return forecasts
45
 
46
+ def visualize_forecast(self, forecasts, original_series):
47
  """
48
+ Create comprehensive visualization of predictions
49
  """
50
+ plt.figure(figsize=(16, 8), dpi=100, facecolor='white')
 
 
 
51
 
52
  # Calculate forecast statistics
53
  forecast_np = forecasts[0].numpy()
54
  low, median, high = np.quantile(forecast_np, [0.1, 0.5, 0.9], axis=0)
55
 
56
+ # Plot original series
57
+ plt.plot(
58
+ range(len(original_series)),
59
+ original_series,
60
+ label='Historical Data',
61
+ color='#2C3E50',
62
+ linewidth=2,
63
+ alpha=0.7
64
+ )
65
+
66
+ # Forecast index
67
+ forecast_index = range(len(original_series), len(original_series) + len(median))
68
+
69
+ # Plot median forecast
70
+ plt.plot(
71
+ forecast_index,
72
+ median,
73
+ color='#3498DB',
74
+ linewidth=3,
75
+ label='Median Forecast'
76
+ )
77
 
78
+ # Plot prediction interval
79
+ plt.fill_between(
80
+ forecast_index,
81
+ low,
82
+ high,
83
+ color='#3498DB',
84
+ alpha=0.2,
85
+ label='90% Prediction Interval'
86
+ )
87
+
88
+ plt.title('Advanced Time Series Forecast', fontsize=18, fontweight='bold', color='#2C3E50')
89
+ plt.xlabel('Time Steps', fontsize=12, color='#34495E')
90
+ plt.ylabel('Value', fontsize=12, color='#34495E')
91
+ plt.legend(frameon=False)
92
+ plt.grid(True, linestyle='--', color='#BDC3C7', alpha=0.5)
93
+
94
+ # Sophisticated styling
95
+ plt.tight_layout()
96
 
97
  return plt
98
 
99
  def main():
100
+ # Page configuration
101
  st.set_page_config(
102
+ page_title="Forecast Maestro",
103
+ page_icon="📊",
104
+ layout="wide"
 
105
  )
106
 
107
+ # Modern, minimalist styling
108
  st.markdown("""
109
  <style>
 
110
  .stApp {
111
+ background-color: #FFFFFF;
112
+ font-family: 'Inter', 'Roboto', sans-serif;
113
  }
114
 
115
+ .stButton>button {
116
+ background-color: #3498DB;
117
+ color: white;
118
+ border: none;
119
+ border-radius: 8px;
120
+ padding: 10px 20px;
121
+ transition: all 0.3s ease;
122
+ font-weight: 600;
123
+ }
124
+
125
+ .stButton>button:hover {
126
+ background-color: #2980B9;
127
+ transform: scale(1.05);
128
  box-shadow: 0 4px 6px rgba(0,0,0,0.1);
 
 
129
  }
130
 
131
+ /* Sleek file uploader */
132
  .stFileUploader {
133
+ background-color: #F8F9FA;
134
+ border: 2px solid #3498DB;
135
  border-radius: 10px;
136
  padding: 20px;
137
  text-align: center;
 
 
 
 
 
 
 
138
  transition: all 0.3s ease;
139
  }
140
+ .stFileUploader:hover {
141
+ border-color: #2980B9;
142
+ background-color: #EAF2F8;
143
  }
144
  </style>
145
  """, unsafe_allow_html=True)
146
 
147
+ # Elegant title
148
+ st.markdown(
149
+ "<h1 style='text-align: center; color: #3498DB; margin-bottom: 30px;'>🔮 Forecast Maestro</h1>",
150
+ unsafe_allow_html=True
151
+ )
152
 
153
+ # File upload
 
154
  uploaded_file = st.file_uploader(
155
+ "Upload Time Series Data",
156
  type=['csv'],
157
+ help="CSV with timestamp and numeric columns"
 
158
  )
 
159
 
 
160
  if uploaded_file is not None:
161
  # Read CSV
162
  df = pd.read_csv(uploaded_file)
163
 
164
+ # Column selection
 
165
  col1, col2 = st.columns(2)
166
 
167
  with col1:
168
+ date_column = st.selectbox('Date Column', options=df.columns)
 
 
 
 
169
 
170
  with col2:
171
  value_column = st.selectbox(
172
+ 'Value Column',
173
+ options=[col for col in df.columns if col != date_column]
 
174
  )
175
 
176
+ # Advanced prediction settings
177
  col3, col4 = st.columns(2)
178
 
179
  with col3:
180
  context_length = st.slider(
181
+ 'Historical Context',
182
+ min_value=30,
183
+ max_value=500,
184
+ value=100,
185
+ help="Number of past data points to analyze"
186
  )
187
 
188
  with col4:
189
  prediction_length = st.slider(
190
+ 'Forecast Horizon',
191
+ min_value=30,
192
+ max_value=1000,
193
+ value=365,
194
  help="Number of future time steps to predict"
195
  )
196
 
197
+ # Forecast generation
198
  if st.button('Generate Forecast'):
199
  try:
200
+ # Initialize and run forecaster
201
  forecaster = TimeSeriesForecaster()
202
 
203
  # Preprocess data
 
212
  # Perform forecasting
213
  forecasts = forecaster.forecast(context, prediction_length)
214
 
215
+ # Visualization
216
+ plt = forecaster.visualize_forecast(forecasts, forecaster.original_series)
 
 
217
  st.pyplot(plt)
 
 
 
 
 
218
 
219
+ # Forecast details
220
  forecast_np = forecasts[0].numpy()
221
  forecast_mean = forecast_np.mean(axis=0)
222
  forecast_lower = np.percentile(forecast_np, 10, axis=0)
223
  forecast_upper = np.percentile(forecast_np, 90, axis=0)
224
 
225
+ st.subheader('Forecast Insights')
226
  prediction_df = pd.DataFrame({
227
  'Mean Forecast': forecast_mean,
228
  'Lower Bound (10%)': forecast_lower,
 
230
  })
231
 
232
  st.dataframe(prediction_df)
 
233
 
234
  except Exception as e:
235
+ st.error(f"Forecast generation error: {str(e)}")
 
 
236
 
237
  if __name__ == '__main__':
238
  main()