File size: 911 Bytes
4e5517b d5c00b7 4e5517b d5c00b7 4e5517b d5c00b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch
from torch.utils.data import DataLoader
def validation(actor, validation_dataset, batch_size):
"""
Evaluate the actor model on the validation dataset.
Args:
actor: Trained model to evaluate
validation_dataset: Dataset for validation
batch_size: Size of mini-batches used in evaluation
Returns:
Tensor of total costs for each sample in the validation set
"""
actor.eval() # Set model to evaluation mode
val_dataloader = DataLoader(dataset=validation_dataset,
batch_size=batch_size,
collate_fn=validation_dataset.collate)
scores = []
with torch.no_grad():
for batch in val_dataloader:
actor_output = actor(batch)
cost = actor_output['total_time'].view(-1)
scores.append(cost)
return torch.cat(scores, dim=0)
|