import torch | |
from torch.utils.data import DataLoader | |
def validation(actor, validation_dataset, batch_size): | |
""" | |
Evaluate the actor model on the validation dataset. | |
Args: | |
actor: Trained model to evaluate | |
validation_dataset: Dataset for validation | |
batch_size: Size of mini-batches used in evaluation | |
Returns: | |
Tensor of total costs for each sample in the validation set | |
""" | |
actor.eval() # Set model to evaluation mode | |
val_dataloader = DataLoader(dataset=validation_dataset, | |
batch_size=batch_size, | |
collate_fn=validation_dataset.collate) | |
scores = [] | |
with torch.no_grad(): | |
for batch in val_dataloader: | |
actor_output = actor(batch) | |
cost = actor_output['total_time'].view(-1) | |
scores.append(cost) | |
return torch.cat(scores, dim=0) | |