Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ import torch
|
|
6 |
from chronos import ChronosPipeline
|
7 |
|
8 |
class TimeSeriesForecaster:
|
9 |
-
def __init__(self, model_name="amazon/chronos-t5-
|
10 |
self.pipeline = ChronosPipeline.from_pretrained(
|
11 |
model_name,
|
12 |
device_map="cuda" if torch.cuda.is_available() else "cpu",
|
@@ -15,7 +15,7 @@ class TimeSeriesForecaster:
|
|
15 |
self.original_series = None
|
16 |
self.context = None
|
17 |
|
18 |
-
def preprocess_data(self, df, date_column, value_column, context_length=
|
19 |
"""
|
20 |
Prepare time series data from DataFrame
|
21 |
"""
|
@@ -36,151 +36,168 @@ class TimeSeriesForecaster:
|
|
36 |
|
37 |
return self.context, context_length
|
38 |
|
39 |
-
def forecast(self, context, prediction_length=
|
40 |
"""
|
41 |
Perform time series forecasting
|
42 |
"""
|
43 |
forecasts = self.pipeline.predict(context, prediction_length, num_samples=num_samples)
|
44 |
return forecasts
|
45 |
|
46 |
-
def visualize_forecast(self,
|
47 |
"""
|
48 |
-
Create visualization of predictions
|
49 |
"""
|
50 |
-
plt.figure(figsize=(
|
51 |
-
|
52 |
-
# Plot original series
|
53 |
-
plt.plot(range(len(self.original_series)), self.original_series, label='Historical Data', color='#1E88E5', linewidth=2)
|
54 |
|
55 |
# Calculate forecast statistics
|
56 |
forecast_np = forecasts[0].numpy()
|
57 |
low, median, high = np.quantile(forecast_np, [0.1, 0.5, 0.9], axis=0)
|
58 |
|
59 |
-
# Plot
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
-
|
65 |
-
plt.
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
return plt
|
71 |
|
72 |
def main():
|
73 |
-
#
|
74 |
st.set_page_config(
|
75 |
-
page_title="
|
76 |
-
page_icon="
|
77 |
-
layout="wide"
|
78 |
-
initial_sidebar_state="collapsed"
|
79 |
)
|
80 |
|
81 |
-
#
|
82 |
st.markdown("""
|
83 |
<style>
|
84 |
-
/* Modern, clean design */
|
85 |
.stApp {
|
86 |
-
background-color: #
|
87 |
-
font-family: 'Inter', sans-serif;
|
88 |
}
|
89 |
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
border
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
|
95 |
-
padding: 20px;
|
96 |
-
margin-bottom: 20px;
|
97 |
}
|
98 |
|
99 |
-
/*
|
100 |
.stFileUploader {
|
101 |
-
background-color: #
|
102 |
-
border: 2px
|
103 |
border-radius: 10px;
|
104 |
padding: 20px;
|
105 |
text-align: center;
|
106 |
-
}
|
107 |
-
|
108 |
-
/* Buttons */
|
109 |
-
.stButton>button {
|
110 |
-
background-color: #1E88E5;
|
111 |
-
color: white;
|
112 |
-
border-radius: 6px;
|
113 |
transition: all 0.3s ease;
|
114 |
}
|
115 |
-
.
|
116 |
-
|
117 |
-
|
118 |
}
|
119 |
</style>
|
120 |
""", unsafe_allow_html=True)
|
121 |
|
122 |
-
#
|
123 |
-
st.markdown(
|
124 |
-
|
|
|
|
|
125 |
|
126 |
-
# File upload
|
127 |
-
st.markdown("<div class='card'>", unsafe_allow_html=True)
|
128 |
uploaded_file = st.file_uploader(
|
129 |
-
"Upload
|
130 |
type=['csv'],
|
131 |
-
help="
|
132 |
-
label_visibility="collapsed"
|
133 |
)
|
134 |
-
st.markdown("</div>", unsafe_allow_html=True)
|
135 |
|
136 |
-
# Data and forecast configuration
|
137 |
if uploaded_file is not None:
|
138 |
# Read CSV
|
139 |
df = pd.read_csv(uploaded_file)
|
140 |
|
141 |
-
#
|
142 |
-
st.markdown("<div class='card'>", unsafe_allow_html=True)
|
143 |
col1, col2 = st.columns(2)
|
144 |
|
145 |
with col1:
|
146 |
-
date_column = st.selectbox(
|
147 |
-
'Select Date Column',
|
148 |
-
options=df.columns,
|
149 |
-
help="Choose the column representing timestamps"
|
150 |
-
)
|
151 |
|
152 |
with col2:
|
153 |
value_column = st.selectbox(
|
154 |
-
'
|
155 |
-
options=[col for col in df.columns if col != date_column]
|
156 |
-
help="Choose the numeric column to forecast"
|
157 |
)
|
158 |
|
159 |
-
#
|
160 |
col3, col4 = st.columns(2)
|
161 |
|
162 |
with col3:
|
163 |
context_length = st.slider(
|
164 |
-
'Context
|
165 |
-
min_value=
|
166 |
-
max_value=
|
167 |
-
value=
|
168 |
-
help="Number of
|
169 |
)
|
170 |
|
171 |
with col4:
|
172 |
prediction_length = st.slider(
|
173 |
-
'
|
174 |
-
min_value=
|
175 |
-
max_value=
|
176 |
-
value=
|
177 |
help="Number of future time steps to predict"
|
178 |
)
|
179 |
|
180 |
-
# Forecast
|
181 |
if st.button('Generate Forecast'):
|
182 |
try:
|
183 |
-
# Initialize forecaster
|
184 |
forecaster = TimeSeriesForecaster()
|
185 |
|
186 |
# Preprocess data
|
@@ -195,22 +212,17 @@ def main():
|
|
195 |
# Perform forecasting
|
196 |
forecasts = forecaster.forecast(context, prediction_length)
|
197 |
|
198 |
-
# Visualization
|
199 |
-
|
200 |
-
st.subheader('Forecast Visualization')
|
201 |
-
plt = forecaster.visualize_forecast(context, forecasts)
|
202 |
st.pyplot(plt)
|
203 |
-
st.markdown("</div>", unsafe_allow_html=True)
|
204 |
-
|
205 |
-
# Forecast details card
|
206 |
-
st.markdown("<div class='card'>", unsafe_allow_html=True)
|
207 |
-
st.subheader('Forecast Details')
|
208 |
|
|
|
209 |
forecast_np = forecasts[0].numpy()
|
210 |
forecast_mean = forecast_np.mean(axis=0)
|
211 |
forecast_lower = np.percentile(forecast_np, 10, axis=0)
|
212 |
forecast_upper = np.percentile(forecast_np, 90, axis=0)
|
213 |
|
|
|
214 |
prediction_df = pd.DataFrame({
|
215 |
'Mean Forecast': forecast_mean,
|
216 |
'Lower Bound (10%)': forecast_lower,
|
@@ -218,12 +230,9 @@ def main():
|
|
218 |
})
|
219 |
|
220 |
st.dataframe(prediction_df)
|
221 |
-
st.markdown("</div>", unsafe_allow_html=True)
|
222 |
|
223 |
except Exception as e:
|
224 |
-
st.error(f"
|
225 |
-
|
226 |
-
st.markdown("</div>", unsafe_allow_html=True)
|
227 |
|
228 |
if __name__ == '__main__':
|
229 |
main()
|
|
|
6 |
from chronos import ChronosPipeline
|
7 |
|
8 |
class TimeSeriesForecaster:
|
9 |
+
def __init__(self, model_name="amazon/chronos-t5-large"):
|
10 |
self.pipeline = ChronosPipeline.from_pretrained(
|
11 |
model_name,
|
12 |
device_map="cuda" if torch.cuda.is_available() else "cpu",
|
|
|
15 |
self.original_series = None
|
16 |
self.context = None
|
17 |
|
18 |
+
def preprocess_data(self, df, date_column, value_column, context_length=100, prediction_length=365):
|
19 |
"""
|
20 |
Prepare time series data from DataFrame
|
21 |
"""
|
|
|
36 |
|
37 |
return self.context, context_length
|
38 |
|
39 |
+
def forecast(self, context, prediction_length=365, num_samples=100):
|
40 |
"""
|
41 |
Perform time series forecasting
|
42 |
"""
|
43 |
forecasts = self.pipeline.predict(context, prediction_length, num_samples=num_samples)
|
44 |
return forecasts
|
45 |
|
46 |
+
def visualize_forecast(self, forecasts, original_series):
|
47 |
"""
|
48 |
+
Create comprehensive visualization of predictions
|
49 |
"""
|
50 |
+
plt.figure(figsize=(16, 8), dpi=100, facecolor='white')
|
|
|
|
|
|
|
51 |
|
52 |
# Calculate forecast statistics
|
53 |
forecast_np = forecasts[0].numpy()
|
54 |
low, median, high = np.quantile(forecast_np, [0.1, 0.5, 0.9], axis=0)
|
55 |
|
56 |
+
# Plot original series
|
57 |
+
plt.plot(
|
58 |
+
range(len(original_series)),
|
59 |
+
original_series,
|
60 |
+
label='Historical Data',
|
61 |
+
color='#2C3E50',
|
62 |
+
linewidth=2,
|
63 |
+
alpha=0.7
|
64 |
+
)
|
65 |
+
|
66 |
+
# Forecast index
|
67 |
+
forecast_index = range(len(original_series), len(original_series) + len(median))
|
68 |
+
|
69 |
+
# Plot median forecast
|
70 |
+
plt.plot(
|
71 |
+
forecast_index,
|
72 |
+
median,
|
73 |
+
color='#3498DB',
|
74 |
+
linewidth=3,
|
75 |
+
label='Median Forecast'
|
76 |
+
)
|
77 |
|
78 |
+
# Plot prediction interval
|
79 |
+
plt.fill_between(
|
80 |
+
forecast_index,
|
81 |
+
low,
|
82 |
+
high,
|
83 |
+
color='#3498DB',
|
84 |
+
alpha=0.2,
|
85 |
+
label='90% Prediction Interval'
|
86 |
+
)
|
87 |
+
|
88 |
+
plt.title('Advanced Time Series Forecast', fontsize=18, fontweight='bold', color='#2C3E50')
|
89 |
+
plt.xlabel('Time Steps', fontsize=12, color='#34495E')
|
90 |
+
plt.ylabel('Value', fontsize=12, color='#34495E')
|
91 |
+
plt.legend(frameon=False)
|
92 |
+
plt.grid(True, linestyle='--', color='#BDC3C7', alpha=0.5)
|
93 |
+
|
94 |
+
# Sophisticated styling
|
95 |
+
plt.tight_layout()
|
96 |
|
97 |
return plt
|
98 |
|
99 |
def main():
|
100 |
+
# Page configuration
|
101 |
st.set_page_config(
|
102 |
+
page_title="Forecast Maestro",
|
103 |
+
page_icon="📊",
|
104 |
+
layout="wide"
|
|
|
105 |
)
|
106 |
|
107 |
+
# Modern, minimalist styling
|
108 |
st.markdown("""
|
109 |
<style>
|
|
|
110 |
.stApp {
|
111 |
+
background-color: #FFFFFF;
|
112 |
+
font-family: 'Inter', 'Roboto', sans-serif;
|
113 |
}
|
114 |
|
115 |
+
.stButton>button {
|
116 |
+
background-color: #3498DB;
|
117 |
+
color: white;
|
118 |
+
border: none;
|
119 |
+
border-radius: 8px;
|
120 |
+
padding: 10px 20px;
|
121 |
+
transition: all 0.3s ease;
|
122 |
+
font-weight: 600;
|
123 |
+
}
|
124 |
+
|
125 |
+
.stButton>button:hover {
|
126 |
+
background-color: #2980B9;
|
127 |
+
transform: scale(1.05);
|
128 |
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
|
|
|
|
|
129 |
}
|
130 |
|
131 |
+
/* Sleek file uploader */
|
132 |
.stFileUploader {
|
133 |
+
background-color: #F8F9FA;
|
134 |
+
border: 2px solid #3498DB;
|
135 |
border-radius: 10px;
|
136 |
padding: 20px;
|
137 |
text-align: center;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
transition: all 0.3s ease;
|
139 |
}
|
140 |
+
.stFileUploader:hover {
|
141 |
+
border-color: #2980B9;
|
142 |
+
background-color: #EAF2F8;
|
143 |
}
|
144 |
</style>
|
145 |
""", unsafe_allow_html=True)
|
146 |
|
147 |
+
# Elegant title
|
148 |
+
st.markdown(
|
149 |
+
"<h1 style='text-align: center; color: #3498DB; margin-bottom: 30px;'>🔮 Forecast Maestro</h1>",
|
150 |
+
unsafe_allow_html=True
|
151 |
+
)
|
152 |
|
153 |
+
# File upload
|
|
|
154 |
uploaded_file = st.file_uploader(
|
155 |
+
"Upload Time Series Data",
|
156 |
type=['csv'],
|
157 |
+
help="CSV with timestamp and numeric columns"
|
|
|
158 |
)
|
|
|
159 |
|
|
|
160 |
if uploaded_file is not None:
|
161 |
# Read CSV
|
162 |
df = pd.read_csv(uploaded_file)
|
163 |
|
164 |
+
# Column selection
|
|
|
165 |
col1, col2 = st.columns(2)
|
166 |
|
167 |
with col1:
|
168 |
+
date_column = st.selectbox('Date Column', options=df.columns)
|
|
|
|
|
|
|
|
|
169 |
|
170 |
with col2:
|
171 |
value_column = st.selectbox(
|
172 |
+
'Value Column',
|
173 |
+
options=[col for col in df.columns if col != date_column]
|
|
|
174 |
)
|
175 |
|
176 |
+
# Advanced prediction settings
|
177 |
col3, col4 = st.columns(2)
|
178 |
|
179 |
with col3:
|
180 |
context_length = st.slider(
|
181 |
+
'Historical Context',
|
182 |
+
min_value=30,
|
183 |
+
max_value=500,
|
184 |
+
value=100,
|
185 |
+
help="Number of past data points to analyze"
|
186 |
)
|
187 |
|
188 |
with col4:
|
189 |
prediction_length = st.slider(
|
190 |
+
'Forecast Horizon',
|
191 |
+
min_value=30,
|
192 |
+
max_value=1000,
|
193 |
+
value=365,
|
194 |
help="Number of future time steps to predict"
|
195 |
)
|
196 |
|
197 |
+
# Forecast generation
|
198 |
if st.button('Generate Forecast'):
|
199 |
try:
|
200 |
+
# Initialize and run forecaster
|
201 |
forecaster = TimeSeriesForecaster()
|
202 |
|
203 |
# Preprocess data
|
|
|
212 |
# Perform forecasting
|
213 |
forecasts = forecaster.forecast(context, prediction_length)
|
214 |
|
215 |
+
# Visualization
|
216 |
+
plt = forecaster.visualize_forecast(forecasts, forecaster.original_series)
|
|
|
|
|
217 |
st.pyplot(plt)
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
+
# Forecast details
|
220 |
forecast_np = forecasts[0].numpy()
|
221 |
forecast_mean = forecast_np.mean(axis=0)
|
222 |
forecast_lower = np.percentile(forecast_np, 10, axis=0)
|
223 |
forecast_upper = np.percentile(forecast_np, 90, axis=0)
|
224 |
|
225 |
+
st.subheader('Forecast Insights')
|
226 |
prediction_df = pd.DataFrame({
|
227 |
'Mean Forecast': forecast_mean,
|
228 |
'Lower Bound (10%)': forecast_lower,
|
|
|
230 |
})
|
231 |
|
232 |
st.dataframe(prediction_df)
|
|
|
233 |
|
234 |
except Exception as e:
|
235 |
+
st.error(f"Forecast generation error: {str(e)}")
|
|
|
|
|
236 |
|
237 |
if __name__ == '__main__':
|
238 |
main()
|