TiRex-demo / app.py
Nikita
requirements.txt
dd32739
raw
history blame
18.1 kB
import io
import pandas as pd
import torch
import plotly.graph_objects as go
from PIL import Image
import numpy as np
import gradio as gr
import os
from plotly.subplots import make_subplots
from tirex import load_model, ForecastModel
# ----------------------------
# Helper functions (logic mostly unchanged)
# ----------------------------
torch.manual_seed(42)
model: ForecastModel = load_model("NX-AI/TiRex",device='cuda')
def model_forecast(input_data, forecast_length=256, file_name=None):
if os.path.basename(file_name) == "loop.csv":
_forecast_tensor = torch.load("data/loop_forecast_512.pt")
return _forecast_tensor[:,:forecast_length,:]
elif os.path.basename(file_name) == "ett2.csv":
_forecast_tensor = torch.load("data/ett2_forecast_512.pt")
return _forecast_tensor[:,:forecast_length,:]
elif os.path.basename(file_name) == "air_passangers.csv":
_forecast_tensor = torch.load("data/air_passengers_forecast_512.pt")
return _forecast_tensor[:,:forecast_length,:]
else:
forecast = model.forecast(context=input_data, prediction_length=forecast_length)
return forecast[0]
def plot_forecast_plotly(timeseries, quantile_predictions, timeseries_name):
"""
- timeseries: 1D list/array of historical values.
- quantile_predictions: 2D array of shape (pred_len, n_q),
with quantiles sorted left→right.
- timeseries_name: string label.
"""
fig = go.Figure()
# 1) Plot historical data (blue line, no markers)
x_hist = list(range(len(timeseries)))
fig.add_trace(go.Scatter(
x=x_hist,
y=timeseries,
mode="lines", # no markers
name=f"{timeseries_name} – Given Data",
line=dict(color="blue", width=2),
))
# 2) X-axis indices for forecasts
pred_len = quantile_predictions.shape[0]
x_pred = list(range(len(timeseries) - 1, len(timeseries) - 1 + pred_len))
# 3) Extract lower, upper, and median quantiles
lower_q = quantile_predictions[:, 0]
upper_q = quantile_predictions[:, -1]
n_q = quantile_predictions.shape[1]
median_idx = n_q // 2
median_q = quantile_predictions[:, median_idx]
# 4) Lower‐bound trace (invisible line, still shows on hover)
fig.add_trace(go.Scatter(
x=x_pred,
y=lower_q,
mode="lines",
line=dict(color="rgba(0, 0, 0, 0)", width=0),
name=f"{timeseries_name} – 10% Quantile",
hovertemplate="Lower: %{y:.2f}<extra></extra>"
))
# 5) Upper‐bound trace (shaded down to lower_q)
fig.add_trace(go.Scatter(
x=x_pred,
y=upper_q,
mode="lines",
line=dict(color="rgba(0, 0, 0, 0)", width=0),
fill="tonexty",
fillcolor="rgba(128, 128, 128, 0.3)",
name=f"{timeseries_name} – 90% Quantile",
hovertemplate="Upper: %{y:.2f}<extra></extra>"
))
# 6) Median trace (orange) on top
fig.add_trace(go.Scatter(
x=x_pred,
y=median_q,
mode="lines",
name=f"{timeseries_name} – Median Forecast",
line=dict(color="orange", width=2),
hovertemplate="Median: %{y:.2f}<extra></extra>"
))
# 7) Layout: title on left (y=0.95), legend on right (y=0.95)
fig.update_layout(
template="plotly_dark",
title=dict(
text=f"Timeseries: {timeseries_name}",
x=0.10, # left‐align
xanchor="left",
y=0.90, # near top
yanchor="bottom",
font=dict(size=18, family="Arial", color="white")
),
xaxis=dict(
rangeslider=dict(visible=True), # <-- put rangeslider here
fixedrange=False
),
xaxis_title="Time",
yaxis_title="Value",
hovermode="x unified",
margin=dict(
t=120, # increase top margin to fit title+legend comfortably
b=40,
l=60,
r=40
),
# height=plot_height,
# width=plot_width,
autosize=True,
)
return fig
def load_table(file_path):
ext = file_path.split(".")[-1].lower()
if ext == "csv":
return pd.read_csv(file_path)
elif ext in ("xls", "xlsx"):
return pd.read_excel(file_path)
elif ext == "parquet":
return pd.read_parquet(file_path)
else:
raise ValueError("Unsupported format. Use CSV, XLS, XLSX, or PARQUET.")
def extract_names_and_update(file, preset_filename):
try:
# Determine which file to use and get default forecast length
if file is not None:
df = load_table(file.name)
default_length = get_default_forecast_length(file.name)
else:
if not preset_filename or preset_filename == "-- No preset selected --":
return gr.update(choices=[], value=[]), [], gr.update(value=256)
df = load_table(preset_filename)
default_length = get_default_forecast_length(preset_filename)
if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():
names = df.iloc[:, 0].tolist()
else:
names = [f"Series {i}" for i in range(len(df))]
return (
gr.update(choices=names, value=names),
names,
gr.update(value=default_length)
)
except Exception:
return gr.update(choices=[], value=[]), [], gr.update(value=256)
def filter_names(search_term, all_names):
if not all_names:
return gr.update(choices=[], value=[])
if not search_term:
return gr.update(choices=all_names, value=all_names)
lower = search_term.lower()
filtered = [n for n in all_names if lower in str(n).lower()]
return gr.update(choices=filtered, value=filtered)
def check_all(names_list):
return gr.update(value=names_list)
def uncheck_all(_):
return gr.update(value=[])
def get_default_forecast_length(file_path):
"""Get default forecast length based on filename"""
if file_path is None:
return 64
filename = os.path.basename(file_path)
if filename == "loop.csv" or filename == "ett2.csv":
return 256
elif filename == "air_passangers.csv":
return 48
else:
return 64
def display_filtered_forecast(file, preset_filename, selected_names, forecast_length):
try:
# 1) If no file or preset selected, show an error
if file is None and (preset_filename is None or preset_filename == "-- No preset selected --"):
return None, "No file selected."
# 2) Load DataFrame and remember which filename to pass to model_forecast
if file is not None:
df = load_table(file.name)
file_name = file.name
else:
df = load_table(preset_filename)
file_name = preset_filename
if df.shape[1]>2048:
df = df.iloc[:,-2048:]
gr.Info("Maximum of 2048 steps per timeseries (row) is allowed, hence last 2048 kept. ℹ️", duration=5)
# 3) Determine whether first column is names or numeric
if (
df.shape[1] > 0
and df.iloc[:, 0].dtype == object
and not df.iloc[:, 0].str.isnumeric().all()
):
all_names = df.iloc[:, 0].tolist()
data_only = df.iloc[:, 1:].astype(float)
else:
all_names = [f"Series {i}" for i in range(len(df))]
data_only = df.astype(float)
# 4) Build mask from selected_names
mask = [name in selected_names for name in all_names]
if not any(mask):
return None, "No timeseries chosen to plot."
filtered_data = data_only.iloc[mask, :].values # shape = (n_selected, seq_len)
filtered_names = [all_names[i] for i, m in enumerate(mask) if m]
n_selected = filtered_data.shape[0]
if n_selected>30:
raise gr.Error("Maximum of 30 timeseries (rows) is possible to choose", duration=5)
# 5) First call model_forecast on all series, then select only the masked rows
full_data = data_only.values # shape = (n_all, seq_len)
full_out = model_forecast(full_data, forecast_length=forecast_length, file_name=file_name)
# Now pick only the rows we actually filtered
out = full_out[mask, :, :] # shape = (n_selected, pred_len, n_q)
inp = torch.tensor(filtered_data)
# 6) Create one subplot per selected series, with vertical spacing
subplot_height_px = 350 # px per subplot
n_selected = len(filtered_names)
fig = make_subplots(
rows=n_selected,
cols=1,
shared_xaxes=False,
subplot_titles=filtered_names,
row_heights=[1] * n_selected, # all rows equal height
)
fig.update_layout(
height=subplot_height_px * n_selected,
template="plotly_dark",
margin=dict(t=50, b=50)
)
for idx in range(n_selected):
ts = inp[idx].numpy().tolist()
qp = out[idx].numpy()
series_name = filtered_names[idx]
# a) plot historical data (blue line)
x_hist = list(range(len(ts)))
fig.add_trace(
go.Scatter(
x=x_hist,
y=ts,
mode="lines",
name=f"{series_name} – Given Data",
line=dict(color="blue", width=2),
showlegend=False
),
row=idx + 1, col=1
)
# b) compute forecast indices
pred_len = qp.shape[0]
x_pred = list(range(len(ts) - 1, len(ts) - 1 + pred_len))
lower_q = qp[:, 0]
upper_q = qp[:, -1]
n_q = qp.shape[1]
median_idx = n_q // 2
median_q = qp[:, median_idx]
# c) lower‐bound (invisible)
fig.add_trace(
go.Scatter(
x=x_pred,
y=lower_q,
mode="lines",
line=dict(color="rgba(0,0,0,0)", width=0),
name=f"{series_name} – 10% Quantile",
hovertemplate="10% Quantile: %{y:.2f}<extra></extra>",
showlegend=False
),
row=idx + 1, col=1
)
# d) upper‐bound (shaded area)
fig.add_trace(
go.Scatter(
x=x_pred,
y=upper_q,
mode="lines",
line=dict(color="rgba(0,0,0,0)", width=0),
fill="tonexty",
fillcolor="rgba(128,128,128,0.3)",
name=f"{series_name} – 90% Quantile",
hovertemplate="90% Quantile: %{y:.2f}<extra></extra>",
showlegend=False
),
row=idx + 1, col=1
)
# e) median forecast (orange line)
fig.add_trace(
go.Scatter(
x=x_pred,
y=median_q,
mode="lines",
name=f"{series_name} – Median Forecast",
line=dict(color="orange", width=2),
hovertemplate="Median: %{y:.2f}<extra></extra>",
showlegend=False
),
row=idx + 1, col=1
)
# f) label axes for each subplot
fig.update_xaxes(title_text="Time", row=idx + 1, col=1)
fig.update_yaxes(title_text="Value", row=idx + 1, col=1)
# 7) Global layout tweaks
fig.update_layout(
template="plotly_dark",
height=300 * n_selected, # 300px per subplot
title=dict(
text="Forecasts for Selected Timeseries",
x=0.5,
font=dict(size=20, family="Arial", color="white")
),
hovermode="x unified",
margin=dict(t=120, b=40, l=60, r=40),
showlegend=False
)
return fig, ""
except gr.Error as e:
raise gr.Error(e, duration=5)
except Exception as e:
return None, f"Error: {str(e)}"
# ----------------------------
# Gradio layout: two columns + instructions
# ----------------------------
with gr.Blocks(fill_width=True,theme=gr.themes.Ocean()) as demo:
gr.Markdown("# 📈 TiRex - timeseries forecasting 📊")
gr.Markdown("Upload data or choose a preset, filter by name, then click Plot.")
with gr.Row():
# Left column: controls
with gr.Column(scale=1):
gr.Markdown("## Data Selection")
file_input = gr.File(
label="Upload CSV / XLSX / PARQUET",
file_types=[".csv", ".xls", ".xlsx", ".parquet"]
)
preset_choices = ["-- No preset selected --", "data/loop.csv", "data/air_passangers.csv", 'data/ett2.csv']
preset_dropdown = gr.Dropdown(
label="Or choose a preset:",
choices=preset_choices,
value="-- No preset selected --"
)
gr.Markdown("## Forecast Length Setting")
forecast_length_slider = gr.Slider(
minimum=1,
maximum=512,
value=64,
step=1,
label="Forecast Length (Steps)",
info="Choose how many future steps to forecast."
)
gr.Markdown("## Search / Filter")
search_box = gr.Textbox(placeholder="Type to filter (e.g. 'AMZN')")
filter_checkbox = gr.CheckboxGroup(
choices=[], value=[], label="Select which timeseries to show"
)
with gr.Row():
check_all_btn = gr.Button("✅ Check All")
uncheck_all_btn = gr.Button("❎ Uncheck All")
plot_button = gr.Button("▶️ Plot Forecasts")
errbox = gr.Textbox(label="Error Message", interactive=False)
with gr.Row():
gr.Image("static/nxai_logo.png", width=150, show_label=False, container=False)
gr.Image("static/tirex.jpeg", width=150, show_label=False, container=False)
with gr.Column(scale=5):
gr.Markdown("## Forecast Plot")
plot_output = gr.Plot()
# Instruction text below plot
gr.Markdown("## Instructions")
gr.Markdown(
"""
**How to format your data:**
- Your file must be a table (CSV, XLS, XLSX, or Parquet).
- **One row per timeseries.** Each row is treated as a separate series.
- If you want to **name** each series, put the name as the first value in **every** row:
- Example (CSV):
`AAPL, 120.5, 121.0, 119.8, ...`
`AMZN, 3300.0, 3310.5, 3295.2, ...`
- In that case, the first column is not numeric, so it will be used as the series name.
- If you do **not** want named series, simply leave out the first column entirely and have all values numeric:
- Example:
`120.5, 121.0, 119.8, ...`
`3300.0, 3310.5, 3295.2, ...`
- Then every row will be auto-named “Series 0, Series 1, …” in order.
- **Consistency rule:** Either all rows have a non-numeric first entry for the name, or none do. Do not mix.
- The rest of the columns (after the optional name) must be numeric data points for that series.
- You can filter by typing in the search box. Then check or uncheck individual names before plotting.
- Use “Check All” / “Uncheck All” to quickly select or deselect every series.
- Finally, click **Plot Forecasts** to view the quantile forecast for each selected series (for 64 steps ahead).
"""
)
gr.Markdown("## Citation")
# make citation as code block
gr.Markdown(
"""
If you use TiRex in your research, please cite our work:
```
@article{auerTiRexZeroShotForecasting2025,
title = {{{TiRex}}: {{Zero-Shot Forecasting Across Long}} and {{Short Horizons}} with {{Enhanced In-Context Learning}}},
author = {Auer, Andreas and Podest, Patrick and Klotz, Daniel and B{\"o}ck, Sebastian and Klambauer, G{\"u}nter and Hochreiter, Sepp},
journal = {ArXiv},
volume = {2505.23719},
year = {2025}
}
```
"""
)
names_state = gr.State([])
file_input.change(
fn=extract_names_and_update,
inputs=[file_input, preset_dropdown],
outputs=[filter_checkbox, names_state, forecast_length_slider]
)
preset_dropdown.change(
fn=extract_names_and_update,
inputs=[file_input, preset_dropdown],
outputs=[filter_checkbox, names_state, forecast_length_slider]
)
# When search term changes, filter names
search_box.change(
fn=filter_names,
inputs=[search_box, names_state],
outputs=[filter_checkbox]
)
# Check All / Uncheck All
check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox)
uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox)
# Plot button
plot_button.click(
fn=display_filtered_forecast,
inputs=[file_input, preset_dropdown, filter_checkbox, forecast_length_slider],
outputs=[plot_output, errbox]
)
demo.launch()
'''
gradio app.py
ssh -L 7860:localhost:7860 nikita_blago@oracle-gpu-controller -t \
ssh -L 7860:localhost:7860 compute-permanent-node-83
'''