Nikita commited on
Commit
7fb6e58
·
1 Parent(s): bf87f19

second round of edits from günther and andreas

Browse files
app.py CHANGED
@@ -6,7 +6,9 @@ from PIL import Image
6
  import numpy as np
7
  import gradio as gr
8
  import os
9
- from tirex import load_model, ForecastModel
 
 
10
 
11
  # ----------------------------
12
  # Helper functions (logic mostly unchanged)
@@ -15,21 +17,23 @@ from tirex import load_model, ForecastModel
15
  torch.manual_seed(42)
16
 
17
  def model_forecast(input_data, forecast_length=256, file_name=None):
18
- if os.path.basename(file_name) == "merged_ett2_loop.csv":
19
- _forecast_tensor = torch.load("data/merged_ett2_loop_forecast_256.pt")
 
 
 
20
  return _forecast_tensor[:,:forecast_length,:]
21
  elif os.path.basename(file_name) == "air_passangers.csv":
22
- _forecast_tensor = torch.load("data/air_passengers_forecast_256.pt")
23
  return _forecast_tensor[:,:forecast_length,:]
24
  else:
25
- model: ForecastModel = load_model("NX-AI/TiRex",device='cuda')
26
- forecast = model.forecast(context=input_data, prediction_length=forecast_length)
27
- return forecast[0]
 
28
 
29
 
30
 
31
-
32
-
33
  def plot_forecast_plotly(timeseries, quantile_predictions, timeseries_name):
