Delete train.py
Browse files
train.py
DELETED
@@ -1,54 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
from torch.nn.utils import clip_grad_norm_
|
4 |
-
from torch.utils.data import DataLoader
|
5 |
-
from just_time_windows.Actor.actor import Actor
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
def train_batch(actor, baseline, batch, optimizer, gradient_clipping=True, comparison_model=None, compute_cost_ratio=True):
|
11 |
-
|
12 |
-
device = actor.device
|
13 |
-
|
14 |
-
actor.train_mode()
|
15 |
-
actor.train()
|
16 |
-
actor_output = actor(batch)
|
17 |
-
actor_cost, log_probs = actor_output['total_time'], actor_output['log_probs']
|
18 |
-
|
19 |
-
|
20 |
-
with torch.no_grad():
|
21 |
-
baseline.greedy_search()
|
22 |
-
baseline_output = baseline(batch)
|
23 |
-
baseline_cost = baseline_output['total_time']
|
24 |
-
|
25 |
-
loss = ((actor_cost - baseline_cost).detach() * log_probs).mean()
|
26 |
-
|
27 |
-
optimizer.zero_grad()
|
28 |
-
loss.backward()
|
29 |
-
|
30 |
-
if gradient_clipping:
|
31 |
-
for group in optimizer.param_groups:
|
32 |
-
clip_grad_norm_(
|
33 |
-
group['params'],
|
34 |
-
1,
|
35 |
-
norm_type=2
|
36 |
-
)
|
37 |
-
|
38 |
-
optimizer.step()
|
39 |
-
|
40 |
-
if compute_cost_ratio and (comparison_model is None):
|
41 |
-
normalize = actor.apply_normalization
|
42 |
-
comparison_model = Actor(model=None, num_neighbors_action=1, normalize=normalize, device=device)
|
43 |
-
|
44 |
-
if compute_cost_ratio:
|
45 |
-
with torch.no_grad():
|
46 |
-
comp_output = comparison_model(batch)
|
47 |
-
comp_cost = comp_output['total_time']
|
48 |
-
|
49 |
-
a = comp_cost.sum().item()
|
50 |
-
b = actor_cost.sum().item()
|
51 |
-
return b / a
|
52 |
-
else:
|
53 |
-
return None
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|