File size: 911 Bytes
4e5517b
 
 
 
 
d5c00b7
 
 
 
 
 
 
 
 
 
 
 
4e5517b
 
 
 
 
 
d5c00b7
 
 
 
 
4e5517b
d5c00b7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from torch.utils.data import DataLoader


def validation(actor, validation_dataset, batch_size):
    """
    Evaluate the actor model on the validation dataset.
    
    Args:
        actor: Trained model to evaluate
        validation_dataset: Dataset for validation
        batch_size: Size of mini-batches used in evaluation

    Returns:
        Tensor of total costs for each sample in the validation set
    """
    actor.eval()  # Set model to evaluation mode
    val_dataloader = DataLoader(dataset=validation_dataset,
                                batch_size=batch_size,
                                collate_fn=validation_dataset.collate)

    scores = []

    with torch.no_grad():
        for batch in val_dataloader:
            actor_output = actor(batch)
            cost = actor_output['total_time'].view(-1)
            scores.append(cost)

    return torch.cat(scores, dim=0)