34
  """
35
  - timeseries: 1D list/array of historical values.
@@ -66,7 +70,7 @@ def plot_forecast_plotly(timeseries, quantile_predictions, timeseries_name):
66
  y=lower_q,
67
  mode="lines",
68
  line=dict(color="rgba(0, 0, 0, 0)", width=0),
69
- name=f"{timeseries_name} – Lower Bound",
70
  hovertemplate="Lower: %{y:.2f}<extra></extra>"
71
  ))
72
 
@@ -78,7 +82,7 @@ def plot_forecast_plotly(timeseries, quantile_predictions, timeseries_name):
78
  line=dict(color="rgba(0, 0, 0, 0)", width=0),
79
  fill="tonexty",
80
  fillcolor="rgba(128, 128, 128, 0.3)",
81
- name=f"{timeseries_name} – Upper Bound",
82
  hovertemplate="Upper: %{y:.2f}<extra></extra>"
83
  ))
84
 
@@ -139,24 +143,6 @@ def load_table(file_path):
139
  raise ValueError("Unsupported format. Use CSV, XLS, XLSX, or PARQUET.")
140
 
141
 
142
- # def extract_names_and_update(file, preset_filename):
143
- # try:
144
- # if file is not None:
145
- # df = load_table(file.name)
146
- # else:
147
- # if not preset_filename:
148
- # return gr.update(choices=[], value=[]), []
149
- # df = load_table(preset_filename)
150
-
151
- # if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():
152
- # names = df.iloc[:, 0].tolist()
153
- # else:
154
- # names = [f"Series {i}" for i in range(len(df))]
155
- # return gr.update(choices=names, value=names), names
156
- # except Exception:
157
- # return gr.update(choices=[], value=[]), []
158
-
159
-
160
  def extract_names_and_update(file, preset_filename):
161
  try:
162
  # Determine which file to use and get default forecast length
@@ -206,7 +192,7 @@ def get_default_forecast_length(file_path):
206
  return 64
207
 
208
  filename = os.path.basename(file_path)
209
- if filename == "merged_ett2_loop.csv":
210
  return 256
211
  elif filename == "air_passangers.csv":
212
  return 48
@@ -216,17 +202,19 @@ def get_default_forecast_length(file_path):
216
 
217
  def display_filtered_forecast(file, preset_filename, selected_names, forecast_length):
218
  try:
219
- # If no file uploaded and no valid preset chosen, return early
220
  if file is None and (preset_filename is None or preset_filename == "-- No preset selected --"):
221
  return None, "No file selected."
222
 
223
- # Load data
224
  if file is not None:
225
  df = load_table(file.name)
 
226
  else:
227
  df = load_table(preset_filename)
 
228
 
229
- # Determine names vs numeric data
230
  if (
231
  df.shape[1] > 0
232
  and df.iloc[:, 0].dtype == object
@@ -238,63 +226,128 @@ def display_filtered_forecast(file, preset_filename, selected_names, forecast_le
238
  all_names = [f"Series {i}" for i in range(len(df))]
239
  data_only = df.astype(float)
240
 
241
- # Build a boolean mask for selected series
242
  mask = [name in selected_names for name in all_names]
243
  if not any(mask):
244
  return None, "No timeseries chosen to plot."
245
 
246
- # Extract the filtered historical data (numpy array of shape (n_selected, seq_len))
247
- filtered_data = data_only.iloc[mask, :].values
248
  filtered_names = [all_names[i] for i, m in enumerate(mask) if m]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
- # Get the full forecast tensor (n_series, pred_len, n_q)
251
- file_path = file.name if file is not None else preset_filename
252
- out = model_forecast(filtered_data, forecast_length=forecast_length, file_name=file_path)
253
- inp = torch.tensor(filtered_data) # shape = (n_selected, seq_len)
254
-
255
- # If only one series is selected, we can just call plot_forecast_plotly directly:
256
- if inp.shape[0] == 1:
257
- ts = inp[0].numpy().tolist()
258
- qp = out[0].numpy()
259
- fig = plot_forecast_plotly(ts, qp, filtered_names[0])
260
- return fig, ""
261
-
262
- # If multiple series are selected, build a master figure by concatenating traces
263
- master_fig = go.Figure()
264
- for idx in range(inp.shape[0]):
265
  ts = inp[idx].numpy().tolist()
266
  qp = out[idx].numpy()
267
  series_name = filtered_names[idx]
268
 
269
- # Get a “per‐series” figure
270
- small_fig = plot_forecast_plotly(ts, qp, series_name)
271
- # Append each trace from small_fig into master_fig
272
- for trace in small_fig.data:
273
- master_fig.add_trace(trace)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
- # Finally, update the layout of master_fig (you can tweak the title if you want)
276
- master_fig.update_layout(
277
  template="plotly_dark",
 
278
  title=dict(
279
  text="Forecasts for Selected Timeseries",
280
  x=0.5,
281
- font=dict(size=16, family="Arial", color="white")
282
  ),
283
- xaxis=dict(
284
- rangeslider=dict(visible=True), # <-- put rangeslider here
285
- fixedrange=False
286
- ),
287
- xaxis_title="Time",
288
- yaxis_title="Value",
289
  hovermode="x unified",
290
- # height=plot_height,
291
- # width=plot_width # ← add these here
292
- autosize=True,
293
  )
294
- return master_fig, ""
 
295
 
296
  except Exception as e:
297
- return None, f"Error: {e}. Use CSV, XLS, XLSX, or PARQUET."
 
298
 
299
 
300
  # ----------------------------
@@ -313,7 +366,7 @@ with gr.Blocks(fill_width=True,theme=gr.themes.Ocean()) as demo:
313
  label="Upload CSV / XLSX / PARQUET",
314
  file_types=[".csv", ".xls", ".xlsx", ".parquet"]
315
  )
316
- preset_choices = ["-- No preset selected --", "data/merged_ett2_loop.csv", "data/air_passangers.csv"]
317
 
318
  preset_dropdown = gr.Dropdown(
319
  label="Or choose a preset:",
@@ -347,7 +400,6 @@ with gr.Blocks(fill_width=True,theme=gr.themes.Ocean()) as demo:
347
  gr.Image("static/nxai_logo.png", width=150, show_label=False, container=False)
348
  gr.Image("static/tirex.jpeg", width=150, show_label=False, container=False)
349
 
350
- # Right column: interactive plot + instructions
351
  with gr.Column(scale=5):
352
  gr.Markdown("## Forecast Plot")
353
  plot_output = gr.Plot()
@@ -418,10 +470,10 @@ with gr.Blocks(fill_width=True,theme=gr.themes.Ocean()) as demo:
418
 
419
  # Plot button
420
  plot_button.click(
421
- fn=display_filtered_forecast,
422
- inputs=[file_input, preset_dropdown, filter_checkbox, forecast_length_slider], # <-- add slider here
423
- outputs=[plot_output, errbox]
424
- )
425
  demo.launch()
426
 
427
 
 
6
  import numpy as np
7
  import gradio as gr
8
  import os
9
+ from plotly.subplots import make_subplots
10
+
11
+ # from tirex import load_model, ForecastModel
12
 
13
  # ----------------------------
14
  # Helper functions (logic mostly unchanged)
 
17
  torch.manual_seed(42)
18
 
19
  def model_forecast(input_data, forecast_length=256, file_name=None):
20
+ if os.path.basename(file_name) == "loop.csv":
21
+ _forecast_tensor = torch.load("data/loop_forecast_512.pt")
22
+ return _forecast_tensor[:,:forecast_length,:]
23
+ elif os.path.basename(file_name) == "ett2.csv":
24
+ _forecast_tensor = torch.load("data/ett2_forecast_512.pt")
25
  return _forecast_tensor[:,:forecast_length,:]
26
  elif os.path.basename(file_name) == "air_passangers.csv":
27
+ _forecast_tensor = torch.load("data/air_passengers_forecast_512.pt")
28
  return _forecast_tensor[:,:forecast_length,:]
29
  else:
30
+ # model: ForecastModel = load_model("NX-AI/TiRex",device='cuda')
31
+ # forecast = model.forecast(context=input_data, prediction_length=forecast_length)
32
+ # return forecast[0]
33
+ pass
34
 
35
 
36
 
 
 
37
  def plot_forecast_plotly(timeseries, quantile_predictions, timeseries_name):
38
  """
