File size: 1,775 Bytes
a2a33f3 d58ae22 a2a33f3 d58ae22 a2a33f3 d58ae22 a2a33f3 d58ae22 a2a33f3 d58ae22 a2a33f3 d58ae22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
from torch.utils.data import DataLoader
def update_baseline(actor, baseline, validation_set, record_scores=None, batch_size=100, threshold=0.95):
"""
Evaluate the actor on the validation set and update the baseline if performance improves.
Parameters:
- actor: current model being trained
- baseline: model used as the performance reference
- validation_set: dataset used for evaluation
- record_scores: previously recorded baseline scores
- batch_size: batch size for validation
- threshold: (optional) threshold for improvement (not used in current implementation)
Returns:
- updated record_scores
"""
val_dataloader = DataLoader(dataset=validation_set,
batch_size=batch_size,
collate_fn=validation_set.collate)
actor.greedy_search()
actor.eval()
actor_scores = []
for batch in val_dataloader:
with torch.no_grad():
actor_output = actor(batch)
actor_cost = actor_output['total_time'].view(-1)
actor_scores.append(actor_cost)
actor_scores = torch.cat(actor_scores, dim=0)
actor_score_mean = actor_scores.mean().item()
if record_scores is None:
baseline.load_state_dict(actor.state_dict())
return actor_scores
baseline_score_mean = record_scores.mean().item()
if actor_score_mean < baseline_score_mean:
print(f"\nBaseline updated: {baseline_score_mean:.4f} → {actor_score_mean:.4f}\n", flush=True)
baseline.load_state_dict(actor.state_dict())
return actor_scores
else:
print(f"\nNo improvement: {actor_score_mean:.4f} ≥ {baseline_score_mean:.4f}\n", flush=True)
return record_scores
|