import torch from torch.utils.data import DataLoader def update_baseline(actor, baseline, validation_set, record_scores=None, batch_size=100, threshold=0.95): """ Evaluate the actor on the validation set and update the baseline if performance improves. Parameters: - actor: current model being trained - baseline: model used as the performance reference - validation_set: dataset used for evaluation - record_scores: previously recorded baseline scores - batch_size: batch size for validation - threshold: (optional) threshold for improvement (not used in current implementation) Returns: - updated record_scores """ val_dataloader = DataLoader(dataset=validation_set, batch_size=batch_size, collate_fn=validation_set.collate) actor.greedy_search() actor.eval() actor_scores = [] for batch in val_dataloader: with torch.no_grad(): actor_output = actor(batch) actor_cost = actor_output['total_time'].view(-1) actor_scores.append(actor_cost) actor_scores = torch.cat(actor_scores, dim=0) actor_score_mean = actor_scores.mean().item() if record_scores is None: baseline.load_state_dict(actor.state_dict()) return actor_scores baseline_score_mean = record_scores.mean().item() if actor_score_mean < baseline_score_mean: print(f"\nBaseline updated: {baseline_score_mean:.4f} → {actor_score_mean:.4f}\n", flush=True) baseline.load_state_dict(actor.state_dict()) return actor_scores else: print(f"\nNo improvement: {actor_score_mean:.4f} ≥ {baseline_score_mean:.4f}\n", flush=True) return record_scores