39
  - timeseries: 1D list/array of historical values.
 
70
  y=lower_q,
71
  mode="lines",
72
  line=dict(color="rgba(0, 0, 0, 0)", width=0),
73
+ name=f"{timeseries_name} – 10% Quantile",
74
  hovertemplate="Lower: %{y:.2f}<extra></extra>"
75
  ))
76
 
 
82
  line=dict(color="rgba(0, 0, 0, 0)", width=0),
83
  fill="tonexty",
84
  fillcolor="rgba(128, 128, 128, 0.3)",
85
+ name=f"{timeseries_name} – 90% Quantile",
86
  hovertemplate="Upper: %{y:.2f}<extra></extra>"
87
  ))
88
 
 
143
  raise ValueError("Unsupported format. Use CSV, XLS, XLSX, or PARQUET.")
144
 
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  def extract_names_and_update(file, preset_filename):
147
  try:
148
  # Determine which file to use and get default forecast length
 
192
  return 64
193
 
194
  filename = os.path.basename(file_path)
195
+ if filename == "loop.csv" or filename == "ett2.csv":
196
  return 256
197
  elif filename == "air_passangers.csv":
198
  return 48
 
202
 
203
  def display_filtered_forecast(file, preset_filename, selected_names, forecast_length):
204
  try:
205
+ # 1) If no file or preset selected, show an error
206
  if file is None and (preset_filename is None or preset_filename == "-- No preset selected --"):
207
  return None, "No file selected."
208
 
209
+ # 2) Load DataFrame and remember which filename to pass to model_forecast
210
  if file is not None:
211
  df = load_table(file.name)
212
+ file_name = file.name
213
  else:
214
  df = load_table(preset_filename)
215
+ file_name = preset_filename
216
 
