Spaces:
Running
Running
Update restful/onnx_utilities.py
Browse files- restful/onnx_utilities.py +30 -7
restful/onnx_utilities.py
CHANGED
@@ -30,7 +30,12 @@ class Utilities:
|
|
30 |
days: int, sequence_length: int, model_name: str) -> tuple:
|
31 |
|
32 |
model_path = os.path.join(self.model_path, f'{model_name}.onnx')
|
33 |
-
session = ort.InferenceSession(model_path)
|
|
|
|
|
|
|
|
|
|
|
34 |
input_name = session.get_inputs()[0].name
|
35 |
|
36 |
dataframe_path = os.path.join(self.posttrained_path, f'{model_name}.csv')
|
@@ -59,20 +64,38 @@ class Utilities:
|
|
59 |
|
60 |
# lst_seq = np.roll(lst_seq, shift=-1, axis=1)
|
61 |
# lst_seq[:, -1, -1] = predicted[0][0][0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
for _ in range(days):
|
63 |
predicted = session.run(None, {input_name: lst_seq.astype(np.float32)})[0]
|
64 |
-
|
65 |
value = np.array(predicted).flatten()[0]
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
69 |
last_date = pd.to_datetime(last_date) + pd.Timedelta(days=1)
|
70 |
-
# predicted_prices[last_date.strftime('%Y-%m-%d')] = float(denorm_price)
|
71 |
predicted_prices[last_date] = self.truncate_2_decimal(denorm_price)
|
72 |
-
|
73 |
lst_seq = np.roll(lst_seq, shift=-1, axis=1)
|
74 |
lst_seq[:, -1, -1] = value
|
75 |
|
|
|
|
|
76 |
# predictions = [
|
77 |
# {'date': date.strftime('%Y-%m-%d'), 'price': float(price)}
|
78 |
# for date, price in predicted_prices.items()
|
|
|
30 |
days: int, sequence_length: int, model_name: str) -> tuple:
|
31 |
|
32 |
model_path = os.path.join(self.model_path, f'{model_name}.onnx')
|
33 |
+
# session = ort.InferenceSession(model_path)
|
34 |
+
try:
|
35 |
+
session = ort.InferenceSession(model_path)
|
36 |
+
except Exception as e:
|
37 |
+
print("ONNX model load error:", e)
|
38 |
+
return [], []
|
39 |
input_name = session.get_inputs()[0].name
|
40 |
|
41 |
dataframe_path = os.path.join(self.posttrained_path, f'{model_name}.csv')
|
|
|
64 |
|
65 |
# lst_seq = np.roll(lst_seq, shift=-1, axis=1)
|
66 |
# lst_seq[:, -1, -1] = predicted[0][0][0]
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
# for _ in range(days):
|
71 |
+
# predicted = session.run(None, {input_name: lst_seq.astype(np.float32)})[0]
|
72 |
+
|
73 |
+
# value = np.array(predicted).flatten()[0]
|
74 |
+
# denorm_price = (value * (max_close - min_close)) + min_close
|
75 |
+
|
76 |
+
# # last_date += pd.Timedelta(days=1)
|
77 |
+
# last_date = pd.to_datetime(last_date) + pd.Timedelta(days=1)
|
78 |
+
# # predicted_prices[last_date.strftime('%Y-%m-%d')] = float(denorm_price)
|
79 |
+
# predicted_prices[last_date] = self.truncate_2_decimal(denorm_price)
|
80 |
+
|
81 |
+
# lst_seq = np.roll(lst_seq, shift=-1, axis=1)
|
82 |
+
# lst_seq[:, -1, -1] = value
|
83 |
+
|
84 |
for _ in range(days):
|
85 |
predicted = session.run(None, {input_name: lst_seq.astype(np.float32)})[0]
|
|
|
86 |
value = np.array(predicted).flatten()[0]
|
87 |
+
if np.isnan(value):
|
88 |
+
continue
|
89 |
+
denorm_price = self.denormalization(value, min_close, max_close)
|
90 |
+
if np.isnan(denorm_price):
|
91 |
+
continue
|
92 |
last_date = pd.to_datetime(last_date) + pd.Timedelta(days=1)
|
|
|
93 |
predicted_prices[last_date] = self.truncate_2_decimal(denorm_price)
|
|
|
94 |
lst_seq = np.roll(lst_seq, shift=-1, axis=1)
|
95 |
lst_seq[:, -1, -1] = value
|
96 |
|
97 |
+
|
98 |
+
|
99 |
# predictions = [
|
100 |
# {'date': date.strftime('%Y-%m-%d'), 'price': float(price)}
|
101 |
# for date, price in predicted_prices.items()
|