qywok commited on
Commit
1e9352c
·
verified ·
1 Parent(s): f26f72c

Update restful/onnx_utilities.py

Browse files
Files changed (1) hide show
  1. restful/onnx_utilities.py +17 -6
restful/onnx_utilities.py CHANGED
@@ -38,16 +38,27 @@ class Utilities:
38
  predicted_prices = {}
39
  last_date = to_datetime(dataframe.index[-1])
40
 
41
- for _ in range(days):
42
- predicted = session.run(None, {input_name: lst_seq.astype(np.float32)})[0]
43
 
44
- denorm_price = self.denormalization(predicted[0][0], min_close, max_close)
45
 
46
- last_date += Timedelta(days=1)
47
- predicted_prices[last_date] = denorm_price.flatten()[0]
48
 
 
 
 
 
 
 
 
 
 
 
 
49
  lst_seq = np.roll(lst_seq, shift=-1, axis=1)
50
- lst_seq[:, -1, -1] = predicted[0][0][0]
51
 
52
  predictions = [
53
  {'date': date.strftime('%Y-%m-%d'), 'price': float(price)}
 
38
  predicted_prices = {}
39
  last_date = to_datetime(dataframe.index[-1])
40
 
41
+ # for _ in range(days):
42
+ # predicted = session.run(None, {input_name: lst_seq.astype(np.float32)})[0]
43
 
44
+ # denorm_price = self.denormalization(predicted[0][0], min_close, max_close)
45
 
46
+ # last_date += Timedelta(days=1)
47
+ # predicted_prices[last_date] = denorm_price.flatten()[0]
48
 
49
+ # lst_seq = np.roll(lst_seq, shift=-1, axis=1)
50
+ # lst_seq[:, -1, -1] = predicted[0][0][0]
51
+ for _ in range(days):
52
+ predicted = session.run(None, {input_name: lst_seq.astype(np.float32)})[0]
53
+
54
+ value = np.array(predicted).flatten()[0]
55
+ denorm_price = (value * (max_close - min_close)) + min_close
56
+
57
+ last_date += pd.Timedelta(days=1)
58
+ predicted_prices[last_date.strftime('%Y-%m-%d')] = float(denorm_price)
59
+
60
  lst_seq = np.roll(lst_seq, shift=-1, axis=1)
61
+ lst_seq[:, -1, -1] = value
62
 
63
  predictions = [
64
  {'date': date.strftime('%Y-%m-%d'), 'price': float(price)}