Derek Thomas
commited on
Commit
·
97ab62b
0
Parent(s):
Duplicate from derek-thomas/probabilistic-forecast
Browse files- .gitattributes +34 -0
- .gitignore +2 -0
- AirPassengers.csv +1 -0
- README.md +14 -0
- app.py +74 -0
- make_plot.py +114 -0
- packages.txt +0 -0
- requirements.txt +3 -0
.gitattributes
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.idea
|
| 2 |
+
lightning_logs
|
AirPassengers.csv
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Month,#Passengers
|
README.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Probablistic Forecasting
|
| 3 |
+
emoji: 🐨
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: gray
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 3.27.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
duplicated_from: derek-thomas/probabilistic-forecast
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from gluonts.dataset.pandas import PandasDataset
|
| 4 |
+
from gluonts.dataset.split import split
|
| 5 |
+
from gluonts.torch.model.deepar import DeepAREstimator
|
| 6 |
+
|
| 7 |
+
from make_plot import plot_forecast, plot_train_test
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def offset_calculation(prediction_length, rolling_windows, length):
|
| 11 |
+
row_offset = -1 * prediction_length * rolling_windows
|
| 12 |
+
if abs(row_offset) > 0.95 * length:
|
| 13 |
+
raise gr.Error("Reduce prediction_length * rolling_windows")
|
| 14 |
+
return row_offset
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def preprocess(input_data, prediction_length, rolling_windows, progress=gr.Progress(track_tqdm=True)):
|
| 18 |
+
df = pd.read_csv(input_data.name, index_col=0, parse_dates=True)
|
| 19 |
+
row_offset = offset_calculation(prediction_length, rolling_windows, len(df))
|
| 20 |
+
return plot_train_test(df.iloc[:row_offset], df.iloc[row_offset:])
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def train_and_forecast(input_data, prediction_length, rolling_windows, epochs, progress=gr.Progress(track_tqdm=True)):
|
| 24 |
+
if not input_data:
|
| 25 |
+
raise gr.Error("Upload a file with the Upload button")
|
| 26 |
+
try:
|
| 27 |
+
df = pd.read_csv(input_data.name, index_col=0, parse_dates=True)
|
| 28 |
+
except AttributeError:
|
| 29 |
+
raise gr.Error("Upload a file with the Upload button")
|
| 30 |
+
|
| 31 |
+
row_offset = offset_calculation(prediction_length, rolling_windows, len(df))
|
| 32 |
+
|
| 33 |
+
gluon_df = PandasDataset(df, target=df.columns[0])
|
| 34 |
+
|
| 35 |
+
training_data, test_gen = split(gluon_df, offset=row_offset)
|
| 36 |
+
|
| 37 |
+
model = DeepAREstimator(
|
| 38 |
+
prediction_length=prediction_length,
|
| 39 |
+
freq=gluon_df.freq,
|
| 40 |
+
trainer_kwargs=dict(max_epochs=epochs),
|
| 41 |
+
).train(
|
| 42 |
+
training_data=training_data,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
test_data = test_gen.generate_instances(prediction_length=prediction_length, windows=rolling_windows)
|
| 46 |
+
forecasts = list(model.predict(test_data.input))
|
| 47 |
+
return plot_forecast(df, forecasts)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
with gr.Blocks() as demo:
|
| 51 |
+
gr.Markdown("""
|
| 52 |
+
# How to use
|
| 53 |
+
Upload a univariate csv with the first column showing your dates and the second column having your data
|
| 54 |
+
|
| 55 |
+
# How it works
|
| 56 |
+
1. Click **Upload** to upload your data
|
| 57 |
+
2. Click **Run**
|
| 58 |
+
- This app will visualize your data and then train an estimator and show its predictions
|
| 59 |
+
""")
|
| 60 |
+
with gr.Accordion(label='Hyperparameters'):
|
| 61 |
+
with gr.Row():
|
| 62 |
+
prediction_length = gr.Number(value=12, label='Prediction Length', precision=0)
|
| 63 |
+
windows = gr.Number(value=3, label='Number of Windows', precision=0)
|
| 64 |
+
epochs = gr.Number(value=10, label='Number of Epochs', precision=0)
|
| 65 |
+
with gr.Row():
|
| 66 |
+
upload_btn = gr.UploadButton(label="Upload")
|
| 67 |
+
train_btn = gr.Button(label="Train and Forecast")
|
| 68 |
+
plot = gr.Plot()
|
| 69 |
+
|
| 70 |
+
upload_btn.upload(fn=preprocess, inputs=[upload_btn, prediction_length, windows], outputs=plot)
|
| 71 |
+
train_btn.click(fn=train_and_forecast, inputs=[upload_btn, prediction_length, epochs, windows], outputs=plot)
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
demo.queue().launch()
|
make_plot.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import plotly.graph_objects as go
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def plot_train_test(df1: pd.DataFrame, df2: pd.DataFrame) -> go.Figure:
|
| 9 |
+
"""
|
| 10 |
+
Plot the training and test datasets using Plotly.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
df1 (pd.DataFrame): Train dataset
|
| 14 |
+
df2 (pd.DataFrame): Test dataset
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
None
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
# Create a Plotly figure
|
| 21 |
+
fig = go.Figure()
|
| 22 |
+
|
| 23 |
+
# Add the first scatter plot with steelblue color
|
| 24 |
+
fig.add_trace(go.Scatter(
|
| 25 |
+
x=df1.index,
|
| 26 |
+
y=df1.iloc[:, 0],
|
| 27 |
+
mode='lines',
|
| 28 |
+
name='Training Data',
|
| 29 |
+
line=dict(color='steelblue'),
|
| 30 |
+
marker=dict(color='steelblue')
|
| 31 |
+
))
|
| 32 |
+
|
| 33 |
+
# Add the second scatter plot with yellow color
|
| 34 |
+
fig.add_trace(go.Scatter(
|
| 35 |
+
x=df2.index,
|
| 36 |
+
y=df2.iloc[:, 0],
|
| 37 |
+
mode='lines',
|
| 38 |
+
name='Test Data',
|
| 39 |
+
line=dict(color='gold'),
|
| 40 |
+
marker=dict(color='gold')
|
| 41 |
+
))
|
| 42 |
+
|
| 43 |
+
# Customize the layout
|
| 44 |
+
fig.update_layout(
|
| 45 |
+
title='Univariate Time Series',
|
| 46 |
+
xaxis=dict(title='Date'),
|
| 47 |
+
yaxis=dict(title='Value'),
|
| 48 |
+
showlegend=True,
|
| 49 |
+
template='plotly_white'
|
| 50 |
+
)
|
| 51 |
+
return fig
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]) -> go.Figure:
|
| 55 |
+
"""
|
| 56 |
+
Plot the true values and forecasts using Plotly.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
df (pd.DataFrame): DataFrame with the true values. Assumed to have an index and columns.
|
| 60 |
+
forecasts (List[pd.DataFrame]): List of DataFrames containing the forecasts.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
go.Figure: Plotly figure object.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
# Create a Plotly figure
|
| 67 |
+
fig = go.Figure()
|
| 68 |
+
|
| 69 |
+
# Add the true values trace
|
| 70 |
+
fig.add_trace(go.Scatter(
|
| 71 |
+
x=pd.to_datetime(df.index),
|
| 72 |
+
y=df.iloc[:, 0],
|
| 73 |
+
mode='lines',
|
| 74 |
+
name='True values',
|
| 75 |
+
line=dict(color='black')
|
| 76 |
+
))
|
| 77 |
+
|
| 78 |
+
# Add the forecast traces
|
| 79 |
+
colors = ["green", "blue", "purple"]
|
| 80 |
+
for i, forecast in enumerate(forecasts):
|
| 81 |
+
color = colors[i]
|
| 82 |
+
for sample in forecast.samples:
|
| 83 |
+
fig.add_trace(go.Scatter(
|
| 84 |
+
x=forecast.index.to_timestamp(),
|
| 85 |
+
y=sample,
|
| 86 |
+
mode='lines',
|
| 87 |
+
opacity=0.15, # Adjust opacity to control visibility of individual samples
|
| 88 |
+
name=f'Forecast {i + 1}',
|
| 89 |
+
showlegend=False, # Hide the individual forecast series from the legend
|
| 90 |
+
hoverinfo='none', # Disable hover information for the forecast series
|
| 91 |
+
line=dict(color=color)
|
| 92 |
+
))
|
| 93 |
+
# Add the average
|
| 94 |
+
mean_forecast = np.mean(forecast.samples, axis=0)
|
| 95 |
+
fig.add_trace(go.Scatter(
|
| 96 |
+
x=forecast.index.to_timestamp(),
|
| 97 |
+
y=mean_forecast,
|
| 98 |
+
mode='lines',
|
| 99 |
+
name=f'Mean Forecast',
|
| 100 |
+
line=dict(color='red', dash='dash')
|
| 101 |
+
))
|
| 102 |
+
|
| 103 |
+
# Customize the layout
|
| 104 |
+
fig.update_layout(
|
| 105 |
+
title='Passenger Forecast',
|
| 106 |
+
xaxis=dict(title='Index'),
|
| 107 |
+
yaxis=dict(title='Passenger Count'),
|
| 108 |
+
showlegend=True,
|
| 109 |
+
legend=dict(x=0, y=1, font=dict(size=16)),
|
| 110 |
+
hovermode='x' # Enable x-axis hover for better interactivity
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Return the figure
|
| 114 |
+
return fig
|
packages.txt
ADDED
|
File without changes
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gluonts[torch,pro]
|
| 2 |
+
pandas
|
| 3 |
+
plotly
|