File size: 632 Bytes
4e5517b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader


def validation(actor, validation_dataset, batch_size):

    val_dataloader = DataLoader(dataset=validation_dataset,
                                batch_size=batch_size,
                                collate_fn=validation_dataset.collate)

    scores = []
    for batch in val_dataloader:
        with torch.no_grad():
            actor_output = actor(batch)
            cost = actor_output['total_time']

            scores.append(cost.reshape(-1))

    scores = torch.cat(scores, dim=0)

    return scores