import io import pandas as pd import torch import plotly.graph_objects as go from PIL import Image import numpy as np import gradio as gr import os from plotly.subplots import make_subplots from tirex import load_model, ForecastModel # ---------------------------- # Helper functions (logic mostly unchanged) # ---------------------------- torch.manual_seed(42) model: ForecastModel = load_model("NX-AI/TiRex",device='cuda') def model_forecast(input_data, forecast_length=256, file_name=None): if os.path.basename(file_name) == "loop.csv" and forecast_length==256: _forecast_tensor = torch.load("data/loop_forecast_256.pt") return _forecast_tensor elif os.path.basename(file_name) == "ett2.csv" and forecast_length==256: _forecast_tensor = torch.load("data/ett2_forecast_256.pt") return _forecast_tensor elif os.path.basename(file_name) == "air_passengers.csv"and forecast_length==24: _forecast_tensor = torch.load("data/air_passengers_forecast_24.pt") return _forecast_tensor else: forecast = model.forecast(context=input_data, prediction_length=forecast_length) return forecast[0] def load_table(file_path): ext = file_path.split(".")[-1].lower() if ext == "csv": return pd.read_csv(file_path) elif ext in ("xls", "xlsx"): return pd.read_excel(file_path) elif ext == "parquet": return pd.read_parquet(file_path) else: raise ValueError("Unsupported format. Use CSV, XLS, XLSX, or PARQUET.") def extract_names_and_update(file, preset_filename, transpose): try: # Determine which file to use and get default forecast length if file is not None: df = load_table(file.name) default_length = get_default_forecast_length(file.name) else: if not preset_filename or preset_filename == "-- No preset selected --": return gr.update(choices=[], value=[]), [], gr.update(value=256) df = load_table(preset_filename) default_length = get_default_forecast_length(preset_filename) # if user wants to transpose, do it here if transpose: df = df.T if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all(): names = df.iloc[:, 0].tolist() else: names = [f"Series {i}" for i in range(len(df))] return ( gr.update(choices=names, value=names), names, gr.update(value=default_length) ) except Exception: return gr.update(choices=[], value=[]), [], gr.update(value=256) def filter_names(search_term, all_names): if not all_names: return gr.update(choices=[], value=[]) if not search_term: return gr.update(choices=all_names, value=all_names) lower = search_term.lower() filtered = [n for n in all_names if lower in str(n).lower()] return gr.update(choices=filtered, value=filtered) def check_all(names_list): return gr.update(value=names_list) def uncheck_all(_): return gr.update(value=[]) def get_default_forecast_length(file_path): """Get default forecast length based on filename""" if file_path is None: return 64 filename = os.path.basename(file_path) if filename == "loop.csv" or filename == "ett2.csv": return 256 elif filename == "air_passengers.csv": return 24 else: return 64 def display_filtered_forecast(file, preset_filename, selected_names, forecast_length, transpose): try: # 1) If no file or preset selected, show an error if file is None and (preset_filename is None or preset_filename == "-- No preset selected --"): return None, "No file selected." # 2) Load DataFrame and remember which filename to pass to model_forecast if file is not None: df = load_table(file.name) file_name = file.name else: df = load_table(preset_filename) file_name = preset_filename if transpose: df = df.T # 3) Determine whether first column is names or numeric if ( df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all() ): if df.shape[1]>2048 and file is not None: df = pd.concat([ df.iloc[:, [0]], df.iloc[:, -2048:] ], axis=1) gr.Info("Maximum of 2048 steps per timeseries (row) is allowed, hence last 2048 kept. ℹ️", duration=5) all_names = df.iloc[:, 0].tolist() data_only_full = df.iloc[:, 1:].astype(float) else: if df.shape[1]>2048 and file is not None: df = df.iloc[:, -2048:] gr.Info("Maximum of 2048 steps per timeseries (row) is allowed, hence last 2048 kept. ℹ️", duration=5) all_names = [f"Series {i}" for i in range(len(df))] data_only_full = df.astype(float) data_only = data_only_full # 4) Build mask from selected_names mask = [name in selected_names for name in all_names] if not any(mask): return None, "No timeseries chosen to plot." filtered_data = data_only.iloc[mask, :].values # shape = (n_selected, seq_len) filtered_data_only_full = data_only_full.iloc[mask, :].values # ** Added to show prediction accuracy filtered_names = [all_names[i] for i, m in enumerate(mask) if m] n_selected = filtered_data.shape[0] if n_selected>30: raise gr.Error("Maximum of 30 timeseries (rows) is possible to choose", duration=5) # 5) First call model_forecast on all series, then select only the masked rows full_data = data_only.values # shape = (n_all, seq_len) if file is not None: full_out = model_forecast(full_data, forecast_length=forecast_length, file_name=file_name) else: if preset_filename=='data/ett2.csv' or preset_filename=="data/loop.csv": full_out = model_forecast(full_data[:, :2048], forecast_length=forecast_length, file_name=file_name) elif preset_filename=="data/air_passengers.csv": full_out = model_forecast(full_data[:, :132], forecast_length=forecast_length, file_name=file_name) # Now pick only the rows we actually filtered out = full_out[mask, :, :] # shape = (n_selected, pred_len, n_q) inp = torch.tensor(filtered_data) inp_full = torch.tensor(filtered_data_only_full) # ** Added to show prediction accuracy # 6) Create one subplot per selected series, with vertical spacing subplot_height_px = 350 # px per subplot n_selected = len(filtered_names) fig = make_subplots( rows=n_selected, cols=1, shared_xaxes=False, subplot_titles=filtered_names, row_heights=[1] * n_selected, # all rows equal height ) fig.update_layout( height=subplot_height_px * n_selected, template="plotly_dark", margin=dict(t=50, b=50) ) for idx in range(n_selected): ts = inp[idx].numpy().tolist() ts_full = inp_full[idx].numpy().tolist() qp = out[idx].numpy() series_name = filtered_names[idx] pred_len = qp.shape[0] if file is not None: x_pred = list(range(len(ts), len(ts) + pred_len)) else: if preset_filename=='data/ett2.csv' or preset_filename=="data/loop.csv": x_pred = list(range(2048, 2048 + pred_len)) elif preset_filename=="data/air_passengers.csv": x_pred = list(range(132, 132 + pred_len)) # a) plot historical data (blue line) x_hist = list(range(len(ts_full))) if x_pred[-1]