Spaces:
Running
on
T4
Running
on
T4
preset forecasting data used only when default forecast length is set, tested, updated README.md
Browse files
README.md
CHANGED
@@ -7,4 +7,63 @@ sdk: docker
|
|
7 |
app_port: 7860
|
8 |
---
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
app_port: 7860
|
8 |
---
|
9 |
|
10 |
+
# TiRex – Zero‑Shot Time Series Forecasting App
|
11 |
+
|
12 |
+
A Gradio‑based interactive web app to perform zero‑shot time series forecasting using the TiRex model. Upload your own CSV/XLSX/Parquet files or choose from built‑in presets, filter series by name, and visualize quantile forecasts over your chosen horizon.
|
13 |
+
|
14 |
+
---
|
15 |
+
|
16 |
+
## 🔍 Features
|
17 |
+
|
18 |
+
- **Zero‑Shot Forecasting**: Powered by the [`NX-AI/TiRex`](https://huggingface.co/NX-AI/TiRex) model.
|
19 |
+
- **Custom Data Upload**: Accepts CSV, XLSX, and Parquet.
|
20 |
+
- **Preset Datasets**: Includes `loop.csv`, `air_passangers.csv`, and `ett2.csv` for quick demos.
|
21 |
+
- **Interactive Filtering**: Search, check/uncheck, and plot only the series you care about.
|
22 |
+
- **Quantile Forecasts**: Displays historical data, median forecast line, and 10–90% quantile shading.
|
23 |
+
- **Configurable Horizon**: Slider to set forecast length (1–512 steps).
|
24 |
+
- **Automatic Defaults**: Detects best forecast‐length defaults for presets.
|
25 |
+
|
26 |
+
---
|
27 |
+
|
28 |
+
## 📊 Data Format
|
29 |
+
|
30 |
+
### With Named Series
|
31 |
+
```csv
|
32 |
+
AAPL,120.5,121.0,119.8,122.1,123.5,...
|
33 |
+
AMZN,3300.0,3310.5,3295.2,3305.8,3315.1,...
|
34 |
+
GOOGL,2800.1,2795.3,2810.7,2805.2,2820.4,...
|
35 |
+
```
|
36 |
+
|
37 |
+
### Without Named Series
|
38 |
+
```csv
|
39 |
+
120.5,121.0,119.8,122.1,123.5,...
|
40 |
+
3300.0,3310.5,3295.2,3305.8,3315.1,...
|
41 |
+
2800.1,2795.3,2810.7,2805.2,2820.4,...
|
42 |
+
```
|
43 |
+
|
44 |
+
### Key Rules:
|
45 |
+
- **One row per time series**
|
46 |
+
- **Consistent naming**: Either all rows have names (first column) or none do
|
47 |
+
- **Numeric data**: All values after the optional name column must be numeric
|
48 |
+
- **Minimum length**: Time series must have at least `forecast_length + 10` data points
|
49 |
+
- **Maximum constraints**: Up to 30 time series and 2048 time steps per series
|
50 |
+
|
51 |
+
## 🔧 Configuration
|
52 |
+
|
53 |
+
### Forecast Length
|
54 |
+
- **Default**: 64 steps
|
55 |
+
- **Range**: 1-512 steps
|
56 |
+
- **Auto-adjustment**: Preset datasets have optimized forecast lengths:
|
57 |
+
- `loop.csv` and `ett2.csv`: 256 steps
|
58 |
+
- `air_passangers.csv`: 48 steps
|
59 |
+
|
60 |
+
### Model Settings
|
61 |
+
- **Device**: CUDA (T4 GPU)
|
62 |
+
- **Quantiles**: 10%, 50% (median), 90% prediction intervals
|
63 |
+
|
64 |
+
## 📈 Output Features
|
65 |
+
|
66 |
+
- **Historical data**: Blue line showing input time series
|
67 |
+
- **Median forecast**: Orange line for point predictions
|
68 |
+
- **Uncertainty bands**: Gray shaded area showing 10%-90%
|
69 |
+
|
app.py
CHANGED
@@ -18,13 +18,13 @@ torch.manual_seed(42)
|
|
18 |
model: ForecastModel = load_model("NX-AI/TiRex",device='cuda')
|
19 |
|
20 |
def model_forecast(input_data, forecast_length=256, file_name=None):
|
21 |
-
if os.path.basename(file_name) == "loop.csv":
|
22 |
_forecast_tensor = torch.load("data/loop_forecast_256.pt")
|
23 |
return _forecast_tensor[:,:forecast_length,:]
|
24 |
-
elif os.path.basename(file_name) == "ett2.csv":
|
25 |
_forecast_tensor = torch.load("data/ett2_forecast_256.pt")
|
26 |
return _forecast_tensor[:,:forecast_length,:]
|
27 |
-
elif os.path.basename(file_name) == "air_passangers.csv":
|
28 |
_forecast_tensor = torch.load("data/air_passengers_forecast_48.pt")
|
29 |
return _forecast_tensor[:,:forecast_length,:]
|
30 |
else:
|
@@ -413,5 +413,5 @@ with gr.Blocks(fill_width=True,theme=gr.themes.Ocean()) as demo:
|
|
413 |
'''
|
414 |
gradio app.py
|
415 |
ssh -L 7860:localhost:7860 nikita_blago@oracle-gpu-controller -t \
|
416 |
-
ssh -L 7860:localhost:7860 compute-permanent-node-
|
417 |
'''
|
|
|
18 |
model: ForecastModel = load_model("NX-AI/TiRex",device='cuda')
|
19 |
|
20 |
def model_forecast(input_data, forecast_length=256, file_name=None):
|
21 |
+
if os.path.basename(file_name) == "loop.csv" and forecast_length==256:
|
22 |
_forecast_tensor = torch.load("data/loop_forecast_256.pt")
|
23 |
return _forecast_tensor[:,:forecast_length,:]
|
24 |
+
elif os.path.basename(file_name) == "ett2.csv" and forecast_length==256:
|
25 |
_forecast_tensor = torch.load("data/ett2_forecast_256.pt")
|
26 |
return _forecast_tensor[:,:forecast_length,:]
|
27 |
+
elif os.path.basename(file_name) == "air_passangers.csv"and forecast_length==48:
|
28 |
_forecast_tensor = torch.load("data/air_passengers_forecast_48.pt")
|
29 |
return _forecast_tensor[:,:forecast_length,:]
|
30 |
else:
|
|
|
413 |
'''
|
414 |
gradio app.py
|
415 |
ssh -L 7860:localhost:7860 nikita_blago@oracle-gpu-controller -t \
|
416 |
+
ssh -L 7860:localhost:7860 compute-permanent-node-195
|
417 |
'''
|