Spaces:
Running
on
T4
Running
on
T4
Nikita
commited on
Commit
·
398f918
1
Parent(s):
8623916
tirex forecasting in model_forecast()
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ from PIL import Image
|
|
6 |
import numpy as np
|
7 |
import gradio as gr
|
8 |
import os
|
|
|
9 |
|
10 |
# ----------------------------
|
11 |
# Helper functions (logic mostly unchanged)
|
@@ -21,10 +22,9 @@ def model_forecast(input_data, forecast_length=256, file_name=None):
|
|
21 |
_forecast_tensor = torch.load("data/air_passengers_forecast_256.pt")
|
22 |
return _forecast_tensor[:,:forecast_length,:]
|
23 |
else:
|
24 |
-
''
|
25 |
-
|
26 |
-
|
27 |
-
pass
|
28 |
|
29 |
|
30 |
|
|
|
6 |
import numpy as np
|
7 |
import gradio as gr
|
8 |
import os
|
9 |
+
from tirex import load_model, ForecastModel
|
10 |
|
11 |
# ----------------------------
|
12 |
# Helper functions (logic mostly unchanged)
|
|
|
22 |
_forecast_tensor = torch.load("data/air_passengers_forecast_256.pt")
|
23 |
return _forecast_tensor[:,:forecast_length,:]
|
24 |
else:
|
25 |
+
model: ForecastModel = load_model("NX-AI/TiRex",device='cuda')
|
26 |
+
forecast = model.forecast(context=input_data, prediction_length=forecast_length)
|
27 |
+
return forecast[0]
|
|
|
28 |
|
29 |
|
30 |
|