Upload 3 files
Browse files- baseline.py +41 -0
- train.py +54 -0
- validation.py +24 -0
baseline.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn.utils import clip_grad_norm_
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
|
6 |
+
|
7 |
+
def update_baseline(actor, baseline, validation_set, record_scores, batch_size=100, threshold=0.95):
|
8 |
+
|
9 |
+
val_dataloader = DataLoader(dataset=validation_set,
|
10 |
+
batch_size=batch_size,
|
11 |
+
collate_fn=validation_set.collate)
|
12 |
+
|
13 |
+
actor.greedy_search()
|
14 |
+
actor.eval()
|
15 |
+
|
16 |
+
actor_scores = []
|
17 |
+
for batch in val_dataloader:
|
18 |
+
with torch.no_grad():
|
19 |
+
actor_output = actor(batch)
|
20 |
+
actor_cost = actor_output['total_time']
|
21 |
+
actor_cost.reshape(-1)
|
22 |
+
actor_scores.append(actor_cost)
|
23 |
+
actor_scores = torch.cat(actor_scores, dim=0)
|
24 |
+
|
25 |
+
|
26 |
+
if record_scores is None:
|
27 |
+
baseline.load_state_dict(actor.state_dict())
|
28 |
+
record_scores = actor_scores
|
29 |
+
return record_scores
|
30 |
+
else:
|
31 |
+
|
32 |
+
if actor_scores.mean().item() < record_scores.mean().item():
|
33 |
+
print('\n', flush=True)
|
34 |
+
print('baseline updated', flush=True)
|
35 |
+
print('\n', flush=True)
|
36 |
+
|
37 |
+
baseline.load_state_dict(actor.state_dict())
|
38 |
+
record_scores = actor_scores
|
39 |
+
return record_scores
|
40 |
+
else:
|
41 |
+
return record_scores
|
train.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn.utils import clip_grad_norm_
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from just_time_windows.Actor.actor import Actor
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
def train_batch(actor, baseline, batch, optimizer, gradient_clipping=True, comparison_model=None, compute_cost_ratio=True):
|
11 |
+
|
12 |
+
device = actor.device
|
13 |
+
|
14 |
+
actor.train_mode()
|
15 |
+
actor.train()
|
16 |
+
actor_output = actor(batch)
|
17 |
+
actor_cost, log_probs = actor_output['total_time'], actor_output['log_probs']
|
18 |
+
|
19 |
+
|
20 |
+
with torch.no_grad():
|
21 |
+
baseline.greedy_search()
|
22 |
+
baseline_output = baseline(batch)
|
23 |
+
baseline_cost = baseline_output['total_time']
|
24 |
+
|
25 |
+
loss = ((actor_cost - baseline_cost).detach() * log_probs).mean()
|
26 |
+
|
27 |
+
optimizer.zero_grad()
|
28 |
+
loss.backward()
|
29 |
+
|
30 |
+
if gradient_clipping:
|
31 |
+
for group in optimizer.param_groups:
|
32 |
+
clip_grad_norm_(
|
33 |
+
group['params'],
|
34 |
+
1,
|
35 |
+
norm_type=2
|
36 |
+
)
|
37 |
+
|
38 |
+
optimizer.step()
|
39 |
+
|
40 |
+
if compute_cost_ratio and (comparison_model is None):
|
41 |
+
normalize = actor.apply_normalization
|
42 |
+
comparison_model = Actor(model=None, num_neighbors_action=1, normalize=normalize, device=device)
|
43 |
+
|
44 |
+
if compute_cost_ratio:
|
45 |
+
with torch.no_grad():
|
46 |
+
comp_output = comparison_model(batch)
|
47 |
+
comp_cost = comp_output['total_time']
|
48 |
+
|
49 |
+
a = comp_cost.sum().item()
|
50 |
+
b = actor_cost.sum().item()
|
51 |
+
return b / a
|
52 |
+
else:
|
53 |
+
return None
|
54 |
+
|
validation.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn.utils import clip_grad_norm_
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
|
7 |
+
|
8 |
+
def validation(actor, validation_dataset, batch_size):
|
9 |
+
|
10 |
+
val_dataloader = DataLoader(dataset=validation_dataset,
|
11 |
+
batch_size=batch_size,
|
12 |
+
collate_fn=validation_dataset.collate)
|
13 |
+
|
14 |
+
scores = []
|
15 |
+
for batch in val_dataloader:
|
16 |
+
with torch.no_grad():
|
17 |
+
actor_output = actor(batch)
|
18 |
+
cost = actor_output['total_time']
|
19 |
+
|
20 |
+
scores.append(cost.reshape(-1))
|
21 |
+
|
22 |
+
scores = torch.cat(scores, dim=0)
|
23 |
+
|
24 |
+
return scores
|