a-ragab-h-m's picture
Update train_test_utils/validation.py
d5c00b7 verified
raw
history blame
911 Bytes
import torch
from torch.utils.data import DataLoader
def validation(actor, validation_dataset, batch_size):
"""
Evaluate the actor model on the validation dataset.
Args:
actor: Trained model to evaluate
validation_dataset: Dataset for validation
batch_size: Size of mini-batches used in evaluation
Returns:
Tensor of total costs for each sample in the validation set
"""
actor.eval() # Set model to evaluation mode
val_dataloader = DataLoader(dataset=validation_dataset,
batch_size=batch_size,
collate_fn=validation_dataset.collate)
scores = []
with torch.no_grad():
for batch in val_dataloader:
actor_output = actor(batch)
cost = actor_output['total_time'].view(-1)
scores.append(cost)
return torch.cat(scores, dim=0)