Nikita commited on
Commit
fc8f543
·
1 Parent(s): de2de2d

trying real app.py

Browse files
Files changed (3) hide show
  1. app.py +493 -32
  2. orig_app.py +0 -500
  3. test_app.py +39 -0
app.py CHANGED
@@ -1,39 +1,500 @@
 
 
 
 
 
 
1
  import gradio as gr
2
- import time # Import time to make logs more distinct
 
3
 
4
- def greet(name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  """
6
- This function takes a name as input and returns a personalized greeting string.
7
- It now includes print statements for logging with flush=True to ensure
8
- logs appear immediately in container environments like Hugging Face Spaces.
 
9
  """
10
- # Log the function entry
11
- # The flush=True argument is crucial for logs to appear in real-time in Docker.
12
- print(f"[{time.ctime()}] - Function 'greet' was called.", flush=True)
13
-
14
- if name:
15
- # Log the received input
16
- print(f"[{time.ctime()}] - Received input name: '{name}'", flush=True)
17
- return f"Hello, {name}! Welcome to your first Gradio app."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  else:
19
- # Log that the input was empty
20
- print(f"[{time.ctime()}] - No input name received.", flush=True)
21
- return "Hello! Please enter your name."
22
-
23
- # Create the Gradio interface
24
- app = gr.Interface(
25
- fn=greet,
26
- inputs=gr.Textbox(
27
- lines=1,
28
- placeholder="Please enter your name here...",
29
- label="Your Name"
30
- ),
31
- outputs=gr.Text(label="Greeting"),
32
- title="Simple Greeting App with Logging",
33
- description="Enter your name to receive a greeting. Check the Hugging Face logs to see the output from the print() statements."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  )
 
 
 
 
 
 
 
 
35
 
36
- # Launch the application
37
- if __name__ == "__main__":
38
- print(f"[{time.ctime()}] - Starting Gradio server...", flush=True)
39
- app.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import io
2
+ import pandas as pd
3
+ import torch
4
+ import plotly.graph_objects as go
5
+ from PIL import Image
6
+ import numpy as np
7
  import gradio as gr
8
+ import os
9
+ from plotly.subplots import make_subplots
10
 
11
+ from tirex import load_model, ForecastModel
12
+
13
+ # ----------------------------
14
+ # Helper functions (logic mostly unchanged)
15
+ # ----------------------------
16
+
17
+ torch.manual_seed(42)
18
+ model: ForecastModel = load_model("NX-AI/TiRex",device='cuda')
19
+
20
+ def model_forecast(input_data, forecast_length=256, file_name=None):
21
+ if os.path.basename(file_name) == "loop.csv":
22
+ _forecast_tensor = torch.load("data/loop_forecast_512.pt")
23
+ return _forecast_tensor[:,:forecast_length,:]
24
+ elif os.path.basename(file_name) == "ett2.csv":
25
+ _forecast_tensor = torch.load("data/ett2_forecast_512.pt")
26
+ return _forecast_tensor[:,:forecast_length,:]
27
+ elif os.path.basename(file_name) == "air_passangers.csv":
28
+ _forecast_tensor = torch.load("data/air_passengers_forecast_512.pt")
29
+ return _forecast_tensor[:,:forecast_length,:]
30
+ else:
31
+ forecast = model.forecast(context=input_data, prediction_length=forecast_length)
32
+ return forecast[0]
33
+
34
+
35
+
36
+ def plot_forecast_plotly(timeseries, quantile_predictions, timeseries_name):
37
  """
38
+ - timeseries: 1D list/array of historical values.
39
+ - quantile_predictions: 2D array of shape (pred_len, n_q),
40
+ with quantiles sorted left→right.
41
+ - timeseries_name: string label.
42
  """
43
+ fig = go.Figure()
44
+
45
+ # 1) Plot historical data (blue line, no markers)
46
+ x_hist = list(range(len(timeseries)))
47
+ fig.add_trace(go.Scatter(
48
+ x=x_hist,
49
+ y=timeseries,
50
+ mode="lines", # no markers
51
+ name=f"{timeseries_name} – Given Data",
52
+ line=dict(color="blue", width=2),
53
+ ))
54
+
55
+ # 2) X-axis indices for forecasts
56
+ pred_len = quantile_predictions.shape[0]
57
+ x_pred = list(range(len(timeseries) - 1, len(timeseries) - 1 + pred_len))
58
+
59
+ # 3) Extract lower, upper, and median quantiles
60
+ lower_q = quantile_predictions[:, 0]
61
+ upper_q = quantile_predictions[:, -1]
62
+ n_q = quantile_predictions.shape[1]
63
+ median_idx = n_q // 2
64
+ median_q = quantile_predictions[:, median_idx]
65
+
66
+ # 4) Lower‐bound trace (invisible line, still shows on hover)
67
+ fig.add_trace(go.Scatter(
68
+ x=x_pred,
69
+ y=lower_q,
70
+ mode="lines",
71
+ line=dict(color="rgba(0, 0, 0, 0)", width=0),
72
+ name=f"{timeseries_name} – 10% Quantile",
73
+ hovertemplate="Lower: %{y:.2f}<extra></extra>"
74
+ ))
75
+
76
+ # 5) Upper‐bound trace (shaded down to lower_q)
77
+ fig.add_trace(go.Scatter(
78
+ x=x_pred,
79
+ y=upper_q,
80
+ mode="lines",
81
+ line=dict(color="rgba(0, 0, 0, 0)", width=0),
82
+ fill="tonexty",
83
+ fillcolor="rgba(128, 128, 128, 0.3)",
84
+ name=f"{timeseries_name} – 90% Quantile",
85
+ hovertemplate="Upper: %{y:.2f}<extra></extra>"
86
+ ))
87
+
88
+ # 6) Median trace (orange) on top
89
+ fig.add_trace(go.Scatter(
90
+ x=x_pred,
91
+ y=median_q,
92
+ mode="lines",
93
+ name=f"{timeseries_name} – Median Forecast",
94
+ line=dict(color="orange", width=2),
95
+ hovertemplate="Median: %{y:.2f}<extra></extra>"
96
+ ))
97
+
98
+ # 7) Layout: title on left (y=0.95), legend on right (y=0.95)
99
+ fig.update_layout(
100
+ template="plotly_dark",
101
+ title=dict(
102
+ text=f"Timeseries: {timeseries_name}",
103
+ x=0.10, # left‐align
104
+ xanchor="left",
105
+ y=0.90, # near top
106
+ yanchor="bottom",
107
+ font=dict(size=18, family="Arial", color="white")
108
+ ),
109
+ xaxis=dict(
110
+ rangeslider=dict(visible=True), # <-- put rangeslider here
111
+ fixedrange=False
112
+ ),
113
+ xaxis_title="Time",
114
+ yaxis_title="Value",
115
+ hovermode="x unified",
116
+ margin=dict(
117
+ t=120, # increase top margin to fit title+legend comfortably
118
+ b=40,
119
+ l=60,
120
+ r=40
121
+ ),
122
+ # height=plot_height,
123
+ # width=plot_width,
124
+ autosize=True,
125
+ )
126
+
127
+ return fig
128
+
129
+
130
+
131
+
132
+
133
+ def load_table(file_path):
134
+ ext = file_path.split(".")[-1].lower()
135
+ if ext == "csv":
136
+ return pd.read_csv(file_path)
137
+ elif ext in ("xls", "xlsx"):
138
+ return pd.read_excel(file_path)
139
+ elif ext == "parquet":
140
+ return pd.read_parquet(file_path)
141
  else:
142
+ raise ValueError("Unsupported format. Use CSV, XLS, XLSX, or PARQUET.")
143
+
144
+
145
+ def extract_names_and_update(file, preset_filename):
146
+ try:
147
+ # Determine which file to use and get default forecast length
148
+ if file is not None:
149
+ df = load_table(file.name)
150
+ default_length = get_default_forecast_length(file.name)
151
+ else:
152
+ if not preset_filename or preset_filename == "-- No preset selected --":
153
+ return gr.update(choices=[], value=[]), [], gr.update(value=256)
154
+ df = load_table(preset_filename)
155
+ default_length = get_default_forecast_length(preset_filename)
156
+
157
+ if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():
158
+ names = df.iloc[:, 0].tolist()
159
+ else:
160
+ names = [f"Series {i}" for i in range(len(df))]
161
+
162
+ return (
163
+ gr.update(choices=names, value=names),
164
+ names,
165
+ gr.update(value=default_length)
166
+ )
167
+ except Exception:
168
+ return gr.update(choices=[], value=[]), [], gr.update(value=256)
169
+
170
+
171
+ def filter_names(search_term, all_names):
172
+ if not all_names:
173
+ return gr.update(choices=[], value=[])
174
+ if not search_term:
175
+ return gr.update(choices=all_names, value=all_names)
176
+ lower = search_term.lower()
177
+ filtered = [n for n in all_names if lower in str(n).lower()]
178
+ return gr.update(choices=filtered, value=filtered)
179
+
180
+
181
+ def check_all(names_list):
182
+ return gr.update(value=names_list)
183
+
184
+
185
+ def uncheck_all(_):
186
+ return gr.update(value=[])
187
+
188
+ def get_default_forecast_length(file_path):
189
+ """Get default forecast length based on filename"""
190
+ if file_path is None:
191
+ return 64
192
+
193
+ filename = os.path.basename(file_path)
194
+ if filename == "loop.csv" or filename == "ett2.csv":
195
+ return 256
196
+ elif filename == "air_passangers.csv":
197
+ return 48
198
+ else:
199
+ return 64
200
+
201
+
202
+ def display_filtered_forecast(file, preset_filename, selected_names, forecast_length):
203
+ try:
204
+ # 1) If no file or preset selected, show an error
205
+ if file is None and (preset_filename is None or preset_filename == "-- No preset selected --"):
206
+ return None, "No file selected."
207
+
208
+ # 2) Load DataFrame and remember which filename to pass to model_forecast
209
+ if file is not None:
210
+ df = load_table(file.name)
211
+ file_name = file.name
212
+ else:
213
+ df = load_table(preset_filename)
214
+ file_name = preset_filename
215
+
216
+ if df.shape[1]>2048:
217
+ df = df.iloc[:,-2048:]
218
+ gr.Info("Maximum of 2048 steps per timeseries (row) is allowed, hence last 2048 kept. ℹ️", duration=5)
219
+
220
+
221
+ # 3) Determine whether first column is names or numeric
222
+ if (
223
+ df.shape[1] > 0
224
+ and df.iloc[:, 0].dtype == object
225
+ and not df.iloc[:, 0].str.isnumeric().all()
226
+ ):
227
+ all_names = df.iloc[:, 0].tolist()
228
+ data_only = df.iloc[:, 1:].astype(float)
229
+ else:
230
+ all_names = [f"Series {i}" for i in range(len(df))]
231
+ data_only = df.astype(float)
232
+
233
+ # 4) Build mask from selected_names
234
+ mask = [name in selected_names for name in all_names]
235
+ if not any(mask):
236
+ return None, "No timeseries chosen to plot."
237
+
238
+ filtered_data = data_only.iloc[mask, :].values # shape = (n_selected, seq_len)
239
+ filtered_names = [all_names[i] for i, m in enumerate(mask) if m]
240
+ n_selected = filtered_data.shape[0]
241
+ if n_selected>30:
242
+ raise gr.Error("Maximum of 30 timeseries (rows) is possible to choose", duration=5)
243
+
244
+ # 5) First call model_forecast on all series, then select only the masked rows
245
+ full_data = data_only.values # shape = (n_all, seq_len)
246
+ full_out = model_forecast(full_data, forecast_length=forecast_length, file_name=file_name)
247
+
248
+ # Now pick only the rows we actually filtered
249
+ out = full_out[mask, :, :] # shape = (n_selected, pred_len, n_q)
250
+ inp = torch.tensor(filtered_data)
251
+
252
+ # 6) Create one subplot per selected series, with vertical spacing
253
+ subplot_height_px = 350 # px per subplot
254
+ n_selected = len(filtered_names)
255
+ fig = make_subplots(
256
+ rows=n_selected,
257
+ cols=1,
258
+ shared_xaxes=False,
259
+ subplot_titles=filtered_names,
260
+ row_heights=[1] * n_selected, # all rows equal height
261
+ )
262
+ fig.update_layout(
263
+ height=subplot_height_px * n_selected,
264
+ template="plotly_dark",
265
+ margin=dict(t=50, b=50)
266
+ )
267
+
268
+ for idx in range(n_selected):
269
+ ts = inp[idx].numpy().tolist()
270
+ qp = out[idx].numpy()
271
+ series_name = filtered_names[idx]
272
+
273
+ # a) plot historical data (blue line)
274
+ x_hist = list(range(len(ts)))
275
+ fig.add_trace(
276
+ go.Scatter(
277
+ x=x_hist,
278
+ y=ts,
279
+ mode="lines",
280
+ name=f"{series_name} – Given Data",
281
+ line=dict(color="blue", width=2),
282
+ showlegend=False
283
+ ),
284
+ row=idx + 1, col=1
285
+ )
286
+
287
+ # b) compute forecast indices
288
+ pred_len = qp.shape[0]
289
+ x_pred = list(range(len(ts) - 1, len(ts) - 1 + pred_len))
290
+
291
+ lower_q = qp[:, 0]
292
+ upper_q = qp[:, -1]
293
+ n_q = qp.shape[1]
294
+ median_idx = n_q // 2
295
+ median_q = qp[:, median_idx]
296
+
297
+ # c) lower‐bound (invisible)
298
+ fig.add_trace(
299
+ go.Scatter(
300
+ x=x_pred,
301
+ y=lower_q,
302
+ mode="lines",
303
+ line=dict(color="rgba(0,0,0,0)", width=0),
304
+ name=f"{series_name} – 10% Quantile",
305
+ hovertemplate="10% Quantile: %{y:.2f}<extra></extra>",
306
+ showlegend=False
307
+ ),
308
+ row=idx + 1, col=1
309
+ )
310
+
311
+ # d) upper‐bound (shaded area)
312
+ fig.add_trace(
313
+ go.Scatter(
314
+ x=x_pred,
315
+ y=upper_q,
316
+ mode="lines",
317
+ line=dict(color="rgba(0,0,0,0)", width=0),
318
+ fill="tonexty",
319
+ fillcolor="rgba(128,128,128,0.3)",
320
+ name=f"{series_name} – 90% Quantile",
321
+ hovertemplate="90% Quantile: %{y:.2f}<extra></extra>",
322
+ showlegend=False
323
+ ),
324
+ row=idx + 1, col=1
325
+ )
326
+
327
+ # e) median forecast (orange line)
328
+ fig.add_trace(
329
+ go.Scatter(
330
+ x=x_pred,
331
+ y=median_q,
332
+ mode="lines",
333
+ name=f"{series_name} – Median Forecast",
334
+ line=dict(color="orange", width=2),
335
+ hovertemplate="Median: %{y:.2f}<extra></extra>",
336
+ showlegend=False
337
+ ),
338
+ row=idx + 1, col=1
339
+ )
340
+
341
+ # f) label axes for each subplot
342
+ fig.update_xaxes(title_text="Time", row=idx + 1, col=1)
343
+ fig.update_yaxes(title_text="Value", row=idx + 1, col=1)
344
+
345
+ # 7) Global layout tweaks
346
+ fig.update_layout(
347
+ template="plotly_dark",
348
+ height=300 * n_selected, # 300px per subplot
349
+ title=dict(
350
+ text="Forecasts for Selected Timeseries",
351
+ x=0.5,
352
+ font=dict(size=20, family="Arial", color="white")
353
+ ),
354
+ hovermode="x unified",
355
+ margin=dict(t=120, b=40, l=60, r=40),
356
+ showlegend=False
357
+ )
358
+
359
+ return fig, ""
360
+ except gr.Error as e:
361
+ raise gr.Error(e, duration=5)
362
+
363
+ except Exception as e:
364
+ return None, f"Error: {str(e)}"
365
+
366
+
367
+
368
+ # ----------------------------
369
+ # Gradio layout: two columns + instructions
370
+ # ----------------------------
371
+
372
+ with gr.Blocks(fill_width=True,theme=gr.themes.Ocean()) as demo:
373
+ gr.Markdown("# 📈 TiRex - timeseries forecasting 📊")
374
+ gr.Markdown("Upload data or choose a preset, filter by name, then click Plot.")
375
+
376
+ with gr.Row():
377
+ # Left column: controls
378
+ with gr.Column(scale=1):
379
+ gr.Markdown("## Data Selection")
380
+ file_input = gr.File(
381
+ label="Upload CSV / XLSX / PARQUET",
382
+ file_types=[".csv", ".xls", ".xlsx", ".parquet"]
383
+ )
384
+ preset_choices = ["-- No preset selected --", "data/loop.csv", "data/air_passangers.csv", 'data/ett2.csv']
385
+
386
+ preset_dropdown = gr.Dropdown(
387
+ label="Or choose a preset:",
388
+ choices=preset_choices,
389
+ value="-- No preset selected --"
390
+ )
391
+
392
+ gr.Markdown("## Forecast Length Setting")
393
+ forecast_length_slider = gr.Slider(
394
+ minimum=1,
395
+ maximum=512,
396
+ value=64,
397
+ step=1,
398
+ label="Forecast Length (Steps)",
399
+ info="Choose how many future steps to forecast."
400
+ )
401
+
402
+ gr.Markdown("## Search / Filter")
403
+ search_box = gr.Textbox(placeholder="Type to filter (e.g. 'AMZN')")
404
+ filter_checkbox = gr.CheckboxGroup(
405
+ choices=[], value=[], label="Select which timeseries to show"
406
+ )
407
+
408
+ with gr.Row():
409
+ check_all_btn = gr.Button("✅ Check All")
410
+ uncheck_all_btn = gr.Button("❎ Uncheck All")
411
+
412
+ plot_button = gr.Button("▶️ Plot Forecasts")
413
+ errbox = gr.Textbox(label="Error Message", interactive=False)
414
+ with gr.Row():
415
+ gr.Image("static/nxai_logo.png", width=150, show_label=False, container=False)
416
+ gr.Image("static/tirex.jpeg", width=150, show_label=False, container=False)
417
+
418
+ with gr.Column(scale=5):
419
+ gr.Markdown("## Forecast Plot")
420
+ plot_output = gr.Plot()
421
+
422
+ # Instruction text below plot
423
+ gr.Markdown("## Instructions")
424
+ gr.Markdown(
425
+ """
426
+ **How to format your data:**
427
+ - Your file must be a table (CSV, XLS, XLSX, or Parquet).
428
+ - **One row per timeseries.** Each row is treated as a separate series.
429
+ - If you want to **name** each series, put the name as the first value in **every** row:
430
+ - Example (CSV):
431
+ `AAPL, 120.5, 121.0, 119.8, ...`
432
+ `AMZN, 3300.0, 3310.5, 3295.2, ...`
433
+ - In that case, the first column is not numeric, so it will be used as the series name.
434
+ - If you do **not** want named series, simply leave out the first column entirely and have all values numeric:
435
+ - Example:
436
+ `120.5, 121.0, 119.8, ...`
437
+ `3300.0, 3310.5, 3295.2, ...`
438
+ - Then every row will be auto-named “Series 0, Series 1, …” in order.
439
+ - **Consistency rule:** Either all rows have a non-numeric first entry for the name, or none do. Do not mix.
440
+ - The rest of the columns (after the optional name) must be numeric data points for that series.
441
+ - You can filter by typing in the search box. Then check or uncheck individual names before plotting.
442
+ - Use “Check All” / “Uncheck All” to quickly select or deselect every series.
443
+ - Finally, click **Plot Forecasts** to view the quantile forecast for each selected series (for 64 steps ahead).
444
+ """
445
+ )
446
+ gr.Markdown("## Citation")
447
+ # make citation as code block
448
+ gr.Markdown(
449
+ """
450
+ If you use TiRex in your research, please cite our work:
451
+ ```
452
+ @article{auerTiRexZeroShotForecasting2025,
453
+ title = {{{TiRex}}: {{Zero-Shot Forecasting Across Long}} and {{Short Horizons}} with {{Enhanced In-Context Learning}}},
454
+ author = {Auer, Andreas and Podest, Patrick and Klotz, Daniel and B{\"o}ck, Sebastian and Klambauer, G{\"u}nter and Hochreiter, Sepp},
455
+ journal = {ArXiv},
456
+ volume = {2505.23719},
457
+ year = {2025}
458
+ }
459
+ ```
460
+ """
461
+ )
462
+
463
+ names_state = gr.State([])
464
+ file_input.change(
465
+ fn=extract_names_and_update,
466
+ inputs=[file_input, preset_dropdown],
467
+ outputs=[filter_checkbox, names_state, forecast_length_slider]
468
+ )
469
+ preset_dropdown.change(
470
+ fn=extract_names_and_update,
471
+ inputs=[file_input, preset_dropdown],
472
+ outputs=[filter_checkbox, names_state, forecast_length_slider]
473
+ )
474
+
475
+ # When search term changes, filter names
476
+ search_box.change(
477
+ fn=filter_names,
478
+ inputs=[search_box, names_state],
479
+ outputs=[filter_checkbox]
480
+ )
481
+
482
+ # Check All / Uncheck All
483
+ check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox)
484
+ uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox)
485
+
486
+ # Plot button
487
+ plot_button.click(
488
+ fn=display_filtered_forecast,
489
+ inputs=[file_input, preset_dropdown, filter_checkbox, forecast_length_slider],
490
+ outputs=[plot_output, errbox]
491
  )
