qywok commited on
Commit
d92ee7d
·
verified ·
1 Parent(s): 0e963ff

Update restful/onnx_utilities.py

Browse files
Files changed (1) hide show
  1. restful/onnx_utilities.py +30 -7
restful/onnx_utilities.py CHANGED
@@ -30,7 +30,12 @@ class Utilities:
30
  days: int, sequence_length: int, model_name: str) -> tuple:
31
 
32
  model_path = os.path.join(self.model_path, f'{model_name}.onnx')
33
- session = ort.InferenceSession(model_path)
 
 
 
 
 
34
  input_name = session.get_inputs()[0].name
35
 
36
  dataframe_path = os.path.join(self.posttrained_path, f'{model_name}.csv')
@@ -59,20 +64,38 @@ class Utilities:
59
 
60
  # lst_seq = np.roll(lst_seq, shift=-1, axis=1)
61
  # lst_seq[:, -1, -1] = predicted[0][0][0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  for _ in range(days):
63
  predicted = session.run(None, {input_name: lst_seq.astype(np.float32)})[0]
64
-
65
  value = np.array(predicted).flatten()[0]
66
- denorm_price = (value * (max_close - min_close)) + min_close
67
-
68
- # last_date += pd.Timedelta(days=1)
 
 
69
  last_date = pd.to_datetime(last_date) + pd.Timedelta(days=1)
70
- # predicted_prices[last_date.strftime('%Y-%m-%d')] = float(denorm_price)
71
  predicted_prices[last_date] = self.truncate_2_decimal(denorm_price)
72
-
73
  lst_seq = np.roll(lst_seq, shift=-1, axis=1)
74
  lst_seq[:, -1, -1] = value
75
 
 
 
76
  # predictions = [
77
  # {'date': date.strftime('%Y-%m-%d'), 'price': float(price)}
78
  # for date, price in predicted_prices.items()
 
30
  days: int, sequence_length: int, model_name: str) -> tuple:
31
 
32
  model_path = os.path.join(self.model_path, f'{model_name}.onnx')
33
+ # session = ort.InferenceSession(model_path)
34
+ try:
35
+ session = ort.InferenceSession(model_path)
36
+ except Exception as e:
37
+ print("ONNX model load error:", e)
38
+ return [], []
39
  input_name = session.get_inputs()[0].name
40
 
41
  dataframe_path = os.path.join(self.posttrained_path, f'{model_name}.csv')
 
64
 
65
  # lst_seq = np.roll(lst_seq, shift=-1, axis=1)
66
  # lst_seq[:, -1, -1] = predicted[0][0][0]
67
+
68
+
69
+
70
+ # for _ in range(days):
71
+ # predicted = session.run(None, {input_name: lst_seq.astype(np.float32)})[0]
72
+
73
+ # value = np.array(predicted).flatten()[0]
74
+ # denorm_price = (value * (max_close - min_close)) + min_close
75
+
76
+ # # last_date += pd.Timedelta(days=1)
77
+ # last_date = pd.to_datetime(last_date) + pd.Timedelta(days=1)
78
+ # # predicted_prices[last_date.strftime('%Y-%m-%d')] = float(denorm_price)
79
+ # predicted_prices[last_date] = self.truncate_2_decimal(denorm_price)
80
+
81
+ # lst_seq = np.roll(lst_seq, shift=-1, axis=1)
82
+ # lst_seq[:, -1, -1] = value
83
+
84
  for _ in range(days):
85
  predicted = session.run(None, {input_name: lst_seq.astype(np.float32)})[0]
 
86
  value = np.array(predicted).flatten()[0]
87
+ if np.isnan(value):
88
+ continue
89
+ denorm_price = self.denormalization(value, min_close, max_close)
90
+ if np.isnan(denorm_price):
91
+ continue
92
  last_date = pd.to_datetime(last_date) + pd.Timedelta(days=1)
 
93
  predicted_prices[last_date] = self.truncate_2_decimal(denorm_price)
 
94
  lst_seq = np.roll(lst_seq, shift=-1, axis=1)
95
  lst_seq[:, -1, -1] = value
96
 
97
+
98
+
99
  # predictions = [
100
  # {'date': date.strftime('%Y-%m-%d'), 'price': float(price)}
101
  # for date, price in predicted_prices.items()