TiRex-demo / app.py
Blago123's picture
transpose switch, context cutoff as in paper, full context for custom data
fce1568
import io
import pandas as pd
import torch
import plotly.graph_objects as go
from PIL import Image
import numpy as np
import gradio as gr
import os
from plotly.subplots import make_subplots
from tirex import load_model, ForecastModel
# ----------------------------
# Helper functions (logic mostly unchanged)
# ----------------------------
torch.manual_seed(42)
model: ForecastModel = load_model("NX-AI/TiRex",device='cuda')
def model_forecast(input_data, forecast_length=256, file_name=None):
if os.path.basename(file_name) == "loop.csv" and forecast_length==256:
_forecast_tensor = torch.load("data/loop_forecast_256.pt")
return _forecast_tensor
elif os.path.basename(file_name) == "ett2.csv" and forecast_length==256:
_forecast_tensor = torch.load("data/ett2_forecast_256.pt")
return _forecast_tensor
elif os.path.basename(file_name) == "air_passengers.csv"and forecast_length==24:
_forecast_tensor = torch.load("data/air_passengers_forecast_24.pt")
return _forecast_tensor
else:
forecast = model.forecast(context=input_data, prediction_length=forecast_length)
return forecast[0]
def load_table(file_path):
ext = file_path.split(".")[-1].lower()
if ext == "csv":
return pd.read_csv(file_path)
elif ext in ("xls", "xlsx"):
return pd.read_excel(file_path)
elif ext == "parquet":
return pd.read_parquet(file_path)
else:
raise ValueError("Unsupported format. Use CSV, XLS, XLSX, or PARQUET.")
def extract_names_and_update(file, preset_filename, transpose):
try:
# Determine which file to use and get default forecast length
if file is not None:
df = load_table(file.name)
default_length = get_default_forecast_length(file.name)
else:
if not preset_filename or preset_filename == "-- No preset selected --":
return gr.update(choices=[], value=[]), [], gr.update(value=256)
df = load_table(preset_filename)
default_length = get_default_forecast_length(preset_filename)
# if user wants to transpose, do it here
if transpose:
df = df.T
if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():
names = df.iloc[:, 0].tolist()
else:
names = [f"Series {i}" for i in range(len(df))]
return (
gr.update(choices=names, value=names),
names,
gr.update(value=default_length)
)
except Exception:
return gr.update(choices=[], value=[]), [], gr.update(value=256)
def filter_names(search_term, all_names):
if not all_names:
return gr.update(choices=[], value=[])
if not search_term:
return gr.update(choices=all_names, value=all_names)
lower = search_term.lower()
filtered = [n for n in all_names if lower in str(n).lower()]
return gr.update(choices=filtered, value=filtered)
def check_all(names_list):
return gr.update(value=names_list)
def uncheck_all(_):
return gr.update(value=[])
def get_default_forecast_length(file_path):
"""Get default forecast length based on filename"""
if file_path is None:
return 64
filename = os.path.basename(file_path)
if filename == "loop.csv" or filename == "ett2.csv":
return 256
elif filename == "air_passengers.csv":
return 24
else:
return 64
def display_filtered_forecast(file, preset_filename, selected_names, forecast_length, transpose):
try:
# 1) If no file or preset selected, show an error
if file is None and (preset_filename is None or preset_filename == "-- No preset selected --"):
return None, "No file selected."
# 2) Load DataFrame and remember which filename to pass to model_forecast
if file is not None:
df = load_table(file.name)
file_name = file.name
else:
df = load_table(preset_filename)
file_name = preset_filename
if transpose:
df = df.T
# 3) Determine whether first column is names or numeric
if (
df.shape[1] > 0
and df.iloc[:, 0].dtype == object
and not df.iloc[:, 0].str.isnumeric().all()
):
if df.shape[1]>2048 and file is not None:
df = pd.concat([ df.iloc[:, [0]], df.iloc[:, -2048:] ], axis=1)
gr.Info("Maximum of 2048 steps per timeseries (row) is allowed, hence last 2048 kept. ℹ️", duration=5)
all_names = df.iloc[:, 0].tolist()
data_only_full = df.iloc[:, 1:].astype(float)
else:
if df.shape[1]>2048 and file is not None:
df = df.iloc[:, -2048:]
gr.Info("Maximum of 2048 steps per timeseries (row) is allowed, hence last 2048 kept. ℹ️", duration=5)
all_names = [f"Series {i}" for i in range(len(df))]
data_only_full = df.astype(float)
data_only = data_only_full
# 4) Build mask from selected_names
mask = [name in selected_names for name in all_names]
if not any(mask):
return None, "No timeseries chosen to plot."
filtered_data = data_only.iloc[mask, :].values # shape = (n_selected, seq_len)
filtered_data_only_full = data_only_full.iloc[mask, :].values # ** Added to show prediction accuracy
filtered_names = [all_names[i] for i, m in enumerate(mask) if m]
n_selected = filtered_data.shape[0]
if n_selected>30:
raise gr.Error("Maximum of 30 timeseries (rows) is possible to choose", duration=5)
# 5) First call model_forecast on all series, then select only the masked rows
full_data = data_only.values # shape = (n_all, seq_len)
if file is not None:
full_out = model_forecast(full_data, forecast_length=forecast_length, file_name=file_name)
else:
if preset_filename=='data/ett2.csv' or preset_filename=="data/loop.csv":
full_out = model_forecast(full_data[:, :2048], forecast_length=forecast_length, file_name=file_name)
elif preset_filename=="data/air_passengers.csv":
full_out = model_forecast(full_data[:, :132], forecast_length=forecast_length, file_name=file_name)
# Now pick only the rows we actually filtered
out = full_out[mask, :, :] # shape = (n_selected, pred_len, n_q)
inp = torch.tensor(filtered_data)
inp_full = torch.tensor(filtered_data_only_full) # ** Added to show prediction accuracy
# 6) Create one subplot per selected series, with vertical spacing
subplot_height_px = 350 # px per subplot
n_selected = len(filtered_names)
fig = make_subplots(
rows=n_selected,
cols=1,
shared_xaxes=False,
subplot_titles=filtered_names,
row_heights=[1] * n_selected, # all rows equal height
)
fig.update_layout(
height=subplot_height_px * n_selected,
template="plotly_dark",
margin=dict(t=50, b=50)
)
for idx in range(n_selected):
ts = inp[idx].numpy().tolist()
ts_full = inp_full[idx].numpy().tolist()
qp = out[idx].numpy()
series_name = filtered_names[idx]
pred_len = qp.shape[0]
if file is not None:
x_pred = list(range(len(ts), len(ts) + pred_len))
else:
if preset_filename=='data/ett2.csv' or preset_filename=="data/loop.csv":
x_pred = list(range(2048, 2048 + pred_len))
elif preset_filename=="data/air_passengers.csv":
x_pred = list(range(132, 132 + pred_len))
# a) plot historical data (blue line)
x_hist = list(range(len(ts_full)))
if x_pred[-1]<x_hist[-1]:
diff = len(x_hist)-len(x_hist[:x_pred[-1]])
x_hist = x_hist[:x_pred[-1]]
ts_full = ts_full[:-diff]
fig.add_trace(
go.Scatter(
x=x_hist,
y=ts_full,
mode="lines",
name=f"{series_name} – Given Data",
line=dict(color="blue", width=2),
showlegend=False
),
row=idx + 1, col=1
)
# b) compute forecast indices
lower_q = qp[:, 0]
upper_q = qp[:, -1]
n_q = qp.shape[1]
median_idx = n_q // 2
median_q = qp[:, median_idx]
# c) lower‐bound (invisible)
fig.add_trace(
go.Scatter(
x=x_pred,
y=lower_q,
mode="lines",
line=dict(color="rgba(0,0,0,0)", width=0),
name=f"{series_name} – 10% Quantile",
hovertemplate="10% Quantile: %{y:.2f}<extra></extra>",
showlegend=False
),
row=idx + 1, col=1
)
# d) upper‐bound (shaded area)
fig.add_trace(
go.Scatter(
x=x_pred,
y=upper_q,
mode="lines",
line=dict(color="rgba(0,0,0,0)", width=0),
fill="tonexty",
fillcolor="rgba(128,128,128,0.3)",
name=f"{series_name} – 90% Quantile",
hovertemplate="90% Quantile: %{y:.2f}<extra></extra>",
showlegend=False
),
row=idx + 1, col=1
)
# e) median forecast (orange line)
fig.add_trace(
go.Scatter(
x=x_pred,
y=median_q,
mode="lines",
name=f"{series_name} – Median Forecast",
line=dict(color="orange", width=2),
hovertemplate="Median: %{y:.2f}<extra></extra>",
showlegend=False
),
row=idx + 1, col=1
)
# f) label axes for each subplot
fig.update_xaxes(title_text="Time", row=idx + 1, col=1)
fig.update_yaxes(title_text="Value", row=idx + 1, col=1)
# 7) Global layout tweaks
fig.update_layout(
template="plotly_dark",
height=300 * n_selected, # 300px per subplot
title=dict(
text="Forecasts for Selected Timeseries",
x=0.5,
font=dict(size=20, family="Arial", color="white")
),
hovermode="x unified",
margin=dict(t=120, b=40, l=60, r=40),
showlegend=False
)
return fig, ""
except gr.Error as e:
raise gr.Error(e, duration=5)
except Exception as e:
return None, f"Error: {str(e)}"
# ----------------------------
# Gradio layout: two columns + instructions
# ----------------------------
with gr.Blocks(fill_width=True,theme=gr.themes.Ocean()) as demo:
gr.Markdown("# 📈 TiRex - timeseries forecasting 📊")
gr.Markdown("Upload data or choose a preset, filter by name, then click Plot.")
with gr.Row():
# Left column: controls
with gr.Column(scale=1):
gr.Markdown("## Data Selection")
file_input = gr.File(
label="Upload CSV / XLSX / PARQUET",
file_types=[".csv", ".xls", ".xlsx", ".parquet"]
)
preset_choices = ["-- No preset selected --", "data/loop.csv", "data/air_passengers.csv", 'data/ett2.csv']
preset_dropdown = gr.Dropdown(
label="Or choose a preset:",
choices=preset_choices,
value="-- No preset selected --"
)
gr.Markdown("## Forecast Length Setting")
forecast_length_slider = gr.Slider(
minimum=1,
maximum=512,
value=64,
step=1,
label="Forecast Length (Steps)",
info="Choose how many future steps to forecast."
)
gr.Markdown("## Transpose data")
transpose_switch = gr.Checkbox(
label="Transpose data (Click if your columns are timeseries)",
value=False
)
gr.Markdown("## Search / Filter")
search_box = gr.Textbox(placeholder="Type to filter (e.g. 'AMZN')")
filter_checkbox = gr.CheckboxGroup(
choices=[], value=[], label="Select which timeseries to show"
)
with gr.Row():
check_all_btn = gr.Button("✅ Check All")
uncheck_all_btn = gr.Button("❎ Uncheck All")
plot_button = gr.Button("▶️ Plot Forecasts")
errbox = gr.Textbox(label="Error Message", interactive=False)
with gr.Row():
gr.Image("static/nxai_logo.png", width=150, show_label=False, container=False)
gr.Image("static/tirex.jpeg", width=150, show_label=False, container=False)
with gr.Column(scale=5):
gr.Markdown("## Forecast Plot")
plot_output = gr.Plot()
# Instruction text below plot
gr.Markdown("## Instructions")
gr.Markdown(
"""
**How to format your data:**
- Your file must be a table (CSV, XLS, XLSX, or Parquet).
- **One row per timeseries.** Each row is treated as a separate series.
- If you want to **name** each series, put the name as the first value in **every** row:
- Example (CSV):
`AAPL, 120.5, 121.0, 119.8, ...`
`AMZN, 3300.0, 3310.5, 3295.2, ...`
- In that case, the first column is not numeric, so it will be used as the series name.
- If you do **not** want named series, simply leave out the first column entirely and have all values numeric:
- Example:
`120.5, 121.0, 119.8, ...`
`3300.0, 3310.5, 3295.2, ...`
- Then every row will be auto-named “Series 0, Series 1, …” in order.
- **Consistency rule:** Either all rows have a non-numeric first entry for the name, or none do. Do not mix.
- The rest of the columns (after the optional name) must be numeric data points for that series.
- You can filter by typing in the search box. Then check or uncheck individual names before plotting.
- Use “Check All” / “Uncheck All” to quickly select or deselect every series.
- Finally, click **Plot Forecasts** to view the quantile forecast for each selected series (for 64 steps ahead).
"""
)
gr.Markdown("## Citation")
# make citation as code block
gr.Markdown(
"""
If you use TiRex in your research, please cite our work:
```
@article{auerTiRexZeroShotForecasting2025,
title = {{{TiRex}}: {{Zero-Shot Forecasting Across Long}} and {{Short Horizons}} with {{Enhanced In-Context Learning}}},
author = {Auer, Andreas and Podest, Patrick and Klotz, Daniel and B{\"o}ck, Sebastian and Klambauer, G{\"u}nter and Hochreiter, Sepp},
journal = {ArXiv},
volume = {2505.23719},
year = {2025}
}
```
"""
)
names_state = gr.State([])
file_input.change(
fn=extract_names_and_update,
inputs=[file_input, preset_dropdown, transpose_switch],
outputs=[filter_checkbox, names_state, forecast_length_slider]
)
preset_dropdown.change(
fn=extract_names_and_update,
inputs=[file_input, preset_dropdown, transpose_switch],
outputs=[filter_checkbox, names_state, forecast_length_slider]
)
# When search term changes, filter names
search_box.change(
fn=filter_names,
inputs=[search_box, names_state],
outputs=[filter_checkbox]
)
transpose_switch.change(
fn=extract_names_and_update,
inputs=[file_input, preset_dropdown, transpose_switch],
outputs=[filter_checkbox, names_state, forecast_length_slider]
)
# Check All / Uncheck All
check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox)
uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox)
# Plot button
plot_button.click(
fn=display_filtered_forecast,
inputs=[file_input, preset_dropdown, filter_checkbox, forecast_length_slider, transpose_switch],
outputs=[plot_output, errbox]
)
demo.launch(server_name="0.0.0.0", server_port=7860)
'''
gradio app.py
ssh -L 7860:localhost:7860 nikita_blago@oracle-gpu-controller -t \
ssh -L 7860:localhost:7860 compute-permanent-node-303
'''