qywok commited on
Commit
4058d83
·
verified ·
1 Parent(s): e3a9758

Create onnx_utilities.py

Browse files
Files changed (1) hide show
  1. restful/onnx_utilities.py +66 -0
restful/onnx_utilities.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import pandas as pd
5
+ import onnxruntime as ort
6
+ from numpy import append, expand_dims
7
+ from pandas import read_csv, to_datetime, Timedelta
8
+
9
+ class Utilities:
10
+ def __init__(self) -> None:
11
+ self.model_path = './models'
12
+ self.posttrained_path = './indonesia_stocks/modeling_datas'
13
+ self.scaler_path = './indonesia_stocks/min_max'
14
+
15
+ def denormalization(self, data, min_value, max_value):
16
+ return (data * (max_value - min_value)) + min_value
17
+
18
+ async def cryptocurrency_prediction_utils(self,
19
+ days: int, sequence_length: int, model_name: str) -> tuple:
20
+
21
+ model_path = os.path.join(self.model_path, f'{model_name}.onnx')
22
+ session = ort.InferenceSession(model_path)
23
+ input_name = session.get_inputs()[0].name
24
+
25
+ dataframe_path = os.path.join(self.posttrained_path, f'{model_name}.csv')
26
+ dataframe = read_csv(dataframe_path, index_col='Date', parse_dates=True)
27
+
28
+ scaler_path = os.path.join(self.scaler_path, f'{model_name}.json')
29
+ with open(scaler_path, 'r') as f:
30
+ scalers = json.load(f)
31
+
32
+ min_close = scalers['min_value']['Close']
33
+ max_close = scalers['max_value']['Close']
34
+
35
+ lst_seq = dataframe[-sequence_length:].values
36
+ lst_seq = expand_dims(lst_seq, axis=0)
37
+
38
+ predicted_prices = {}
39
+ last_date = to_datetime(dataframe.index[-1])
40
+
41
+ for _ in range(days):
42
+ predicted = session.run(None, {input_name: lst_seq.astype(np.float32)})[0]
43
+
44
+ denorm_price = self.denormalization(predicted[0][0], min_close, max_close)
45
+
46
+ last_date += Timedelta(days=1)
47
+ predicted_prices[last_date] = denorm_price.flatten()[0]
48
+
49
+ lst_seq = np.roll(lst_seq, shift=-1, axis=1)
50
+ lst_seq[:, -1, -1] = predicted[0][0][0]
51
+
52
+ predictions = [
53
+ {'date': date.strftime('%Y-%m-%d'), 'price': float(price)}
54
+ for date, price in predicted_prices.items()
55
+ ]
56
+
57
+ df_date = dataframe.index[-sequence_length:]
58
+ close_values = dataframe.iloc[-sequence_length:]['Close'].values
59
+ close_denorm = self.denormalization(close_values, min_close, max_close)
60
+
61
+ actuals = [
62
+ {'date': to_datetime(date).strftime('%Y-%m-%d'), 'price': float(price)}
63
+ for date, price in zip(df_date, close_denorm)
64
+ ]
65
+
66
+ return actuals, predictions