Spaces:
Running
on
T4
Running
on
T4
import io | |
import pandas as pd | |
import torch | |
import plotly.graph_objects as go | |
from PIL import Image | |
import numpy as np | |
import gradio as gr | |
import os | |
from plotly.subplots import make_subplots | |
from tirex import load_model, ForecastModel | |
# ---------------------------- | |
# Helper functions (logic mostly unchanged) | |
# ---------------------------- | |
torch.manual_seed(42) | |
model: ForecastModel = load_model("NX-AI/TiRex",device='cuda') | |
def model_forecast(input_data, forecast_length=256, file_name=None): | |
if os.path.basename(file_name) == "loop.csv" and forecast_length==256: | |
_forecast_tensor = torch.load("data/loop_forecast_256.pt") | |
return _forecast_tensor | |
elif os.path.basename(file_name) == "ett2.csv" and forecast_length==256: | |
_forecast_tensor = torch.load("data/ett2_forecast_256.pt") | |
return _forecast_tensor | |
elif os.path.basename(file_name) == "air_passengers.csv"and forecast_length==24: | |
_forecast_tensor = torch.load("data/air_passengers_forecast_24.pt") | |
return _forecast_tensor | |
else: | |
forecast = model.forecast(context=input_data, prediction_length=forecast_length) | |
return forecast[0] | |
def load_table(file_path): | |
ext = file_path.split(".")[-1].lower() | |
if ext == "csv": | |
return pd.read_csv(file_path) | |
elif ext in ("xls", "xlsx"): | |
return pd.read_excel(file_path) | |
elif ext == "parquet": | |
return pd.read_parquet(file_path) | |
else: | |
raise ValueError("Unsupported format. Use CSV, XLS, XLSX, or PARQUET.") | |
def extract_names_and_update(file, preset_filename, transpose): | |
try: | |
# Determine which file to use and get default forecast length | |
if file is not None: | |
df = load_table(file.name) | |
default_length = get_default_forecast_length(file.name) | |
else: | |
if not preset_filename or preset_filename == "-- No preset selected --": | |
return gr.update(choices=[], value=[]), [], gr.update(value=256) | |
df = load_table(preset_filename) | |
default_length = get_default_forecast_length(preset_filename) | |
# if user wants to transpose, do it here | |
if transpose: | |
df = df.T | |
if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all(): | |
names = df.iloc[:, 0].tolist() | |
else: | |
names = [f"Series {i}" for i in range(len(df))] | |
return ( | |
gr.update(choices=names, value=names), | |
names, | |
gr.update(value=default_length) | |
) | |
except Exception: | |
return gr.update(choices=[], value=[]), [], gr.update(value=256) | |
def filter_names(search_term, all_names): | |
if not all_names: | |
return gr.update(choices=[], value=[]) | |
if not search_term: | |
return gr.update(choices=all_names, value=all_names) | |
lower = search_term.lower() | |
filtered = [n for n in all_names if lower in str(n).lower()] | |
return gr.update(choices=filtered, value=filtered) | |
def check_all(names_list): | |
return gr.update(value=names_list) | |
def uncheck_all(_): | |
return gr.update(value=[]) | |
def get_default_forecast_length(file_path): | |
"""Get default forecast length based on filename""" | |
if file_path is None: | |
return 64 | |
filename = os.path.basename(file_path) | |
if filename == "loop.csv" or filename == "ett2.csv": | |
return 256 | |
elif filename == "air_passengers.csv": | |
return 24 | |
else: | |
return 64 | |
def display_filtered_forecast(file, preset_filename, selected_names, forecast_length, transpose): | |
try: | |
# 1) If no file or preset selected, show an error | |
if file is None and (preset_filename is None or preset_filename == "-- No preset selected --"): | |
return None, "No file selected." | |
# 2) Load DataFrame and remember which filename to pass to model_forecast | |
if file is not None: | |
df = load_table(file.name) | |
file_name = file.name | |
else: | |
df = load_table(preset_filename) | |
file_name = preset_filename | |
if transpose: | |
df = df.T | |
# 3) Determine whether first column is names or numeric | |
if ( | |
df.shape[1] > 0 | |
and df.iloc[:, 0].dtype == object | |
and not df.iloc[:, 0].str.isnumeric().all() | |
): | |
if df.shape[1]>2048 and file is not None: | |
df = pd.concat([ df.iloc[:, [0]], df.iloc[:, -2048:] ], axis=1) | |
gr.Info("Maximum of 2048 steps per timeseries (row) is allowed, hence last 2048 kept. ℹ️", duration=5) | |
all_names = df.iloc[:, 0].tolist() | |
data_only_full = df.iloc[:, 1:].astype(float) | |
else: | |
if df.shape[1]>2048 and file is not None: | |
df = df.iloc[:, -2048:] | |
gr.Info("Maximum of 2048 steps per timeseries (row) is allowed, hence last 2048 kept. ℹ️", duration=5) | |
all_names = [f"Series {i}" for i in range(len(df))] | |
data_only_full = df.astype(float) | |
data_only = data_only_full | |
# 4) Build mask from selected_names | |
mask = [name in selected_names for name in all_names] | |
if not any(mask): | |
return None, "No timeseries chosen to plot." | |
filtered_data = data_only.iloc[mask, :].values # shape = (n_selected, seq_len) | |
filtered_data_only_full = data_only_full.iloc[mask, :].values # ** Added to show prediction accuracy | |
filtered_names = [all_names[i] for i, m in enumerate(mask) if m] | |
n_selected = filtered_data.shape[0] | |
if n_selected>30: | |
raise gr.Error("Maximum of 30 timeseries (rows) is possible to choose", duration=5) | |
# 5) First call model_forecast on all series, then select only the masked rows | |
full_data = data_only.values # shape = (n_all, seq_len) | |
if file is not None: | |
full_out = model_forecast(full_data, forecast_length=forecast_length, file_name=file_name) | |
else: | |
if preset_filename=='data/ett2.csv' or preset_filename=="data/loop.csv": | |
full_out = model_forecast(full_data[:, :2048], forecast_length=forecast_length, file_name=file_name) | |
elif preset_filename=="data/air_passengers.csv": | |
full_out = model_forecast(full_data[:, :132], forecast_length=forecast_length, file_name=file_name) | |
# Now pick only the rows we actually filtered | |
out = full_out[mask, :, :] # shape = (n_selected, pred_len, n_q) | |
inp = torch.tensor(filtered_data) | |
inp_full = torch.tensor(filtered_data_only_full) # ** Added to show prediction accuracy | |
# 6) Create one subplot per selected series, with vertical spacing | |
subplot_height_px = 350 # px per subplot | |
n_selected = len(filtered_names) | |
fig = make_subplots( | |
rows=n_selected, | |
cols=1, | |
shared_xaxes=False, | |
subplot_titles=filtered_names, | |
row_heights=[1] * n_selected, # all rows equal height | |
) | |
fig.update_layout( | |
height=subplot_height_px * n_selected, | |
template="plotly_dark", | |
margin=dict(t=50, b=50) | |
) | |
for idx in range(n_selected): | |
ts = inp[idx].numpy().tolist() | |
ts_full = inp_full[idx].numpy().tolist() | |
qp = out[idx].numpy() | |
series_name = filtered_names[idx] | |
pred_len = qp.shape[0] | |
if file is not None: | |
x_pred = list(range(len(ts), len(ts) + pred_len)) | |
else: | |
if preset_filename=='data/ett2.csv' or preset_filename=="data/loop.csv": | |
x_pred = list(range(2048, 2048 + pred_len)) | |
elif preset_filename=="data/air_passengers.csv": | |
x_pred = list(range(132, 132 + pred_len)) | |
# a) plot historical data (blue line) | |
x_hist = list(range(len(ts_full))) | |
if x_pred[-1]<x_hist[-1]: | |
diff = len(x_hist)-len(x_hist[:x_pred[-1]]) | |
x_hist = x_hist[:x_pred[-1]] | |
ts_full = ts_full[:-diff] | |
fig.add_trace( | |
go.Scatter( | |
x=x_hist, | |
y=ts_full, | |
mode="lines", | |
name=f"{series_name} – Given Data", | |
line=dict(color="blue", width=2), | |
showlegend=False | |
), | |
row=idx + 1, col=1 | |
) | |
# b) compute forecast indices | |
lower_q = qp[:, 0] | |
upper_q = qp[:, -1] | |
n_q = qp.shape[1] | |
median_idx = n_q // 2 | |
median_q = qp[:, median_idx] | |
# c) lower‐bound (invisible) | |
fig.add_trace( | |
go.Scatter( | |
x=x_pred, | |
y=lower_q, | |
mode="lines", | |
line=dict(color="rgba(0,0,0,0)", width=0), | |
name=f"{series_name} – 10% Quantile", | |
hovertemplate="10% Quantile: %{y:.2f}<extra></extra>", | |
showlegend=False | |
), | |
row=idx + 1, col=1 | |
) | |
# d) upper‐bound (shaded area) | |
fig.add_trace( | |
go.Scatter( | |
x=x_pred, | |
y=upper_q, | |
mode="lines", | |
line=dict(color="rgba(0,0,0,0)", width=0), | |
fill="tonexty", | |
fillcolor="rgba(128,128,128,0.3)", | |
name=f"{series_name} – 90% Quantile", | |
hovertemplate="90% Quantile: %{y:.2f}<extra></extra>", | |
showlegend=False | |
), | |
row=idx + 1, col=1 | |
) | |
# e) median forecast (orange line) | |
fig.add_trace( | |
go.Scatter( | |
x=x_pred, | |
y=median_q, | |
mode="lines", | |
name=f"{series_name} – Median Forecast", | |
line=dict(color="orange", width=2), | |
hovertemplate="Median: %{y:.2f}<extra></extra>", | |
showlegend=False | |
), | |
row=idx + 1, col=1 | |
) | |
# f) label axes for each subplot | |
fig.update_xaxes(title_text="Time", row=idx + 1, col=1) | |
fig.update_yaxes(title_text="Value", row=idx + 1, col=1) | |
# 7) Global layout tweaks | |
fig.update_layout( | |
template="plotly_dark", | |
height=300 * n_selected, # 300px per subplot | |
title=dict( | |
text="Forecasts for Selected Timeseries", | |
x=0.5, | |
font=dict(size=20, family="Arial", color="white") | |
), | |
hovermode="x unified", | |
margin=dict(t=120, b=40, l=60, r=40), | |
showlegend=False | |
) | |
return fig, "" | |
except gr.Error as e: | |
raise gr.Error(e, duration=5) | |
except Exception as e: | |
return None, f"Error: {str(e)}" | |
# ---------------------------- | |
# Gradio layout: two columns + instructions | |
# ---------------------------- | |
with gr.Blocks(fill_width=True,theme=gr.themes.Ocean()) as demo: | |
gr.Markdown("# 📈 TiRex - timeseries forecasting 📊") | |
gr.Markdown("Upload data or choose a preset, filter by name, then click Plot.") | |
with gr.Row(): | |
# Left column: controls | |
with gr.Column(scale=1): | |
gr.Markdown("## Data Selection") | |
file_input = gr.File( | |
label="Upload CSV / XLSX / PARQUET", | |
file_types=[".csv", ".xls", ".xlsx", ".parquet"] | |
) | |
preset_choices = ["-- No preset selected --", "data/loop.csv", "data/air_passengers.csv", 'data/ett2.csv'] | |
preset_dropdown = gr.Dropdown( | |
label="Or choose a preset:", | |
choices=preset_choices, | |
value="-- No preset selected --" | |
) | |
gr.Markdown("## Forecast Length Setting") | |
forecast_length_slider = gr.Slider( | |
minimum=1, | |
maximum=512, | |
value=64, | |
step=1, | |
label="Forecast Length (Steps)", | |
info="Choose how many future steps to forecast." | |
) | |
gr.Markdown("## Transpose data") | |
transpose_switch = gr.Checkbox( | |
label="Transpose data (Click if your columns are timeseries)", | |
value=False | |
) | |
gr.Markdown("## Search / Filter") | |
search_box = gr.Textbox(placeholder="Type to filter (e.g. 'AMZN')") | |
filter_checkbox = gr.CheckboxGroup( | |
choices=[], value=[], label="Select which timeseries to show" | |
) | |
with gr.Row(): | |
check_all_btn = gr.Button("✅ Check All") | |
uncheck_all_btn = gr.Button("❎ Uncheck All") | |
plot_button = gr.Button("▶️ Plot Forecasts") | |
errbox = gr.Textbox(label="Error Message", interactive=False) | |
with gr.Row(): | |
gr.Image("static/nxai_logo.png", width=150, show_label=False, container=False) | |
gr.Image("static/tirex.jpeg", width=150, show_label=False, container=False) | |
with gr.Column(scale=5): | |
gr.Markdown("## Forecast Plot") | |
plot_output = gr.Plot() | |
# Instruction text below plot | |
gr.Markdown("## Instructions") | |
gr.Markdown( | |
""" | |
**How to format your data:** | |
- Your file must be a table (CSV, XLS, XLSX, or Parquet). | |
- **One row per timeseries.** Each row is treated as a separate series. | |
- If you want to **name** each series, put the name as the first value in **every** row: | |
- Example (CSV): | |
`AAPL, 120.5, 121.0, 119.8, ...` | |
`AMZN, 3300.0, 3310.5, 3295.2, ...` | |
- In that case, the first column is not numeric, so it will be used as the series name. | |
- If you do **not** want named series, simply leave out the first column entirely and have all values numeric: | |
- Example: | |
`120.5, 121.0, 119.8, ...` | |
`3300.0, 3310.5, 3295.2, ...` | |
- Then every row will be auto-named “Series 0, Series 1, …” in order. | |
- **Consistency rule:** Either all rows have a non-numeric first entry for the name, or none do. Do not mix. | |
- The rest of the columns (after the optional name) must be numeric data points for that series. | |
- You can filter by typing in the search box. Then check or uncheck individual names before plotting. | |
- Use “Check All” / “Uncheck All” to quickly select or deselect every series. | |
- Finally, click **Plot Forecasts** to view the quantile forecast for each selected series (for 64 steps ahead). | |
""" | |
) | |
gr.Markdown("## Citation") | |
# make citation as code block | |
gr.Markdown( | |
""" | |
If you use TiRex in your research, please cite our work: | |
``` | |
@article{auerTiRexZeroShotForecasting2025, | |
title = {{{TiRex}}: {{Zero-Shot Forecasting Across Long}} and {{Short Horizons}} with {{Enhanced In-Context Learning}}}, | |
author = {Auer, Andreas and Podest, Patrick and Klotz, Daniel and B{\"o}ck, Sebastian and Klambauer, G{\"u}nter and Hochreiter, Sepp}, | |
journal = {ArXiv}, | |
volume = {2505.23719}, | |
year = {2025} | |
} | |
``` | |
""" | |
) | |
names_state = gr.State([]) | |
file_input.change( | |
fn=extract_names_and_update, | |
inputs=[file_input, preset_dropdown, transpose_switch], | |
outputs=[filter_checkbox, names_state, forecast_length_slider] | |
) | |
preset_dropdown.change( | |
fn=extract_names_and_update, | |
inputs=[file_input, preset_dropdown, transpose_switch], | |
outputs=[filter_checkbox, names_state, forecast_length_slider] | |
) | |
# When search term changes, filter names | |
search_box.change( | |
fn=filter_names, | |
inputs=[search_box, names_state], | |
outputs=[filter_checkbox] | |
) | |
transpose_switch.change( | |
fn=extract_names_and_update, | |
inputs=[file_input, preset_dropdown, transpose_switch], | |
outputs=[filter_checkbox, names_state, forecast_length_slider] | |
) | |
# Check All / Uncheck All | |
check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox) | |
uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox) | |
# Plot button | |
plot_button.click( | |
fn=display_filtered_forecast, | |
inputs=[file_input, preset_dropdown, filter_checkbox, forecast_length_slider, transpose_switch], | |
outputs=[plot_output, errbox] | |
) | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |
''' | |
gradio app.py | |
ssh -L 7860:localhost:7860 nikita_blago@oracle-gpu-controller -t \ | |
ssh -L 7860:localhost:7860 compute-permanent-node-303 | |
''' |