import torch from torch.utils.data import DataLoader def validation(actor, validation_dataset, batch_size): """ Evaluate the actor model on the validation dataset. Args: actor: Trained model to evaluate validation_dataset: Dataset for validation batch_size: Size of mini-batches used in evaluation Returns: Tensor of total costs for each sample in the validation set """ actor.eval() # Set model to evaluation mode val_dataloader = DataLoader(dataset=validation_dataset, batch_size=batch_size, collate_fn=validation_dataset.collate) scores = [] with torch.no_grad(): for batch in val_dataloader: actor_output = actor(batch) cost = actor_output['total_time'].view(-1) scores.append(cost) return torch.cat(scores, dim=0)