492
+ demo.launch()
493
+
494
+
495
+ '''
496
+ gradio app.py
497
+ ssh -L 7860:localhost:7860 nikita_blago@oracle-gpu-controller -t \
498
+ ssh -L 7860:localhost:7860 compute-permanent-node-83
499
+ '''
500
 
 
 
 
 
orig_app.py DELETED
@@ -1,500 +0,0 @@
1
- import io
2
- import pandas as pd
3
- import torch
4
- import plotly.graph_objects as go
5
- from PIL import Image
6
- import numpy as np
7
- import gradio as gr
8
- import os
9
- from plotly.subplots import make_subplots
10
-
11
- from tirex import load_model, ForecastModel
12
-
13
- # ----------------------------
14
- # Helper functions (logic mostly unchanged)
15
- # ----------------------------
16
-
17
- torch.manual_seed(42)
18
- model: ForecastModel = load_model("NX-AI/TiRex",device='cuda')
19
-
20
- def model_forecast(input_data, forecast_length=256, file_name=None):
21
- if os.path.basename(file_name) == "loop.csv":
22
- _forecast_tensor = torch.load("data/loop_forecast_512.pt")
23
- return _forecast_tensor[:,:forecast_length,:]
24
- elif os.path.basename(file_name) == "ett2.csv":
25
- _forecast_tensor = torch.load("data/ett2_forecast_512.pt")
26
- return _forecast_tensor[:,:forecast_length,:]
27
- elif os.path.basename(file_name) == "air_passangers.csv":
28
- _forecast_tensor = torch.load("data/air_passengers_forecast_512.pt")
29
- return _forecast_tensor[:,:forecast_length,:]
30
- else:
31
- forecast = model.forecast(context=input_data, prediction_length=forecast_length)
32
- return forecast[0]
33
-
34
-
35
-
36
- def plot_forecast_plotly(timeseries, quantile_predictions, timeseries_name):
37
- """
38
- - timeseries: 1D list/array of historical values.
39
- - quantile_predictions: 2D array of shape (pred_len, n_q),
40
- with quantiles sorted left→right.
41
- - timeseries_name: string label.
42
- """
43
- fig = go.Figure()
44
-
45
- # 1) Plot historical data (blue line, no markers)
46
- x_hist = list(range(len(timeseries)))
47
- fig.add_trace(go.Scatter(
48
- x=x_hist,
49
- y=timeseries,
50
- mode="lines", # no markers
51
- name=f"{timeseries_name} – Given Data",
52
- line=dict(color="blue", width=2),
53
- ))
54
-
55
- # 2) X-axis indices for forecasts
56
- pred_len = quantile_predictions.shape[0]
57
- x_pred = list(range(len(timeseries) - 1, len(timeseries) - 1 + pred_len))
58
-
59
- # 3) Extract lower, upper, and median quantiles
60
- lower_q = quantile_predictions[:, 0]
61
- upper_q = quantile_predictions[:, -1]
62
- n_q = quantile_predictions.shape[1]
63
- median_idx = n_q // 2
64
- median_q = quantile_predictions[:, median_idx]
65
-
66
- # 4) Lower‐bound trace (invisible line, still shows on hover)
67
- fig.add_trace(go.Scatter(
68
- x=x_pred,
69
- y=lower_q,
70
- mode="lines",
71
- line=dict(color="rgba(0, 0, 0, 0)", width=0),
72
- name=f"{timeseries_name} – 10% Quantile",
73
- hovertemplate="Lower: %{y:.2f}<extra></extra>"
74
- ))
75
-
76
- # 5) Upper‐bound trace (shaded down to lower_q)
77
- fig.add_trace(go.Scatter(
78
- x=x_pred,
79
- y=upper_q,
80
- mode="lines",
81
- line=dict(color="rgba(0, 0, 0, 0)", width=0),
82
- fill="tonexty",
83
- fillcolor="rgba(128, 128, 128, 0.3)",
84
- name=f"{timeseries_name} – 90% Quantile",
85
- hovertemplate="Upper: %{y:.2f}<extra></extra>"
86
- ))
87
-
88
- # 6) Median trace (orange) on top
89
- fig.add_trace(go.Scatter(
90
- x=x_pred,
91
- y=median_q,
92
- mode="lines",
93
- name=f"{timeseries_name} – Median Forecast",
94
- line=dict(color="orange", width=2),
95
- hovertemplate="Median: %{y:.2f}<extra></extra>"
96
- ))
97
-
98
- # 7) Layout: title on left (y=0.95), legend on right (y=0.95)
99
- fig.update_layout(
100
- template="plotly_dark",
101
- title=dict(
102
- text=f"Timeseries: {timeseries_name}",
103
- x=0.10, # left‐align
104
- xanchor="left",
105
- y=0.90, # near top
106
- yanchor="bottom",
107
- font=dict(size=18, family="Arial", color="white")
108
- ),
109
- xaxis=dict(
110
- rangeslider=dict(visible=True), # <-- put rangeslider here
111
- fixedrange=False
112
- ),
113
- xaxis_title="Time",
114
- yaxis_title="Value",
115
- hovermode="x unified",
116
- margin=dict(
117
- t=120, # increase top margin to fit title+legend comfortably
118
- b=40,
119
- l=60,
120
- r=40
121
- ),
122
- # height=plot_height,
123
- # width=plot_width,
124
- autosize=True,
125
- )
126
-
127
- return fig
128
-
129
-
130
-
131
-
132
-
133
- def load_table(file_path):
134
- ext = file_path.split(".")[-1].lower()
135
- if ext == "csv":
136
- return pd.read_csv(file_path)
137
- elif ext in ("xls", "xlsx"):
138
- return pd.read_excel(file_path)
139
- elif ext == "parquet":
140
- return pd.read_parquet(file_path)
141
- else:
142
- raise ValueError("Unsupported format. Use CSV, XLS, XLSX, or PARQUET.")
143
-
144
-
145
- def extract_names_and_update(file, preset_filename):
146
- try:
147
- # Determine which file to use and get default forecast length
148
- if file is not None:
149
- df = load_table(file.name)
150
- default_length = get_default_forecast_length(file.name)
151
- else:
152
- if not preset_filename or preset_filename == "-- No preset selected --":
153
- return gr.update(choices=[], value=[]), [], gr.update(value=256)
154
- df = load_table(preset_filename)
155
- default_length = get_default_forecast_length(preset_filename)
156
-
157
- if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():
158
- names = df.iloc[:, 0].tolist()
159
- else:
160
- names = [f"Series {i}" for i in range(len(df))]
161
-
162
- return (
163
- gr.update(choices=names, value=names),
164
- names,
165
- gr.update(value=default_length)
166
- )
167
- except Exception:
168
- return gr.update(choices=[], value=[]), [], gr.update(value=256)
169
-
170
-
171
- def filter_names(search_term, all_names):
172
- if not all_names:
173
- return gr.update(choices=[], value=[])
174
- if not search_term:
175
- return gr.update(choices=all_names, value=all_names)
176
- lower = search_term.lower()
177
- filtered = [n for n in all_names if lower in str(n).lower()]
178
- return gr.update(choices=filtered, value=filtered)
179
-
180
-
181
- def check_all(names_list):
182
- return gr.update(value=names_list)
183
-
184
-
185
- def uncheck_all(_):
186
- return gr.update(value=[])
187
-
188
- def get_default_forecast_length(file_path):
189
- """Get default forecast length based on filename"""
190
- if file_path is None:
191
- return 64
192
-
193
- filename = os.path.basename(file_path)
194
- if filename == "loop.csv" or filename == "ett2.csv":
195
- return 256
196
- elif filename == "air_passangers.csv":
197
- return 48
198
- else:
199
- return 64
200
-
201
-
202
- def display_filtered_forecast(file, preset_filename, selected_names, forecast_length):
203
- try:
204
- # 1) If no file or preset selected, show an error
205
- if file is None and (preset_filename is None or preset_filename == "-- No preset selected --"):
206
- return None, "No file selected."
207
-
208
- # 2) Load DataFrame and remember which filename to pass to model_forecast
209
- if file is not None:
210
- df = load_table(file.name)
211
- file_name = file.name
212
- else:
213
- df = load_table(preset_filename)
214
- file_name = preset_filename
215
-
216
- if df.shape[1]>2048:
217
- df = df.iloc[:,-2048:]
218
- gr.Info("Maximum of 2048 steps per timeseries (row) is allowed, hence last 2048 kept. ℹ️", duration=5)
219
-
220
-
221
- # 3) Determine whether first column is names or numeric
222
- if (
223
- df.shape[1] > 0
224
- and df.iloc[:, 0].dtype == object
225
- and not df.iloc[:, 0].str.isnumeric().all()
226
- ):
227
- all_names = df.iloc[:, 0].tolist()
228
- data_only = df.iloc[:, 1:].astype(float)
229
- else:
230
- all_names = [f"Series {i}" for i in range(len(df))]
231
- data_only = df.astype(float)
232
-
233
- # 4) Build mask from selected_names
234
- mask = [name in selected_names for name in all_names]
235
- if not any(mask):
236
- return None, "No timeseries chosen to plot."
237
-
238
- filtered_data = data_only.iloc[mask, :].values # shape = (n_selected, seq_len)
239
- filtered_names = [all_names[i] for i, m in enumerate(mask) if m]
240
- n_selected = filtered_data.shape[0]
241
- if n_selected>30:
242
- raise gr.Error("Maximum of 30 timeseries (rows) is possible to choose", duration=5)
243
-
244
- # 5) First call model_forecast on all series, then select only the masked rows
245
- full_data = data_only.values # shape = (n_all, seq_len)
246
- full_out = model_forecast(full_data, forecast_length=forecast_length, file_name=file_name)
247
-
248
- # Now pick only the rows we actually filtered
249
- out = full_out[mask, :, :] # shape = (n_selected, pred_len, n_q)
250
- inp = torch.tensor(filtered_data)
251
-
252
- # 6) Create one subplot per selected series, with vertical spacing
253
- subplot_height_px = 350 # px per subplot
254
- n_selected = len(filtered_names)
255
- fig = make_subplots(
256
- rows=n_selected,
257
- cols=1,
258
- shared_xaxes=False,
259
- subplot_titles=filtered_names,
260
- row_heights=[1] * n_selected, # all rows equal height
261
- )
262
- fig.update_layout(
263
- height=subplot_height_px * n_selected,
264
- template="plotly_dark",
265
- margin=dict(t=50, b=50)
266
- )
267
-
268
- for idx in range(n_selected):
269
- ts = inp[idx].numpy().tolist()
270
- qp = out[idx].numpy()
271
- series_name = filtered_names[idx]
272
-
273
- # a) plot historical data (blue line)
274
- x_hist = list(range(len(ts)))
275
- fig.add_trace(
276
- go.Scatter(
277
- x=x_hist,
278
- y=ts,
279
- mode="lines",
280
- name=f"{series_name} – Given Data",
281
- line=dict(color="blue", width=2),
282
- showlegend=False
283
- ),
284
- row=idx + 1, col=1
285
- )
286
-
287
- # b) compute forecast indices
288
- pred_len = qp.shape[0]
289
- x_pred = list(range(len(ts) - 1, len(ts) - 1 + pred_len))
290
-
291
- lower_q = qp[:, 0]
292
- upper_q = qp[:, -1]
293
- n_q = qp.shape[1]
294
- median_idx = n_q // 2
295
- median_q = qp[:, median_idx]
296
-
297
- # c) lower‐bound (invisible)
298
- fig.add_trace(
299
- go.Scatter(
300
- x=x_pred,
301
- y=lower_q,
302
- mode="lines",
303
- line=dict(color="rgba(0,0,0,0)", width=0),
304
- name=f"{series_name} – 10% Quantile",
305
- hovertemplate="10% Quantile: %{y:.2f}<extra></extra>",
306
- showlegend=False
307
- ),
308
- row=idx + 1, col=1
309
- )
310
-
311
- # d) upper‐bound (shaded area)
312
- fig.add_trace(
313
- go.Scatter(
314
- x=x_pred,
315
- y=upper_q,
316
- mode="lines",
317
- line=dict(color="rgba(0,0,0,0)", width=0),
318
- fill="tonexty",
319
- fillcolor="rgba(128,128,128,0.3)",
320
- name=f"{series_name} – 90% Quantile",
321
- hovertemplate="90% Quantile: %{y:.2f}<extra></extra>",
322
- showlegend=False
323
- ),
324
- row=idx + 1, col=1
325
- )
326
-
327
- # e) median forecast (orange line)
328
- fig.add_trace(
329
- go.Scatter(
330
- x=x_pred,
331
- y=median_q,
332
- mode="lines",
333
- name=f"{series_name} – Median Forecast",
334
- line=dict(color="orange", width=2),
335
- hovertemplate="Median: %{y:.2f}<extra></extra>",
336
- showlegend=False
337
- ),
338
- row=idx + 1, col=1
339
- )
340
-
341
- # f) label axes for each subplot
342
- fig.update_xaxes(title_text="Time", row=idx + 1, col=1)
343
- fig.update_yaxes(title_text="Value", row=idx + 1, col=1)
344
-
345
- # 7) Global layout tweaks
346
- fig.update_layout(
347
- template="plotly_dark",
348
- height=300 * n_selected, # 300px per subplot
349
- title=dict(
350
- text="Forecasts for Selected Timeseries",
351
- x=0.5,
352
- font=dict(size=20, family="Arial", color="white")
353
- ),
354
- hovermode="x unified",
355
- margin=dict(t=120, b=40, l=60, r=40),
356
- showlegend=False
357
- )
358
-
359
- return fig, ""
360
- except gr.Error as e:
361
- raise gr.Error(e, duration=5)
362
-
363
- except Exception as e:
364
- return None, f"Error: {str(e)}"
365
-
366
-
367
-
368
- # ----------------------------
369
- # Gradio layout: two columns + instructions
370
- # ----------------------------
371
-
372
- with gr.Blocks(fill_width=True,theme=gr.themes.Ocean()) as demo:
373
- gr.Markdown("# 📈 TiRex - timeseries forecasting 📊")
374
- gr.Markdown("Upload data or choose a preset, filter by name, then click Plot.")
375
-
376
- with gr.Row():
377
- # Left column: controls
378
- with gr.Column(scale=1):
379
- gr.Markdown("## Data Selection")
380
- file_input = gr.File(
381
- label="Upload CSV / XLSX / PARQUET",
382
- file_types=[".csv", ".xls", ".xlsx", ".parquet"]
383
- )
384
- preset_choices = ["-- No preset selected --", "data/loop.csv", "data/air_passangers.csv", 'data/ett2.csv']
385
-
386
- preset_dropdown = gr.Dropdown(
387
- label="Or choose a preset:",
388
- choices=preset_choices,
389
- value="-- No preset selected --"
390
- )
391
-
392
- gr.Markdown("## Forecast Length Setting")
393
- forecast_length_slider = gr.Slider(
394
- minimum=1,
395
- maximum=512,
396
- value=64,
397
- step=1,
398
- label="Forecast Length (Steps)",
399
- info="Choose how many future steps to forecast."
400
- )
401
-
402
- gr.Markdown("## Search / Filter")
403
- search_box = gr.Textbox(placeholder="Type to filter (e.g. 'AMZN')")
404
- filter_checkbox = gr.CheckboxGroup(
405
- choices=[], value=[], label="Select which timeseries to show"
406
- )
407
-
408
- with gr.Row():
409
- check_all_btn = gr.Button("✅ Check All")
410
- uncheck_all_btn = gr.Button("❎ Uncheck All")
411
-
412
- plot_button = gr.Button("▶️ Plot Forecasts")
413
- errbox = gr.Textbox(label="Error Message", interactive=False)
414
- with gr.Row():
415
- gr.Image("static/nxai_logo.png", width=150, show_label=False, container=False)
416
- gr.Image("static/tirex.jpeg", width=150, show_label=False, container=False)
417
-
418
- with gr.Column(scale=5):
419
- gr.Markdown("## Forecast Plot")
420
- plot_output = gr.Plot()
421
-
422
- # Instruction text below plot
423
- gr.Markdown("## Instructions")
424
- gr.Markdown(
425
- """
426
- **How to format your data:**
427
- - Your file must be a table (CSV, XLS, XLSX, or Parquet).
428
- - **One row per timeseries.** Each row is treated as a separate series.
429
- - If you want to **name** each series, put the name as the first value in **every** row:
430
- - Example (CSV):
431
- `AAPL, 120.5, 121.0, 119.8, ...`
432
- `AMZN, 3300.0, 3310.5, 3295.2, ...`
433
- - In that case, the first column is not numeric, so it will be used as the series name.
434
- - If you do **not** want named series, simply leave out the first column entirely and have all values numeric:
435
- - Example:
436
- `120.5, 121.0, 119.8, ...`
437
- `3300.0, 3310.5, 3295.2, ...`
438
- - Then every row will be auto-named “Series 0, Series 1, …” in order.
439
- - **Consistency rule:** Either all rows have a non-numeric first entry for the name, or none do. Do not mix.
440
- - The rest of the columns (after the optional name) must be numeric data points for that series.
441
- - You can filter by typing in the search box. Then check or uncheck individual names before plotting.
442
- - Use “Check All” / “Uncheck All” to quickly select or deselect every series.
443
- - Finally, click **Plot Forecasts** to view the quantile forecast for each selected series (for 64 steps ahead).
444
- """
445
- )
446
- gr.Markdown("## Citation")
447
- # make citation as code block
448
- gr.Markdown(
449
- """
450
- If you use TiRex in your research, please cite our work:
451
- ```
452
- @article{auerTiRexZeroShotForecasting2025,
453
- title = {{{TiRex}}: {{Zero-Shot Forecasting Across Long}} and {{Short Horizons}} with {{Enhanced In-Context Learning}}},
454
- author = {Auer, Andreas and Podest, Patrick and Klotz, Daniel and B{\"o}ck, Sebastian and Klambauer, G{\"u}nter and Hochreiter, Sepp},
455
- journal = {ArXiv},
456
- volume = {2505.23719},
457
- year = {2025}
458
- }
459
- ```
460
- """
461
- )
462
-
463
- names_state = gr.State([])
464
- file_input.change(
465
- fn=extract_names_and_update,
466
- inputs=[file_input, preset_dropdown],
467
- outputs=[filter_checkbox, names_state, forecast_length_slider]
468
- )
469
- preset_dropdown.change(
470
- fn=extract_names_and_update,
471
- inputs=[file_input, preset_dropdown],
472
- outputs=[filter_checkbox, names_state, forecast_length_slider]
473
- )
474
-
475
- # When search term changes, filter names
476
- search_box.change(
477
- fn=filter_names,
478
- inputs=[search_box, names_state],
479
- outputs=[filter_checkbox]
480
- )
481
-
482
- # Check All / Uncheck All
483
- check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox)
484
- uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox)
485
-
486
- # Plot button
487
- plot_button.click(
488
- fn=display_filtered_forecast,
489
- inputs=[file_input, preset_dropdown, filter_checkbox, forecast_length_slider],
490
- outputs=[plot_output, errbox]
491
- )
492
- demo.launch()
493
-
494
-
495
- '''
496
- gradio app.py
497
- ssh -L 7860:localhost:7860 nikita_blago@oracle-gpu-controller -t \
498
- ssh -L 7860:localhost:7860 compute-permanent-node-83
499
- '''
500
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time # Import time to make logs more distinct
3
+
4
+ def greet(name):
5
+ """
6
+ This function takes a name as input and returns a personalized greeting string.
7
+ It now includes print statements for logging with flush=True to ensure
8
+ logs appear immediately in container environments like Hugging Face Spaces.
9
+ """
10
+ # Log the function entry
11
+ # The flush=True argument is crucial for logs to appear in real-time in Docker.
12
+ print(f"[{time.ctime()}] - Function 'greet' was called.", flush=True)
13
+
14
+ if name:
15
+ # Log the received input
16
+ print(f"[{time.ctime()}] - Received input name: '{name}'", flush=True)
17
+ return f"Hello, {name}! Welcome to your first Gradio app."
18
+ else:
19
+ # Log that the input was empty
20
+ print(f"[{time.ctime()}] - No input name received.", flush=True)
21
+ return "Hello! Please enter your name."
22
+
23
+ # Create the Gradio interface
24
+ app = gr.Interface(
25
+ fn=greet,
26
+ inputs=gr.Textbox(
27
+ lines=1,
28
+ placeholder="Please enter your name here...",
29
+ label="Your Name"
30
+ ),
31
+ outputs=gr.Text(label="Greeting"),
32
+ title="Simple Greeting App with Logging",
33
+ description="Enter your name to receive a greeting. Check the Hugging Face logs to see the output from the print() statements."
34
+ )
35
+
36
+ # Launch the application
37
+ if __name__ == "__main__":
38
+ print(f"[{time.ctime()}] - Starting Gradio server...", flush=True)
39
+ app.launch(server_name="0.0.0.0", server_port=7860)