Blago123 commited on
Commit
8691d5d
·
1 Parent(s): 0fb9496

preset forecasting data used only when default forecast length is set, tested, updated README.md

Browse files
Files changed (2) hide show
  1. README.md +60 -1
  2. app.py +4 -4
README.md CHANGED
@@ -7,4 +7,63 @@ sdk: docker
7
  app_port: 7860
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  app_port: 7860
8
  ---
9
 
10
+ # TiRex Zero‑Shot Time Series Forecasting App
11
+
12
+ A Gradio‑based interactive web app to perform zero‑shot time series forecasting using the TiRex model. Upload your own CSV/XLSX/Parquet files or choose from built‑in presets, filter series by name, and visualize quantile forecasts over your chosen horizon.
13
+
14
+ ---
15
+
16
+ ## 🔍 Features
17
+
18
+ - **Zero‑Shot Forecasting**: Powered by the [`NX-AI/TiRex`](https://huggingface.co/NX-AI/TiRex) model.
19
+ - **Custom Data Upload**: Accepts CSV, XLSX, and Parquet.
20
+ - **Preset Datasets**: Includes `loop.csv`, `air_passangers.csv`, and `ett2.csv` for quick demos.
21
+ - **Interactive Filtering**: Search, check/uncheck, and plot only the series you care about.
22
+ - **Quantile Forecasts**: Displays historical data, median forecast line, and 10–90% quantile shading.
23
+ - **Configurable Horizon**: Slider to set forecast length (1–512 steps).
24
+ - **Automatic Defaults**: Detects best forecast‐length defaults for presets.
25
+
26
+ ---
27
+
28
+ ## 📊 Data Format
29
+
30
+ ### With Named Series
31
+ ```csv
32
+ AAPL,120.5,121.0,119.8,122.1,123.5,...
33
+ AMZN,3300.0,3310.5,3295.2,3305.8,3315.1,...
34
+ GOOGL,2800.1,2795.3,2810.7,2805.2,2820.4,...
35
+ ```
36
+
37
+ ### Without Named Series
38
+ ```csv
39
+ 120.5,121.0,119.8,122.1,123.5,...
40
+ 3300.0,3310.5,3295.2,3305.8,3315.1,...
41
+ 2800.1,2795.3,2810.7,2805.2,2820.4,...
42
+ ```
43
+
44
+ ### Key Rules:
45
+ - **One row per time series**
46
+ - **Consistent naming**: Either all rows have names (first column) or none do
47
+ - **Numeric data**: All values after the optional name column must be numeric
48
+ - **Minimum length**: Time series must have at least `forecast_length + 10` data points
49
+ - **Maximum constraints**: Up to 30 time series and 2048 time steps per series
50
+
51
+ ## 🔧 Configuration
52
+
53
+ ### Forecast Length
54
+ - **Default**: 64 steps
55
+ - **Range**: 1-512 steps
56
+ - **Auto-adjustment**: Preset datasets have optimized forecast lengths:
57
+ - `loop.csv` and `ett2.csv`: 256 steps
58
+ - `air_passangers.csv`: 48 steps
59
+
60
+ ### Model Settings
61
+ - **Device**: CUDA (T4 GPU)
62
+ - **Quantiles**: 10%, 50% (median), 90% prediction intervals
63
+
64
+ ## 📈 Output Features
65
+
66
+ - **Historical data**: Blue line showing input time series
67
+ - **Median forecast**: Orange line for point predictions
68
+ - **Uncertainty bands**: Gray shaded area showing 10%-90%
69
+
app.py CHANGED
@@ -18,13 +18,13 @@ torch.manual_seed(42)
18
  model: ForecastModel = load_model("NX-AI/TiRex",device='cuda')
19
 
20
  def model_forecast(input_data, forecast_length=256, file_name=None):
21
- if os.path.basename(file_name) == "loop.csv":
22
  _forecast_tensor = torch.load("data/loop_forecast_256.pt")
23
  return _forecast_tensor[:,:forecast_length,:]
24
- elif os.path.basename(file_name) == "ett2.csv":
25
  _forecast_tensor = torch.load("data/ett2_forecast_256.pt")
26
  return _forecast_tensor[:,:forecast_length,:]
27
- elif os.path.basename(file_name) == "air_passangers.csv":
28
  _forecast_tensor = torch.load("data/air_passengers_forecast_48.pt")
29
  return _forecast_tensor[:,:forecast_length,:]
30
  else:
@@ -413,5 +413,5 @@ with gr.Blocks(fill_width=True,theme=gr.themes.Ocean()) as demo:
413
  '''
414
  gradio app.py
415
  ssh -L 7860:localhost:7860 nikita_blago@oracle-gpu-controller -t \
416
- ssh -L 7860:localhost:7860 compute-permanent-node-368
417
  '''
 
18
  model: ForecastModel = load_model("NX-AI/TiRex",device='cuda')
19
 
20
  def model_forecast(input_data, forecast_length=256, file_name=None):
21
+ if os.path.basename(file_name) == "loop.csv" and forecast_length==256:
22
  _forecast_tensor = torch.load("data/loop_forecast_256.pt")
23
  return _forecast_tensor[:,:forecast_length,:]
24
+ elif os.path.basename(file_name) == "ett2.csv" and forecast_length==256:
25
  _forecast_tensor = torch.load("data/ett2_forecast_256.pt")
26
  return _forecast_tensor[:,:forecast_length,:]
27
+ elif os.path.basename(file_name) == "air_passangers.csv"and forecast_length==48:
28
  _forecast_tensor = torch.load("data/air_passengers_forecast_48.pt")
29
  return _forecast_tensor[:,:forecast_length,:]
30
  else:
 
413
  '''
414
  gradio app.py
415
  ssh -L 7860:localhost:7860 nikita_blago@oracle-gpu-controller -t \
416
+ ssh -L 7860:localhost:7860 compute-permanent-node-195
417
  '''