Spaces:
Running
on
T4
Running
on
T4
import io | |
import pandas as pd | |
import torch | |
import plotly.graph_objects as go | |
from PIL import Image | |
import numpy as np | |
import gradio as gr | |
import os | |
from tirex import load_model, ForecastModel | |
# ---------------------------- | |
# Helper functions (logic mostly unchanged) | |
# ---------------------------- | |
torch.manual_seed(42) | |
def model_forecast(input_data, forecast_length=256, file_name=None): | |
if os.path.basename(file_name) == "merged_ett2_loop.csv": | |
_forecast_tensor = torch.load("data/merged_ett2_loop_forecast_256.pt") | |
return _forecast_tensor[:,:forecast_length,:] | |
elif os.path.basename(file_name) == "air_passangers.csv": | |
_forecast_tensor = torch.load("data/air_passengers_forecast_256.pt") | |
return _forecast_tensor[:,:forecast_length,:] | |
else: | |
model: ForecastModel = load_model("NX-AI/TiRex",device='cuda') | |
forecast = model.forecast(context=input_data, prediction_length=forecast_length) | |
return forecast[0] | |
def plot_forecast_plotly(timeseries, quantile_predictions, timeseries_name): | |
""" | |
- timeseries: 1D list/array of historical values. | |
- quantile_predictions: 2D array of shape (pred_len, n_q), | |
with quantiles sorted left→right. | |
- timeseries_name: string label. | |
""" | |
fig = go.Figure() | |
# 1) Plot historical data (blue line, no markers) | |
x_hist = list(range(len(timeseries))) | |
fig.add_trace(go.Scatter( | |
x=x_hist, | |
y=timeseries, | |
mode="lines", # no markers | |
name=f"{timeseries_name} – Given Data", | |
line=dict(color="blue", width=2), | |
)) | |
# 2) X-axis indices for forecasts | |
pred_len = quantile_predictions.shape[0] | |
x_pred = list(range(len(timeseries) - 1, len(timeseries) - 1 + pred_len)) | |
# 3) Extract lower, upper, and median quantiles | |
lower_q = quantile_predictions[:, 0] | |
upper_q = quantile_predictions[:, -1] | |
n_q = quantile_predictions.shape[1] | |
median_idx = n_q // 2 | |
median_q = quantile_predictions[:, median_idx] | |
# 4) Lower‐bound trace (invisible line, still shows on hover) | |
fig.add_trace(go.Scatter( | |
x=x_pred, | |
y=lower_q, | |
mode="lines", | |
line=dict(color="rgba(0, 0, 0, 0)", width=0), | |
name=f"{timeseries_name} – Lower Bound", | |
hovertemplate="Lower: %{y:.2f}<extra></extra>" | |
)) | |
# 5) Upper‐bound trace (shaded down to lower_q) | |
fig.add_trace(go.Scatter( | |
x=x_pred, | |
y=upper_q, | |
mode="lines", | |
line=dict(color="rgba(0, 0, 0, 0)", width=0), | |
fill="tonexty", | |
fillcolor="rgba(128, 128, 128, 0.3)", | |
name=f"{timeseries_name} – Upper Bound", | |
hovertemplate="Upper: %{y:.2f}<extra></extra>" | |
)) | |
# 6) Median trace (orange) on top | |
fig.add_trace(go.Scatter( | |
x=x_pred, | |
y=median_q, | |
mode="lines", | |
name=f"{timeseries_name} – Median Forecast", | |
line=dict(color="orange", width=2), | |
hovertemplate="Median: %{y:.2f}<extra></extra>" | |
)) | |
# 7) Layout: title on left (y=0.95), legend on right (y=0.95) | |
fig.update_layout( | |
template="plotly_dark", | |
title=dict( | |
text=f"Timeseries: {timeseries_name}", | |
x=0.10, # left‐align | |
xanchor="left", | |
y=0.90, # near top | |
yanchor="bottom", | |
font=dict(size=18, family="Arial", color="white") | |
), | |
xaxis=dict( | |
rangeslider=dict(visible=True), # <-- put rangeslider here | |
fixedrange=False | |
), | |
xaxis_title="Time", | |
yaxis_title="Value", | |
hovermode="x unified", | |
margin=dict( | |
t=120, # increase top margin to fit title+legend comfortably | |
b=40, | |
l=60, | |
r=40 | |
), | |
# height=plot_height, | |
# width=plot_width, | |
autosize=True, | |
) | |
return fig | |
def load_table(file_path): | |
ext = file_path.split(".")[-1].lower() | |
if ext == "csv": | |
return pd.read_csv(file_path) | |
elif ext in ("xls", "xlsx"): | |
return pd.read_excel(file_path) | |
elif ext == "parquet": | |
return pd.read_parquet(file_path) | |
else: | |
raise ValueError("Unsupported format. Use CSV, XLS, XLSX, or PARQUET.") | |
# def extract_names_and_update(file, preset_filename): | |
# try: | |
# if file is not None: | |
# df = load_table(file.name) | |
# else: | |
# if not preset_filename: | |
# return gr.update(choices=[], value=[]), [] | |
# df = load_table(preset_filename) | |
# if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all(): | |
# names = df.iloc[:, 0].tolist() | |
# else: | |
# names = [f"Series {i}" for i in range(len(df))] | |
# return gr.update(choices=names, value=names), names | |
# except Exception: | |
# return gr.update(choices=[], value=[]), [] | |
def extract_names_and_update(file, preset_filename): | |
try: | |
# Determine which file to use and get default forecast length | |
if file is not None: | |
df = load_table(file.name) | |
default_length = get_default_forecast_length(file.name) | |
else: | |
if not preset_filename or preset_filename == "-- No preset selected --": | |
return gr.update(choices=[], value=[]), [], gr.update(value=256) | |
df = load_table(preset_filename) | |
default_length = get_default_forecast_length(preset_filename) | |
if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all(): | |
names = df.iloc[:, 0].tolist() | |
else: | |
names = [f"Series {i}" for i in range(len(df))] | |
return ( | |
gr.update(choices=names, value=names), | |
names, | |
gr.update(value=default_length) | |
) | |
except Exception: | |
return gr.update(choices=[], value=[]), [], gr.update(value=256) | |
def filter_names(search_term, all_names): | |
if not all_names: | |
return gr.update(choices=[], value=[]) | |
if not search_term: | |
return gr.update(choices=all_names, value=all_names) | |
lower = search_term.lower() | |
filtered = [n for n in all_names if lower in str(n).lower()] | |
return gr.update(choices=filtered, value=filtered) | |
def check_all(names_list): | |
return gr.update(value=names_list) | |
def uncheck_all(_): | |
return gr.update(value=[]) | |
def get_default_forecast_length(file_path): | |
"""Get default forecast length based on filename""" | |
if file_path is None: | |
return 64 | |
filename = os.path.basename(file_path) | |
if filename == "merged_ett2_loop.csv": | |
return 256 | |
elif filename == "air_passangers.csv": | |
return 48 | |
else: | |
return 64 | |
def display_filtered_forecast(file, preset_filename, selected_names, forecast_length): | |
try: | |
# If no file uploaded and no valid preset chosen, return early | |
if file is None and (preset_filename is None or preset_filename == "-- No preset selected --"): | |
return None, "No file selected." | |
# Load data | |
if file is not None: | |
df = load_table(file.name) | |
else: | |
df = load_table(preset_filename) | |
# Determine names vs numeric data | |
if ( | |
df.shape[1] > 0 | |
and df.iloc[:, 0].dtype == object | |
and not df.iloc[:, 0].str.isnumeric().all() | |
): | |
all_names = df.iloc[:, 0].tolist() | |
data_only = df.iloc[:, 1:].astype(float) | |
else: | |
all_names = [f"Series {i}" for i in range(len(df))] | |
data_only = df.astype(float) | |
# Build a boolean mask for selected series | |
mask = [name in selected_names for name in all_names] | |
if not any(mask): | |
return None, "No timeseries chosen to plot." | |
# Extract the filtered historical data (numpy array of shape (n_selected, seq_len)) | |
filtered_data = data_only.iloc[mask, :].values | |
filtered_names = [all_names[i] for i, m in enumerate(mask) if m] | |
# Get the full forecast tensor (n_series, pred_len, n_q) | |
file_path = file.name if file is not None else preset_filename | |
_forecast_tensor = model_forecast(filtered_data, forecast_length=forecast_length, file_name=file_path) | |
# Slice out only those rows corresponding to `mask` | |
out = _forecast_tensor[mask] # shape = (n_selected, pred_len, n_q) | |
inp = torch.tensor(filtered_data) # shape = (n_selected, seq_len) | |
# If only one series is selected, we can just call plot_forecast_plotly directly: | |
if inp.shape[0] == 1: | |
ts = inp[0].numpy().tolist() | |
qp = out[0].numpy() | |
fig = plot_forecast_plotly(ts, qp, filtered_names[0]) | |
return fig, "" | |
# If multiple series are selected, build a master figure by concatenating traces | |
master_fig = go.Figure() | |
for idx in range(inp.shape[0]): | |
ts = inp[idx].numpy().tolist() | |
qp = out[idx].numpy() | |
series_name = filtered_names[idx] | |
# Get a “per‐series” figure | |
small_fig = plot_forecast_plotly(ts, qp, series_name) | |
# Append each trace from small_fig into master_fig | |
for trace in small_fig.data: | |
master_fig.add_trace(trace) | |
# Finally, update the layout of master_fig (you can tweak the title if you want) | |
master_fig.update_layout( | |
template="plotly_dark", | |
title=dict( | |
text="Forecasts for Selected Timeseries", | |
x=0.5, | |
font=dict(size=16, family="Arial", color="white") | |
), | |
xaxis=dict( | |
rangeslider=dict(visible=True), # <-- put rangeslider here | |
fixedrange=False | |
), | |
xaxis_title="Time", | |
yaxis_title="Value", | |
hovermode="x unified", | |
# height=plot_height, | |
# width=plot_width # ← add these here | |
autosize=True, | |
) | |
return master_fig, "" | |
except Exception as e: | |
return None, f"Error: {e}. Use CSV, XLS, XLSX, or PARQUET." | |
# ---------------------------- | |
# Gradio layout: two columns + instructions | |
# ---------------------------- | |
with gr.Blocks(fill_width=True,theme=gr.themes.Ocean()) as demo: | |
gr.Markdown("# 📈 TiRex - timeseries forecasting 📊") | |
gr.Markdown("Upload data or choose a preset, filter by name, then click Plot.") | |
with gr.Row(): | |
# Left column: controls | |
with gr.Column(scale=1): | |
gr.Markdown("## Data Selection") | |
file_input = gr.File( | |
label="Upload CSV / XLSX / PARQUET", | |
file_types=[".csv", ".xls", ".xlsx", ".parquet"] | |
) | |
preset_choices = ["-- No preset selected --", "data/merged_ett2_loop.csv", "data/air_passangers.csv"] | |
preset_dropdown = gr.Dropdown( | |
label="Or choose a preset:", | |
choices=preset_choices, | |
value="-- No preset selected --" | |
) | |
gr.Markdown("## Forecast Length Setting") | |
forecast_length_slider = gr.Slider( | |
minimum=1, | |
maximum=512, | |
value=64, | |
step=1, | |
label="Forecast Length (Steps)", | |
info="Choose how many future steps to forecast." | |
) | |
gr.Markdown("## Search / Filter") | |
search_box = gr.Textbox(placeholder="Type to filter (e.g. 'AMZN')") | |
filter_checkbox = gr.CheckboxGroup( | |
choices=[], value=[], label="Select which timeseries to show" | |
) | |
with gr.Row(): | |
check_all_btn = gr.Button("✅ Check All") | |
uncheck_all_btn = gr.Button("❎ Uncheck All") | |
plot_button = gr.Button("▶️ Plot Forecasts") | |
errbox = gr.Textbox(label="Error Message", interactive=False) | |
with gr.Row(): | |
gr.Image("static/nxai_logo.png", width=150, show_label=False, container=False) | |
gr.Image("static/tirex.jpeg", width=150, show_label=False, container=False) | |
# Right column: interactive plot + instructions | |
with gr.Column(scale=5): | |
gr.Markdown("## Forecast Plot") | |
plot_output = gr.Plot() | |
# Instruction text below plot | |
gr.Markdown("## Instructions") | |
gr.Markdown( | |
""" | |
**How to format your data:** | |
- Your file must be a table (CSV, XLS, XLSX, or Parquet). | |
- **One row per timeseries.** Each row is treated as a separate series. | |
- If you want to **name** each series, put the name as the first value in **every** row: | |
- Example (CSV): | |
`AAPL, 120.5, 121.0, 119.8, ...` | |
`AMZN, 3300.0, 3310.5, 3295.2, ...` | |
- In that case, the first column is not numeric, so it will be used as the series name. | |
- If you do **not** want named series, simply leave out the first column entirely and have all values numeric: | |
- Example: | |
`120.5, 121.0, 119.8, ...` | |
`3300.0, 3310.5, 3295.2, ...` | |
- Then every row will be auto-named “Series 0, Series 1, …” in order. | |
- **Consistency rule:** Either all rows have a non-numeric first entry for the name, or none do. Do not mix. | |
- The rest of the columns (after the optional name) must be numeric data points for that series. | |
- You can filter by typing in the search box. Then check or uncheck individual names before plotting. | |
- Use “Check All” / “Uncheck All” to quickly select or deselect every series. | |
- Finally, click **Plot Forecasts** to view the quantile forecast for each selected series (for 64 steps ahead). | |
""" | |
) | |
gr.Markdown("## Citation") | |
# make citation as code block | |
gr.Markdown( | |
""" | |
If you use TiRex in your research, please cite our work: | |
``` | |
@article{auerTiRexZeroShotForecasting2025, | |
title = {{{TiRex}}: {{Zero-Shot Forecasting Across Long}} and {{Short Horizons}} with {{Enhanced In-Context Learning}}}, | |
author = {Auer, Andreas and Podest, Patrick and Klotz, Daniel and B{\"o}ck, Sebastian and Klambauer, G{\"u}nter and Hochreiter, Sepp}, | |
journal = {ArXiv}, | |
volume = {2505.23719}, | |
year = {2025} | |
} | |
``` | |
""" | |
) | |
names_state = gr.State([]) | |
# # When file or preset changes, update names | |
# file_input.change( | |
# fn=extract_names_and_update, | |
# inputs=[file_input, preset_dropdown], | |
# outputs=[filter_checkbox, names_state] | |
# ) | |
# preset_dropdown.change( | |
# fn=extract_names_and_update, | |
# inputs=[file_input, preset_dropdown], | |
# outputs=[filter_checkbox, names_state] | |
# ) | |
file_input.change( | |
fn=extract_names_and_update, | |
inputs=[file_input, preset_dropdown], | |
outputs=[filter_checkbox, names_state, forecast_length_slider] | |
) | |
preset_dropdown.change( | |
fn=extract_names_and_update, | |
inputs=[file_input, preset_dropdown], | |
outputs=[filter_checkbox, names_state, forecast_length_slider] | |
) | |
# When search term changes, filter names | |
search_box.change( | |
fn=filter_names, | |
inputs=[search_box, names_state], | |
outputs=[filter_checkbox] | |
) | |
# Check All / Uncheck All | |
check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox) | |
uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox) | |
# Plot button | |
plot_button.click( | |
fn=display_filtered_forecast, | |
inputs=[file_input, preset_dropdown, filter_checkbox, forecast_length_slider], # <-- add slider here | |
outputs=[plot_output, errbox] | |
) | |
demo.launch() | |