Spaces:
Runtime error
Runtime error
PZR0033
commited on
Commit
·
7844376
1
Parent(s):
4d6998d
created agent directory containing policy code
Browse files- rl_agent/policy.py +27 -0
rl_agent/policy.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class Policy(nn.Module):
|
7 |
+
def __init__(self, input_channels=8):
|
8 |
+
|
9 |
+
super(Policy, self).__init__()
|
10 |
+
|
11 |
+
self.layer1 = nn.Linear(input_channels, 2 * input_channels)
|
12 |
+
self.tanh1 = nn.Tanh()
|
13 |
+
self.layer2 = nn.linear(2 * input_channels, 1)
|
14 |
+
self.tanh2 = nn.Tanh()
|
15 |
+
|
16 |
+
def forward(self, state):
|
17 |
+
|
18 |
+
hidden = self.layer1(state)
|
19 |
+
hidden = self.tanh1(hidden)
|
20 |
+
hidden = self.layer2(hidden)
|
21 |
+
action = self.tanh2(hidden)
|
22 |
+
|
23 |
+
return action
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
|