PZR0033 commited on
Commit
7844376
·
1 Parent(s): 4d6998d

created agent directory containing policy code

Browse files
Files changed (1) hide show
  1. rl_agent/policy.py +27 -0
rl_agent/policy.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class Policy(nn.Module):
7
+ def __init__(self, input_channels=8):
8
+
9
+ super(Policy, self).__init__()
10
+
11
+ self.layer1 = nn.Linear(input_channels, 2 * input_channels)
12
+ self.tanh1 = nn.Tanh()
13
+ self.layer2 = nn.linear(2 * input_channels, 1)
14
+ self.tanh2 = nn.Tanh()
15
+
16
+ def forward(self, state):
17
+
18
+ hidden = self.layer1(state)
19
+ hidden = self.tanh1(hidden)
20
+ hidden = self.layer2(hidden)
21
+ action = self.tanh2(hidden)
22
+
23
+ return action
24
+
25
+
26
+
27
+