Nikita commited on
Commit
398f918
·
1 Parent(s): 8623916

tirex forecasting in model_forecast()

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -6,6 +6,7 @@ from PIL import Image
6
  import numpy as np
7
  import gradio as gr
8
  import os
 
9
 
10
  # ----------------------------
11
  # Helper functions (logic mostly unchanged)
@@ -21,10 +22,9 @@ def model_forecast(input_data, forecast_length=256, file_name=None):
21
  _forecast_tensor = torch.load("data/air_passengers_forecast_256.pt")
22
  return _forecast_tensor[:,:forecast_length,:]
23
  else:
24
- '''
25
- TODO: implement the model forecast for custom data
26
- '''
27
- pass
28
 
29
 
30
 
 
6
  import numpy as np
7
  import gradio as gr
8
  import os
9
+ from tirex import load_model, ForecastModel
10
 
11
  # ----------------------------
12
  # Helper functions (logic mostly unchanged)
 
22
  _forecast_tensor = torch.load("data/air_passengers_forecast_256.pt")
23
  return _forecast_tensor[:,:forecast_length,:]
24
  else:
25
+ model: ForecastModel = load_model("NX-AI/TiRex",device='cuda')
26
+ forecast = model.forecast(context=input_data, prediction_length=forecast_length)
27
+ return forecast[0]
 
28
 
29
 
30