Update run.py
Browse files
run.py
CHANGED
@@ -92,7 +92,11 @@ actor = Actor(model=model, num_movers=num_movers,
|
|
92 |
device=device, normalize=False)
|
93 |
actor.train_mode()
|
94 |
|
95 |
-
baseline_model = Model(
|
|
|
|
|
|
|
|
|
96 |
baseline_actor = Actor(model=baseline_model, num_movers=num_movers,
|
97 |
num_neighbors_encoder=num_neighbors_encoder,
|
98 |
num_neighbors_action=num_neighbors_action,
|
|
|
92 |
device=device, normalize=False)
|
93 |
actor.train_mode()
|
94 |
|
95 |
+
baseline_model = Model(
|
96 |
+
input_size=input_size,
|
97 |
+
embedding_size=embedding_size,
|
98 |
+
decoder_input_size=params["decoder_input_size"]
|
99 |
+
)
|
100 |
baseline_actor = Actor(model=baseline_model, num_movers=num_movers,
|
101 |
num_neighbors_encoder=num_neighbors_encoder,
|
102 |
num_neighbors_action=num_neighbors_action,
|