a-ragab-h-m's picture
Rename validation.py to train_test_utils/validation.py
1e69fd6 verified
raw
history blame
632 Bytes
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
def validation(actor, validation_dataset, batch_size):
val_dataloader = DataLoader(dataset=validation_dataset,
batch_size=batch_size,
collate_fn=validation_dataset.collate)
scores = []
for batch in val_dataloader:
with torch.no_grad():
actor_output = actor(batch)
cost = actor_output['total_time']
scores.append(cost.reshape(-1))
scores = torch.cat(scores, dim=0)
return scores