217
+ # 3) Determine whether first column is names or numeric
218
  if (
219
  df.shape[1] > 0
220
  and df.iloc[:, 0].dtype == object
 
226
  all_names = [f"Series {i}" for i in range(len(df))]
227
  data_only = df.astype(float)
228
 
229
+ # 4) Build mask from selected_names
230
  mask = [name in selected_names for name in all_names]
231
  if not any(mask):
232
  return None, "No timeseries chosen to plot."
233
 
234
+ filtered_data = data_only.iloc[mask, :].values # shape = (n_selected, seq_len)
 
235
  filtered_names = [all_names[i] for i, m in enumerate(mask) if m]
236
+ n_selected = filtered_data.shape[0]
237
+
238
+ # 5) First call model_forecast on all series, then select only the masked rows
239
+ full_data = data_only.values # shape = (n_all, seq_len)
240
+ full_out = model_forecast(full_data, forecast_length=forecast_length, file_name=file_name)
241
+
242
+ # Now pick only the rows we actually filtered
243
+ out = full_out[mask, :, :] # shape = (n_selected, pred_len, n_q)
244
+ inp = torch.tensor(filtered_data)
245
+
246
+ # 6) Create one subplot per selected series, with vertical spacing
247
+ fig = make_subplots(
248
+ rows=n_selected,
249
+ cols=1,
250
+ shared_xaxes=False,
251
+ vertical_spacing=0.3, # more space between subplots
252
+ subplot_titles=filtered_names
253
+ )
254
 
255
+ for idx in range(n_selected):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  ts = inp[idx].numpy().tolist()
257
  qp = out[idx].numpy()
258
  series_name = filtered_names[idx]
259
 
260
+ # a) plot historical data (blue line)
261
+ x_hist = list(range(len(ts)))
262
+ fig.add_trace(
263
+ go.Scatter(
264
+ x=x_hist,
265
+ y=ts,
266
+ mode="lines",
267
+ name=f"{series_name} – Given Data",
268
+ line=dict(color="blue", width=2),
269
+ showlegend=False
270
+ ),
271
+ row=idx + 1, col=1
272
+ )
273
+
274
+ # b) compute forecast indices
275
+ pred_len = qp.shape[0]
276
+ x_pred = list(range(len(ts) - 1, len(ts) - 1 + pred_len))
277
+
278
+ lower_q = qp[:, 0]
279
+ upper_q = qp[:, -1]
280
+ n_q = qp.shape[1]
281
+ median_idx = n_q // 2
282
+ median_q = qp[:, median_idx]
283
+
284
+ # c) lower‐bound (invisible)
285
+ fig.add_trace(
286
+ go.Scatter(
287
+ x=x_pred,
288
+ y=lower_q,
289
+ mode="lines",
290
+ line=dict(color="rgba(0,0,0,0)", width=0),
291
+ name=f"{series_name} – 10% Quantile",
292
+ hovertemplate="10% Quantile: %{y:.2f}<extra></extra>",
293
+ showlegend=False
294
+ ),
295
+ row=idx + 1, col=1
296
+ )
297
+
298
+ # d) upper‐bound (shaded area)
299
+ fig.add_trace(
300
+ go.Scatter(
301
+ x=x_pred,
302
+ y=upper_q,
303
+ mode="lines",
304
+ line=dict(color="rgba(0,0,0,0)", width=0),
305
+ fill="tonexty",
306
+ fillcolor="rgba(128,128,128,0.3)",
307
+ name=f"{series_name} – 90% Quantile",
308
+ hovertemplate="90% Quantile: %{y:.2f}<extra></extra>",
309
+ showlegend=False
310
+ ),
311
+ row=idx + 1, col=1
312
+ )
313
+
314
+ # e) median forecast (orange line)
315
+ fig.add_trace(
316
+ go.Scatter(
317
+ x=x_pred,
318
+ y=median_q,
319
+ mode="lines",
320
+ name=f"{series_name} – Median Forecast",
321
+ line=dict(color="orange", width=2),
322
+ hovertemplate="Median: %{y:.2f}<extra></extra>",
323
+ showlegend=False
324
+ ),
325
+ row=idx + 1, col=1
326
+ )
327
+
328
+ # f) label axes for each subplot
329
+ fig.update_xaxes(title_text="Time", row=idx + 1, col=1)
330
+ fig.update_yaxes(title_text="Value", row=idx + 1, col=1)
331
 
332
+ # 7) Global layout tweaks
333
+ fig.update_layout(
334
  template="plotly_dark",
335
+ height=300 * n_selected, # 300px per subplot
336
  title=dict(
337
  text="Forecasts for Selected Timeseries",
338
  x=0.5,
339
+ font=dict(size=20, family="Arial", color="white")
340
  ),
 
 
 
 
 
 
341
  hovermode="x unified",
342
+ margin=dict(t=120, b=40, l=60, r=40),
343
+ showlegend=False
 
344
  )
345
+
346
+ return fig, ""
347
 
348
  except Exception as e:
349
+ return None, f"Error: {str(e)}"
350
+
351
 
352
 
353
  # ----------------------------
 
366
  label="Upload CSV / XLSX / PARQUET",
367
  file_types=[".csv", ".xls", ".xlsx", ".parquet"]
368
  )
369
+ preset_choices = ["-- No preset selected --", "data/loop.csv", "data/air_passangers.csv", 'data/ett2.csv']
370
 
371
  preset_dropdown = gr.Dropdown(
372
  label="Or choose a preset:",
 
400
  gr.Image("static/nxai_logo.png", width=150, show_label=False, container=False)
401
  gr.Image("static/tirex.jpeg", width=150, show_label=False, container=False)
402
 
 
403
  with gr.Column(scale=5):
404
  gr.Markdown("## Forecast Plot")
405
  plot_output = gr.Plot()
 
470
 
471
  # Plot button
472
  plot_button.click(
473
+ fn=display_filtered_forecast,
474
+ inputs=[file_input, preset_dropdown, filter_checkbox, forecast_length_slider],
475
+ outputs=[plot_output, errbox]
476
+ )
477
  demo.launch()
478
 
479
 
data/.DS_Store CHANGED
Binary files a/data/.DS_Store and b/data/.DS_Store differ
 
data/{air_passengers_forecast_256.pt → air_passengers_forecast_512.pt} RENAMED
Binary files a/data/air_passengers_forecast_256.pt and b/data/air_passengers_forecast_512.pt differ
 
data/ett2.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/ett2_forecast_512.pt ADDED
Binary file (38.1 kB). View file
 
data/loop.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/loop_forecast_512.pt ADDED
Binary file (38.1 kB). View file
 
data/merged_ett2_loop.csv DELETED
The diff for this file is too large to render. See raw diff
 
data/merged_ett2_loop_forecast_256.pt DELETED
Binary file (38.2 kB)