Blago123 commited on
Commit
0fb9496
·
1 Parent(s): 5e90f01

predicting last forecast_length (not tested yet)

Browse files
app.py CHANGED
@@ -19,116 +19,18 @@ model: ForecastModel = load_model("NX-AI/TiRex",device='cuda')
19
 
20
  def model_forecast(input_data, forecast_length=256, file_name=None):
21
  if os.path.basename(file_name) == "loop.csv":
22
- _forecast_tensor = torch.load("data/loop_forecast_512.pt")
23
  return _forecast_tensor[:,:forecast_length,:]
24
  elif os.path.basename(file_name) == "ett2.csv":
25
- _forecast_tensor = torch.load("data/ett2_forecast_512.pt")
26
  return _forecast_tensor[:,:forecast_length,:]
27
  elif os.path.basename(file_name) == "air_passangers.csv":
28
- _forecast_tensor = torch.load("data/air_passengers_forecast_512.pt")
29
  return _forecast_tensor[:,:forecast_length,:]
30
  else:
31
  forecast = model.forecast(context=input_data, prediction_length=forecast_length)
32
  return forecast[0]
33
 
34
-
35
-
36
- def plot_forecast_plotly(timeseries, quantile_predictions, timeseries_name):
37
- """
38
- - timeseries: 1D list/array of historical values.
39
- - quantile_predictions: 2D array of shape (pred_len, n_q),
40
- with quantiles sorted left→right.
41
- - timeseries_name: string label.
42
- """
43
- fig = go.Figure()
44
-
45
- # 1) Plot historical data (blue line, no markers)
46
- x_hist = list(range(len(timeseries)))
47
- fig.add_trace(go.Scatter(
48
- x=x_hist,
49
- y=timeseries,
50
- mode="lines", # no markers
51
- name=f"{timeseries_name} – Given Data",
52
- line=dict(color="blue", width=2),
53
- ))
54
-
55
- # 2) X-axis indices for forecasts
56
- pred_len = quantile_predictions.shape[0]
57
- x_pred = list(range(len(timeseries) - 1, len(timeseries) - 1 + pred_len))
58
-
59
- # 3) Extract lower, upper, and median quantiles
60
- lower_q = quantile_predictions[:, 0]
61
- upper_q = quantile_predictions[:, -1]
62
- n_q = quantile_predictions.shape[1]
63
- median_idx = n_q // 2
64
- median_q = quantile_predictions[:, median_idx]
65
-
66
- # 4) Lower‐bound trace (invisible line, still shows on hover)
67
- fig.add_trace(go.Scatter(
68
- x=x_pred,
69
- y=lower_q,
70
- mode="lines",
71
- line=dict(color="rgba(0, 0, 0, 0)", width=0),
72
- name=f"{timeseries_name} – 10% Quantile",
73
- hovertemplate="Lower: %{y:.2f}<extra></extra>"
74
- ))
75
-
76
- # 5) Upper‐bound trace (shaded down to lower_q)
77
- fig.add_trace(go.Scatter(
78
- x=x_pred,
79
- y=upper_q,
80
- mode="lines",
81
- line=dict(color="rgba(0, 0, 0, 0)", width=0),
82
- fill="tonexty",
83
- fillcolor="rgba(128, 128, 128, 0.3)",
84
- name=f"{timeseries_name} – 90% Quantile",
85
- hovertemplate="Upper: %{y:.2f}<extra></extra>"
86
- ))
87
-
88
- # 6) Median trace (orange) on top
89
- fig.add_trace(go.Scatter(
90
- x=x_pred,
91
- y=median_q,
92
- mode="lines",
93
- name=f"{timeseries_name} – Median Forecast",
94
- line=dict(color="orange", width=2),
95
- hovertemplate="Median: %{y:.2f}<extra></extra>"
96
- ))
97
-
98
- # 7) Layout: title on left (y=0.95), legend on right (y=0.95)
99
- fig.update_layout(
100
- template="plotly_dark",
101
- title=dict(
102
- text=f"Timeseries: {timeseries_name}",
103
- x=0.10, # left‐align
104
- xanchor="left",
105
- y=0.90, # near top
106
- yanchor="bottom",
107
- font=dict(size=18, family="Arial", color="white")
108
- ),
109
- xaxis=dict(
110
- rangeslider=dict(visible=True), # <-- put rangeslider here
111
- fixedrange=False
112
- ),
113
- xaxis_title="Time",
114
- yaxis_title="Value",
115
- hovermode="x unified",
116
- margin=dict(
117
- t=120, # increase top margin to fit title+legend comfortably
118
- b=40,
119
- l=60,
120
- r=40
121
- ),
122
- # height=plot_height,
123
- # width=plot_width,
124
- autosize=True,
125
- )
126
-
127
- return fig
128
-
129
-
130
-
131
-
132
 
133
  def load_table(file_path):
134
  ext = file_path.split(".")[-1].lower()
