a-ragab-h-m commited on
Commit
d58ae22
·
verified ·
1 Parent(s): 873cef2

Update train_test_utils/baseline.py

Browse files
Files changed (1) hide show
  1. train_test_utils/baseline.py +28 -19
train_test_utils/baseline.py CHANGED
@@ -1,10 +1,22 @@
1
  import torch
2
- import torch.nn as nn
3
- from torch.nn.utils import clip_grad_norm_
4
  from torch.utils.data import DataLoader
5
 
6
 
7
- def update_baseline(actor, baseline, validation_set, record_scores, batch_size=100, threshold=0.95):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  val_dataloader = DataLoader(dataset=validation_set,
10
  batch_size=batch_size,
@@ -17,25 +29,22 @@ def update_baseline(actor, baseline, validation_set, record_scores, batch_size=1
17
  for batch in val_dataloader:
18
  with torch.no_grad():
19
  actor_output = actor(batch)
20
- actor_cost = actor_output['total_time']
21
- actor_cost.reshape(-1)
22
- actor_scores.append(actor_cost)
23
- actor_scores = torch.cat(actor_scores, dim=0)
24
 
 
 
25
 
26
  if record_scores is None:
27
  baseline.load_state_dict(actor.state_dict())
28
- record_scores = actor_scores
29
- return record_scores
30
- else:
31
 
32
- if actor_scores.mean().item() < record_scores.mean().item():
33
- print('\n', flush=True)
34
- print('baseline updated', flush=True)
35
- print('\n', flush=True)
36
 
37
- baseline.load_state_dict(actor.state_dict())
38
- record_scores = actor_scores
39
- return record_scores
40
- else:
41
- return record_scores
 
 
 
1
  import torch
 
 
2
  from torch.utils.data import DataLoader
3
 
4
 
5
+ def update_baseline(actor, baseline, validation_set, record_scores=None, batch_size=100, threshold=0.95):
6
+ """
7
+ Evaluate the actor on the validation set and update the baseline if performance improves.
8
+
9
+ Parameters:
10
+ - actor: current model being trained
11
+ - baseline: model used as the performance reference
12
+ - validation_set: dataset used for evaluation
13
+ - record_scores: previously recorded baseline scores
14
+ - batch_size: batch size for validation
15
+ - threshold: (optional) threshold for improvement (not used in current implementation)
16
+
17
+ Returns:
18
+ - updated record_scores
19
+ """
20
 
21
  val_dataloader = DataLoader(dataset=validation_set,
22
  batch_size=batch_size,
 
29
  for batch in val_dataloader:
30
  with torch.no_grad():
31
  actor_output = actor(batch)
32
+ actor_cost = actor_output['total_time'].view(-1)
33
+ actor_scores.append(actor_cost)
 
 
34
 
35
+ actor_scores = torch.cat(actor_scores, dim=0)
36
+ actor_score_mean = actor_scores.mean().item()
37
 
38
  if record_scores is None:
39
  baseline.load_state_dict(actor.state_dict())
40
+ return actor_scores
 
 
41
 
42
+ baseline_score_mean = record_scores.mean().item()
 
 
 
43
 
44
+ if actor_score_mean < baseline_score_mean:
45
+ print(f"\nBaseline updated: {baseline_score_mean:.4f} → {actor_score_mean:.4f}\n", flush=True)
46
+ baseline.load_state_dict(actor.state_dict())
47
+ return actor_scores
48
+ else:
49
+ print(f"\nNo improvement: {actor_score_mean:.4f} ≥ {baseline_score_mean:.4f}\n", flush=True)
50
+ return record_scores