|
import torch |
|
from torch.utils.data import DataLoader |
|
|
|
|
|
def update_baseline(actor, baseline, validation_set, record_scores=None, batch_size=100, threshold=0.95): |
|
""" |
|
Evaluate the actor on the validation set and update the baseline if performance improves. |
|
|
|
Parameters: |
|
- actor: current model being trained |
|
- baseline: model used as the performance reference |
|
- validation_set: dataset used for evaluation |
|
- record_scores: previously recorded baseline scores |
|
- batch_size: batch size for validation |
|
- threshold: (optional) threshold for improvement (not used in current implementation) |
|
|
|
Returns: |
|
- updated record_scores |
|
""" |
|
|
|
val_dataloader = DataLoader(dataset=validation_set, |
|
batch_size=batch_size, |
|
collate_fn=validation_set.collate) |
|
|
|
actor.greedy_search() |
|
actor.eval() |
|
|
|
actor_scores = [] |
|
for batch in val_dataloader: |
|
with torch.no_grad(): |
|
actor_output = actor(batch) |
|
actor_cost = actor_output['total_time'].view(-1) |
|
actor_scores.append(actor_cost) |
|
|
|
actor_scores = torch.cat(actor_scores, dim=0) |
|
actor_score_mean = actor_scores.mean().item() |
|
|
|
if record_scores is None: |
|
baseline.load_state_dict(actor.state_dict()) |
|
return actor_scores |
|
|
|
baseline_score_mean = record_scores.mean().item() |
|
|
|
if actor_score_mean < baseline_score_mean: |
|
print(f"\nBaseline updated: {baseline_score_mean:.4f} β {actor_score_mean:.4f}\n", flush=True) |
|
baseline.load_state_dict(actor.state_dict()) |
|
return actor_scores |
|
else: |
|
print(f"\nNo improvement: {actor_score_mean:.4f} β₯ {baseline_score_mean:.4f}\n", flush=True) |
|
return record_scores |
|
|