a-ragab-h-m's picture
Update train_test_utils/baseline.py
d58ae22 verified
import torch
from torch.utils.data import DataLoader
def update_baseline(actor, baseline, validation_set, record_scores=None, batch_size=100, threshold=0.95):
"""
Evaluate the actor on the validation set and update the baseline if performance improves.
Parameters:
- actor: current model being trained
- baseline: model used as the performance reference
- validation_set: dataset used for evaluation
- record_scores: previously recorded baseline scores
- batch_size: batch size for validation
- threshold: (optional) threshold for improvement (not used in current implementation)
Returns:
- updated record_scores
"""
val_dataloader = DataLoader(dataset=validation_set,
batch_size=batch_size,
collate_fn=validation_set.collate)
actor.greedy_search()
actor.eval()
actor_scores = []
for batch in val_dataloader:
with torch.no_grad():
actor_output = actor(batch)
actor_cost = actor_output['total_time'].view(-1)
actor_scores.append(actor_cost)
actor_scores = torch.cat(actor_scores, dim=0)
actor_score_mean = actor_scores.mean().item()
if record_scores is None:
baseline.load_state_dict(actor.state_dict())
return actor_scores
baseline_score_mean = record_scores.mean().item()
if actor_score_mean < baseline_score_mean:
print(f"\nBaseline updated: {baseline_score_mean:.4f} β†’ {actor_score_mean:.4f}\n", flush=True)
baseline.load_state_dict(actor.state_dict())
return actor_scores
else:
print(f"\nNo improvement: {actor_score_mean:.4f} β‰₯ {baseline_score_mean:.4f}\n", flush=True)
return record_scores