@@ -225,13 +127,19 @@ def display_filtered_forecast(file, preset_filename, selected_names, forecast_le
225
  df = pd.concat([ df.iloc[:, [0]], df.iloc[:, -2048:] ], axis=1)
226
  gr.Info("Maximum of 2048 steps per timeseries (row) is allowed, hence last 2048 kept. ℹ️", duration=5)
227
  all_names = df.iloc[:, 0].tolist()
228
- data_only = df.iloc[:, 1:].astype(float)
229
  else:
230
  if df.shape[1]>2048:
231
  df = df.iloc[:, -2048:]
232
  gr.Info("Maximum of 2048 steps per timeseries (row) is allowed, hence last 2048 kept. ℹ️", duration=5)
233
  all_names = [f"Series {i}" for i in range(len(df))]
234
- data_only = df.astype(float)
 
 
 
 
 
 
235
 
236
  # 4) Build mask from selected_names
237
  mask = [name in selected_names for name in all_names]
@@ -239,6 +147,8 @@ def display_filtered_forecast(file, preset_filename, selected_names, forecast_le
239
  return None, "No timeseries chosen to plot."
240
 
241
  filtered_data = data_only.iloc[mask, :].values # shape = (n_selected, seq_len)
 
 
242
  filtered_names = [all_names[i] for i, m in enumerate(mask) if m]
243
  n_selected = filtered_data.shape[0]
244
  if n_selected>30:
@@ -251,6 +161,7 @@ def display_filtered_forecast(file, preset_filename, selected_names, forecast_le
251
  # Now pick only the rows we actually filtered
252
  out = full_out[mask, :, :] # shape = (n_selected, pred_len, n_q)
253
  inp = torch.tensor(filtered_data)
 
254
 
255
  # 6) Create one subplot per selected series, with vertical spacing
256
  subplot_height_px = 350 # px per subplot
@@ -270,15 +181,16 @@ def display_filtered_forecast(file, preset_filename, selected_names, forecast_le
270
 
271
  for idx in range(n_selected):
272
  ts = inp[idx].numpy().tolist()
 
273
  qp = out[idx].numpy()
274
  series_name = filtered_names[idx]
275
 
276
  # a) plot historical data (blue line)
277
- x_hist = list(range(len(ts)))
278
  fig.add_trace(
279
  go.Scatter(
280
  x=x_hist,
281
- y=ts,
282
  mode="lines",
283
  name=f"{series_name} – Given Data",
284
  line=dict(color="blue", width=2),
@@ -290,6 +202,8 @@ def display_filtered_forecast(file, preset_filename, selected_names, forecast_le
290
  # b) compute forecast indices
291
  pred_len = qp.shape[0]
292
  x_pred = list(range(len(ts) - 1, len(ts) - 1 + pred_len))
 
 
293
 
294
  lower_q = qp[:, 0]
295
  upper_q = qp[:, -1]
@@ -340,6 +254,7 @@ def display_filtered_forecast(file, preset_filename, selected_names, forecast_le
340
  ),
341
  row=idx + 1, col=1
342
  )
 
343
 
344
  # f) label axes for each subplot
345
  fig.update_xaxes(title_text="Time", row=idx + 1, col=1)
@@ -498,5 +413,5 @@ with gr.Blocks(fill_width=True,theme=gr.themes.Ocean()) as demo:
498
  '''
499
  gradio app.py
500
  ssh -L 7860:localhost:7860 nikita_blago@oracle-gpu-controller -t \
501
- ssh -L 7860:localhost:7860 compute-permanent-node-83
502
  '''
 
19
 
20
  def model_forecast(input_data, forecast_length=256, file_name=None):
21
  if os.path.basename(file_name) == "loop.csv":
22
+ _forecast_tensor = torch.load("data/loop_forecast_256.pt")
23
  return _forecast_tensor[:,:forecast_length,:]
24
  elif os.path.basename(file_name) == "ett2.csv":
25
+ _forecast_tensor = torch.load("data/ett2_forecast_256.pt")
26
  return _forecast_tensor[:,:forecast_length,:]
27
  elif os.path.basename(file_name) == "air_passangers.csv":
28
+ _forecast_tensor = torch.load("data/air_passengers_forecast_48.pt")
29
  return _forecast_tensor[:,:forecast_length,:]
30
  else:
31
  forecast = model.forecast(context=input_data, prediction_length=forecast_length)
32
  return forecast[0]
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def load_table(file_path):
36
  ext = file_path.split(".")[-1].lower()
 
127
  df = pd.concat([ df.iloc[:, [0]], df.iloc[:, -2048:] ], axis=1)
128
  gr.Info("Maximum of 2048 steps per timeseries (row) is allowed, hence last 2048 kept. ℹ️", duration=5)
129
  all_names = df.iloc[:, 0].tolist()
130
+ data_only_full = df.iloc[:, 1:].astype(float)
131
  else:
132
  if df.shape[1]>2048:
133
  df = df.iloc[:, -2048:]
134
  gr.Info("Maximum of 2048 steps per timeseries (row) is allowed, hence last 2048 kept. ℹ️", duration=5)
135
  all_names = [f"Series {i}" for i in range(len(df))]
136
+ data_only_full = df.astype(float)
137
+
138
+ # ** Cut timeseries into 2 series, context and prediction
139
+ if data_only_full.shape[1]<forecast_length+10:
140
+ raise gr.Error("Timeseries should have the minimum length of (forecast_length+10)!", duration=5)
141
+ y_true = data_only_full.iloc[:, -forecast_length:]
142
+ data_only = data_only_full.iloc[:, :-forecast_length]
143
 
144
  # 4) Build mask from selected_names
145
  mask = [name in selected_names for name in all_names]
 
147
  return None, "No timeseries chosen to plot."
148
 
149
  filtered_data = data_only.iloc[mask, :].values # shape = (n_selected, seq_len)
150
+ filtered_data_only_full = data_only_full.iloc[mask, :].values # ** Added to show prediction accuracy
151
+
152
  filtered_names = [all_names[i] for i, m in enumerate(mask) if m]
153
  n_selected = filtered_data.shape[0]
154
  if n_selected>30:
 
161
  # Now pick only the rows we actually filtered
162
  out = full_out[mask, :, :] # shape = (n_selected, pred_len, n_q)
163
  inp = torch.tensor(filtered_data)
164
+ inp_full = torch.tensor(filtered_data_only_full) # ** Added to show prediction accuracy
165
 
166
  # 6) Create one subplot per selected series, with vertical spacing
167
  subplot_height_px = 350 # px per subplot
 
181
 
182
  for idx in range(n_selected):
183
  ts = inp[idx].numpy().tolist()
184
+ ts_full = inp_full[idx].numpy().tolist()
185
  qp = out[idx].numpy()
186
  series_name = filtered_names[idx]
187
 
188
  # a) plot historical data (blue line)
189
+ x_hist = list(range(len(ts_full)))
190
  fig.add_trace(
191
  go.Scatter(
192
  x=x_hist,
193
+ y=ts_full,
194
  mode="lines",
195
  name=f"{series_name} – Given Data",
196
  line=dict(color="blue", width=2),
 
202
  # b) compute forecast indices
203
  pred_len = qp.shape[0]
204
  x_pred = list(range(len(ts) - 1, len(ts) - 1 + pred_len))
205
+ #x_pred = list(range(len(ts), len(ts) + pred_len))
206
+
207
 
208
  lower_q = qp[:, 0]
209
  upper_q = qp[:, -1]
 
254
  ),
255
  row=idx + 1, col=1
256
  )
257
+
258
 
259
  # f) label axes for each subplot
260
  fig.update_xaxes(title_text="Time", row=idx + 1, col=1)
 
413
  '''
414
  gradio app.py
415
  ssh -L 7860:localhost:7860 nikita_blago@oracle-gpu-controller -t \
416
+ ssh -L 7860:localhost:7860 compute-permanent-node-368
417
  '''
data/air_passengers_forecast_48.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fd1d2ebfd39d5025f04b97c40469b4fff7a0ee5577a090cb854acaa69c36e9c
3
+ size 3067
data/air_passengers_forecast_512.pt DELETED
Binary file (19.8 kB)
 
data/ett2_forecast_256.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:925cded1905836344ca55b2c2e1d7ac1ab06cd871dd39fe6228bb92652229d14
3
+ size 19662
data/ett2_forecast_512.pt DELETED
Binary file (38.1 kB)
 
data/loop_forecast_256.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33aa603d5d34a87b8801242419c831980e95276e2226c00154c00ba09c2cb1d4
3
+ size 19662
data/loop_forecast_512.pt DELETED
Binary file (38.1 kB)
 
test.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
test_app.py DELETED
@@ -1,39 +0,0 @@
1
- import gradio as gr
2
- import time # Import time to make logs more distinct
3
-
4
- def greet(name):
5
- """
6
- This function takes a name as input and returns a personalized greeting string.
7
- It now includes print statements for logging with flush=True to ensure
8
- logs appear immediately in container environments like Hugging Face Spaces.
9
- """
10
- # Log the function entry
11
- # The flush=True argument is crucial for logs to appear in real-time in Docker.
12
- print(f"[{time.ctime()}] - Function 'greet' was called.", flush=True)
13
-
14
- if name:
15
- # Log the received input
16
- print(f"[{time.ctime()}] - Received input name: '{name}'", flush=True)
17
- return f"Hello, {name}! Welcome to your first Gradio app."
18
- else:
19
- # Log that the input was empty
20
- print(f"[{time.ctime()}] - No input name received.", flush=True)
21
- return "Hello! Please enter your name."
22
-
23
- # Create the Gradio interface
24
- app = gr.Interface(
25
- fn=greet,
26
- inputs=gr.Textbox(
27
- lines=1,
28
- placeholder="Please enter your name here...",
29
- label="Your Name"
30
- ),
31
- outputs=gr.Text(label="Greeting"),
32
- title="Simple Greeting App with Logging",
33
- description="Enter your name to receive a greeting. Check the Hugging Face logs to see the output from the print() statements."
34
- )
35
-
36
- # Launch the application
37
- if __name__ == "__main__":
38
- print(f"[{time.ctime()}] - Starting Gradio server...", flush=True)
39
- app.launch(server_name="0.0.0.0", server_port=7860)