a-ragab-h-m commited on
Commit
fffc83c
·
verified ·
1 Parent(s): 79eead4

Delete baseline.py

Browse files
Files changed (1) hide show
  1. baseline.py +0 -41
baseline.py DELETED
@@ -1,41 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.nn.utils import clip_grad_norm_
4
- from torch.utils.data import DataLoader
5
-
6
-
7
- def update_baseline(actor, baseline, validation_set, record_scores, batch_size=100, threshold=0.95):
8
-
9
- val_dataloader = DataLoader(dataset=validation_set,
10
- batch_size=batch_size,
11
- collate_fn=validation_set.collate)
12
-
13
- actor.greedy_search()
14
- actor.eval()
15
-
16
- actor_scores = []
17
- for batch in val_dataloader:
18
- with torch.no_grad():
19
- actor_output = actor(batch)
20
- actor_cost = actor_output['total_time']
21
- actor_cost.reshape(-1)
22
- actor_scores.append(actor_cost)
23
- actor_scores = torch.cat(actor_scores, dim=0)
24
-
25
-
26
- if record_scores is None:
27
- baseline.load_state_dict(actor.state_dict())
28
- record_scores = actor_scores
29
- return record_scores
30
- else:
31
-
32
- if actor_scores.mean().item() < record_scores.mean().item():
33
- print('\n', flush=True)
34
- print('baseline updated', flush=True)
35
- print('\n', flush=True)
36
-
37
- baseline.load_state_dict(actor.state_dict())
38
- record_scores = actor_scores
39
- return record_scores
40
- else:
41
- return record_scores