import io import pandas as pd import torch import plotly.graph_objects as go from PIL import Image import numpy as np import gradio as gr import os from plotly.subplots import make_subplots from tirex import load_model, ForecastModel # ---------------------------- # Helper functions (logic mostly unchanged) # ---------------------------- torch.manual_seed(42) model: ForecastModel = load_model("NX-AI/TiRex",device='cuda') def model_forecast(input_data, forecast_length=256, file_name=None): if os.path.basename(file_name) == "loop.csv": _forecast_tensor = torch.load("data/loop_forecast_512.pt") return _forecast_tensor[:,:forecast_length,:] elif os.path.basename(file_name) == "ett2.csv": _forecast_tensor = torch.load("data/ett2_forecast_512.pt") return _forecast_tensor[:,:forecast_length,:] elif os.path.basename(file_name) == "air_passangers.csv": _forecast_tensor = torch.load("data/air_passengers_forecast_512.pt") return _forecast_tensor[:,:forecast_length,:] else: forecast = model.forecast(context=input_data, prediction_length=forecast_length) return forecast[0] def plot_forecast_plotly(timeseries, quantile_predictions, timeseries_name): """ - timeseries: 1D list/array of historical values. - quantile_predictions: 2D array of shape (pred_len, n_q), with quantiles sorted left→right. - timeseries_name: string label. """ fig = go.Figure() # 1) Plot historical data (blue line, no markers) x_hist = list(range(len(timeseries))) fig.add_trace(go.Scatter( x=x_hist, y=timeseries, mode="lines", # no markers name=f"{timeseries_name} – Given Data", line=dict(color="blue", width=2), )) # 2) X-axis indices for forecasts pred_len = quantile_predictions.shape[0] x_pred = list(range(len(timeseries) - 1, len(timeseries) - 1 + pred_len)) # 3) Extract lower, upper, and median quantiles lower_q = quantile_predictions[:, 0] upper_q = quantile_predictions[:, -1] n_q = quantile_predictions.shape[1] median_idx = n_q // 2 median_q = quantile_predictions[:, median_idx] # 4) Lower‐bound trace (invisible line, still shows on hover) fig.add_trace(go.Scatter( x=x_pred, y=lower_q, mode="lines", line=dict(color="rgba(0, 0, 0, 0)", width=0), name=f"{timeseries_name} – 10% Quantile", hovertemplate="Lower: %{y:.2f}" )) # 5) Upper‐bound trace (shaded down to lower_q) fig.add_trace(go.Scatter( x=x_pred, y=upper_q, mode="lines", line=dict(color="rgba(0, 0, 0, 0)", width=0), fill="tonexty", fillcolor="rgba(128, 128, 128, 0.3)", name=f"{timeseries_name} – 90% Quantile", hovertemplate="Upper: %{y:.2f}" )) # 6) Median trace (orange) on top fig.add_trace(go.Scatter( x=x_pred, y=median_q, mode="lines", name=f"{timeseries_name} – Median Forecast", line=dict(color="orange", width=2), hovertemplate="Median: %{y:.2f}" )) # 7) Layout: title on left (y=0.95), legend on right (y=0.95) fig.update_layout( template="plotly_dark", title=dict( text=f"Timeseries: {timeseries_name}", x=0.10, # left‐align xanchor="left", y=0.90, # near top yanchor="bottom", font=dict(size=18, family="Arial", color="white") ), xaxis=dict( rangeslider=dict(visible=True), # <-- put rangeslider here fixedrange=False ), xaxis_title="Time", yaxis_title="Value", hovermode="x unified", margin=dict( t=120, # increase top margin to fit title+legend comfortably b=40, l=60, r=40 ), # height=plot_height, # width=plot_width, autosize=True, ) return fig def load_table(file_path): ext = file_path.split(".")[-1].lower() if ext == "csv": return pd.read_csv(file_path) elif ext in ("xls", "xlsx"): return pd.read_excel(file_path) elif ext == "parquet": return pd.read_parquet(file_path) else: raise ValueError("Unsupported format. Use CSV, XLS, XLSX, or PARQUET.") def extract_names_and_update(file, preset_filename): try: # Determine which file to use and get default forecast length if file is not None: df = load_table(file.name) default_length = get_default_forecast_length(file.name) else: if not preset_filename or preset_filename == "-- No preset selected --": return gr.update(choices=[], value=[]), [], gr.update(value=256) df = load_table(preset_filename) default_length = get_default_forecast_length(preset_filename) if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all(): names = df.iloc[:, 0].tolist() else: names = [f"Series {i}" for i in range(len(df))] return ( gr.update(choices=names, value=names), names, gr.update(value=default_length) ) except Exception: return gr.update(choices=[], value=[]), [], gr.update(value=256) def filter_names(search_term, all_names): if not all_names: return gr.update(choices=[], value=[]) if not search_term: return gr.update(choices=all_names, value=all_names) lower = search_term.lower() filtered = [n for n in all_names if lower in str(n).lower()] return gr.update(choices=filtered, value=filtered) def check_all(names_list): return gr.update(value=names_list) def uncheck_all(_): return gr.update(value=[]) def get_default_forecast_length(file_path): """Get default forecast length based on filename""" if file_path is None: return 64 filename = os.path.basename(file_path) if filename == "loop.csv" or filename == "ett2.csv": return 256 elif filename == "air_passangers.csv": return 48 else: return 64 def display_filtered_forecast(file, preset_filename, selected_names, forecast_length): try: # 1) If no file or preset selected, show an error if file is None and (preset_filename is None or preset_filename == "-- No preset selected --"): return None, "No file selected." # 2) Load DataFrame and remember which filename to pass to model_forecast if file is not None: df = load_table(file.name) file_name = file.name else: df = load_table(preset_filename) file_name = preset_filename if df.shape[1]>2048: df = df.iloc[:,-2048:] gr.Info("Maximum of 2048 steps per timeseries (row) is allowed, hence last 2048 kept. ℹ️", duration=5) # 3) Determine whether first column is names or numeric if ( df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all() ): all_names = df.iloc[:, 0].tolist() data_only = df.iloc[:, 1:].astype(float) else: all_names = [f"Series {i}" for i in range(len(df))] data_only = df.astype(float) # 4) Build mask from selected_names mask = [name in selected_names for name in all_names] if not any(mask): return None, "No timeseries chosen to plot." filtered_data = data_only.iloc[mask, :].values # shape = (n_selected, seq_len) filtered_names = [all_names[i] for i, m in enumerate(mask) if m] n_selected = filtered_data.shape[0] if n_selected>30: raise gr.Error("Maximum of 30 timeseries (rows) is possible to choose", duration=5) # 5) First call model_forecast on all series, then select only the masked rows full_data = data_only.values # shape = (n_all, seq_len) full_out = model_forecast(full_data, forecast_length=forecast_length, file_name=file_name) # Now pick only the rows we actually filtered out = full_out[mask, :, :] # shape = (n_selected, pred_len, n_q) inp = torch.tensor(filtered_data) # 6) Create one subplot per selected series, with vertical spacing subplot_height_px = 350 # px per subplot n_selected = len(filtered_names) fig = make_subplots( rows=n_selected, cols=1, shared_xaxes=False, subplot_titles=filtered_names, row_heights=[1] * n_selected, # all rows equal height ) fig.update_layout( height=subplot_height_px * n_selected, template="plotly_dark", margin=dict(t=50, b=50) ) for idx in range(n_selected): ts = inp[idx].numpy().tolist() qp = out[idx].numpy() series_name = filtered_names[idx] # a) plot historical data (blue line) x_hist = list(range(len(ts))) fig.add_trace( go.Scatter( x=x_hist, y=ts, mode="lines", name=f"{series_name} – Given Data", line=dict(color="blue", width=2), showlegend=False ), row=idx + 1, col=1 ) # b) compute forecast indices pred_len = qp.shape[0] x_pred = list(range(len(ts) - 1, len(ts) - 1 + pred_len)) lower_q = qp[:, 0] upper_q = qp[:, -1] n_q = qp.shape[1] median_idx = n_q // 2 median_q = qp[:, median_idx] # c) lower‐bound (invisible) fig.add_trace( go.Scatter( x=x_pred, y=lower_q, mode="lines", line=dict(color="rgba(0,0,0,0)", width=0), name=f"{series_name} – 10% Quantile", hovertemplate="10% Quantile: %{y:.2f}", showlegend=False ), row=idx + 1, col=1 ) # d) upper‐bound (shaded area) fig.add_trace( go.Scatter( x=x_pred, y=upper_q, mode="lines", line=dict(color="rgba(0,0,0,0)", width=0), fill="tonexty", fillcolor="rgba(128,128,128,0.3)", name=f"{series_name} – 90% Quantile", hovertemplate="90% Quantile: %{y:.2f}", showlegend=False ), row=idx + 1, col=1 ) # e) median forecast (orange line) fig.add_trace( go.Scatter( x=x_pred, y=median_q, mode="lines", name=f"{series_name} – Median Forecast", line=dict(color="orange", width=2), hovertemplate="Median: %{y:.2f}", showlegend=False ), row=idx + 1, col=1 ) # f) label axes for each subplot fig.update_xaxes(title_text="Time", row=idx + 1, col=1) fig.update_yaxes(title_text="Value", row=idx + 1, col=1) # 7) Global layout tweaks fig.update_layout( template="plotly_dark", height=300 * n_selected, # 300px per subplot title=dict( text="Forecasts for Selected Timeseries", x=0.5, font=dict(size=20, family="Arial", color="white") ), hovermode="x unified", margin=dict(t=120, b=40, l=60, r=40), showlegend=False ) return fig, "" except gr.Error as e: raise gr.Error(e, duration=5) except Exception as e: return None, f"Error: {str(e)}" # ---------------------------- # Gradio layout: two columns + instructions # ---------------------------- with gr.Blocks(fill_width=True,theme=gr.themes.Ocean()) as demo: gr.Markdown("# 📈 TiRex - timeseries forecasting 📊") gr.Markdown("Upload data or choose a preset, filter by name, then click Plot.") with gr.Row(): # Left column: controls with gr.Column(scale=1): gr.Markdown("## Data Selection") file_input = gr.File( label="Upload CSV / XLSX / PARQUET", file_types=[".csv", ".xls", ".xlsx", ".parquet"] ) preset_choices = ["-- No preset selected --", "data/loop.csv", "data/air_passangers.csv", 'data/ett2.csv'] preset_dropdown = gr.Dropdown( label="Or choose a preset:", choices=preset_choices, value="-- No preset selected --" ) gr.Markdown("## Forecast Length Setting") forecast_length_slider = gr.Slider( minimum=1, maximum=512, value=64, step=1, label="Forecast Length (Steps)", info="Choose how many future steps to forecast." ) gr.Markdown("## Search / Filter") search_box = gr.Textbox(placeholder="Type to filter (e.g. 'AMZN')") filter_checkbox = gr.CheckboxGroup( choices=[], value=[], label="Select which timeseries to show" ) with gr.Row(): check_all_btn = gr.Button("✅ Check All") uncheck_all_btn = gr.Button("❎ Uncheck All") plot_button = gr.Button("▶️ Plot Forecasts") errbox = gr.Textbox(label="Error Message", interactive=False) with gr.Row(): gr.Image("static/nxai_logo.png", width=150, show_label=False, container=False) gr.Image("static/tirex.jpeg", width=150, show_label=False, container=False) with gr.Column(scale=5): gr.Markdown("## Forecast Plot") plot_output = gr.Plot() # Instruction text below plot gr.Markdown("## Instructions") gr.Markdown( """ **How to format your data:** - Your file must be a table (CSV, XLS, XLSX, or Parquet). - **One row per timeseries.** Each row is treated as a separate series. - If you want to **name** each series, put the name as the first value in **every** row: - Example (CSV): `AAPL, 120.5, 121.0, 119.8, ...` `AMZN, 3300.0, 3310.5, 3295.2, ...` - In that case, the first column is not numeric, so it will be used as the series name. - If you do **not** want named series, simply leave out the first column entirely and have all values numeric: - Example: `120.5, 121.0, 119.8, ...` `3300.0, 3310.5, 3295.2, ...` - Then every row will be auto-named “Series 0, Series 1, …” in order. - **Consistency rule:** Either all rows have a non-numeric first entry for the name, or none do. Do not mix. - The rest of the columns (after the optional name) must be numeric data points for that series. - You can filter by typing in the search box. Then check or uncheck individual names before plotting. - Use “Check All” / “Uncheck All” to quickly select or deselect every series. - Finally, click **Plot Forecasts** to view the quantile forecast for each selected series (for 64 steps ahead). """ ) gr.Markdown("## Citation") # make citation as code block gr.Markdown( """ If you use TiRex in your research, please cite our work: ``` @article{auerTiRexZeroShotForecasting2025, title = {{{TiRex}}: {{Zero-Shot Forecasting Across Long}} and {{Short Horizons}} with {{Enhanced In-Context Learning}}}, author = {Auer, Andreas and Podest, Patrick and Klotz, Daniel and B{\"o}ck, Sebastian and Klambauer, G{\"u}nter and Hochreiter, Sepp}, journal = {ArXiv}, volume = {2505.23719}, year = {2025} } ``` """ ) names_state = gr.State([]) file_input.change( fn=extract_names_and_update, inputs=[file_input, preset_dropdown], outputs=[filter_checkbox, names_state, forecast_length_slider] ) preset_dropdown.change( fn=extract_names_and_update, inputs=[file_input, preset_dropdown], outputs=[filter_checkbox, names_state, forecast_length_slider] ) # When search term changes, filter names search_box.change( fn=filter_names, inputs=[search_box, names_state], outputs=[filter_checkbox] ) # Check All / Uncheck All check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox) uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox) # Plot button plot_button.click( fn=display_filtered_forecast, inputs=[file_input, preset_dropdown, filter_checkbox, forecast_length_slider], outputs=[plot_output, errbox] ) demo.launch() ''' gradio app.py ssh -L 7860:localhost:7860 nikita_blago@oracle-gpu-controller -t \ ssh -L 7860:localhost:7860 compute-permanent-node-83 '''