ttn0011 commited on
Commit
634b732
·
1 Parent(s): cf73df8

fix position_value to scalar

Browse files
Files changed (1) hide show
  1. rl_agent/env.py +3 -0
rl_agent/env.py CHANGED
@@ -1,5 +1,6 @@
1
  import numpy as np
2
  import pandas as pd
 
3
 
4
  class Environment:
5
 
@@ -37,6 +38,8 @@ class Environment:
37
  self.history.pop(0)
38
  self.history.append(self.data.iloc[self.t, :]['Close'] - self.data.iloc[(self.t-1), :]['Close']) # the price being traded
39
 
 
 
40
  return [self.position_value] + self.history, reward, self.done # obs, reward, done
41
 
42
 
 
1
  import numpy as np
2
  import pandas as pd
3
+ import torch
4
 
5
  class Environment:
6
 
 
38
  self.history.pop(0)
39
  self.history.append(self.data.iloc[self.t, :]['Close'] - self.data.iloc[(self.t-1), :]['Close']) # the price being traded
40
 
41
+ self.position_value = self.position_value.item()
42
+
43
  return [self.position_value] + self.history, reward, self.done # obs, reward, done
44
 
45