Spaces:
Runtime error
Runtime error
File size: 632 Bytes
4e5517b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
def validation(actor, validation_dataset, batch_size):
val_dataloader = DataLoader(dataset=validation_dataset,
batch_size=batch_size,
collate_fn=validation_dataset.collate)
scores = []
for batch in val_dataloader:
with torch.no_grad():
actor_output = actor(batch)
cost = actor_output['total_time']
scores.append(cost.reshape(-1))
scores = torch.cat(scores, dim=0)
return scores
|