a-ragab-h-m commited on
Commit
4e5517b
·
verified ·
1 Parent(s): e57fbcf

Upload 3 files

Browse files
Files changed (3) hide show
  1. baseline.py +41 -0
  2. train.py +54 -0
  3. validation.py +24 -0
baseline.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.utils import clip_grad_norm_
4
+ from torch.utils.data import DataLoader
5
+
6
+
7
+ def update_baseline(actor, baseline, validation_set, record_scores, batch_size=100, threshold=0.95):
8
+
9
+ val_dataloader = DataLoader(dataset=validation_set,
10
+ batch_size=batch_size,
11
+ collate_fn=validation_set.collate)
12
+
13
+ actor.greedy_search()
14
+ actor.eval()
15
+
16
+ actor_scores = []
17
+ for batch in val_dataloader:
18
+ with torch.no_grad():
19
+ actor_output = actor(batch)
20
+ actor_cost = actor_output['total_time']
21
+ actor_cost.reshape(-1)
22
+ actor_scores.append(actor_cost)
23
+ actor_scores = torch.cat(actor_scores, dim=0)
24
+
25
+
26
+ if record_scores is None:
27
+ baseline.load_state_dict(actor.state_dict())
28
+ record_scores = actor_scores
29
+ return record_scores
30
+ else:
31
+
32
+ if actor_scores.mean().item() < record_scores.mean().item():
33
+ print('\n', flush=True)
34
+ print('baseline updated', flush=True)
35
+ print('\n', flush=True)
36
+
37
+ baseline.load_state_dict(actor.state_dict())
38
+ record_scores = actor_scores
39
+ return record_scores
40
+ else:
41
+ return record_scores
train.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.utils import clip_grad_norm_
4
+ from torch.utils.data import DataLoader
5
+ from just_time_windows.Actor.actor import Actor
6
+
7
+
8
+
9
+
10
+ def train_batch(actor, baseline, batch, optimizer, gradient_clipping=True, comparison_model=None, compute_cost_ratio=True):
11
+
12
+ device = actor.device
13
+
14
+ actor.train_mode()
15
+ actor.train()
16
+ actor_output = actor(batch)
17
+ actor_cost, log_probs = actor_output['total_time'], actor_output['log_probs']
18
+
19
+
20
+ with torch.no_grad():
21
+ baseline.greedy_search()
22
+ baseline_output = baseline(batch)
23
+ baseline_cost = baseline_output['total_time']
24
+
25
+ loss = ((actor_cost - baseline_cost).detach() * log_probs).mean()
26
+
27
+ optimizer.zero_grad()
28
+ loss.backward()
29
+
30
+ if gradient_clipping:
31
+ for group in optimizer.param_groups:
32
+ clip_grad_norm_(
33
+ group['params'],
34
+ 1,
35
+ norm_type=2
36
+ )
37
+
38
+ optimizer.step()
39
+
40
+ if compute_cost_ratio and (comparison_model is None):
41
+ normalize = actor.apply_normalization
42
+ comparison_model = Actor(model=None, num_neighbors_action=1, normalize=normalize, device=device)
43
+
44
+ if compute_cost_ratio:
45
+ with torch.no_grad():
46
+ comp_output = comparison_model(batch)
47
+ comp_cost = comp_output['total_time']
48
+
49
+ a = comp_cost.sum().item()
50
+ b = actor_cost.sum().item()
51
+ return b / a
52
+ else:
53
+ return None
54
+
validation.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn.utils import clip_grad_norm_
5
+ from torch.utils.data import DataLoader
6
+
7
+
8
+ def validation(actor, validation_dataset, batch_size):
9
+
10
+ val_dataloader = DataLoader(dataset=validation_dataset,
11
+ batch_size=batch_size,
12
+ collate_fn=validation_dataset.collate)
13
+
14
+ scores = []
15
+ for batch in val_dataloader:
16
+ with torch.no_grad():
17
+ actor_output = actor(batch)
18
+ cost = actor_output['total_time']
19
+
20
+ scores.append(cost.reshape(-1))
21
+
22
+ scores = torch.cat(scores, dim=0)
23
+
24
+ return scores