a-ragab-h-m commited on
Commit
d5c00b7
·
verified ·
1 Parent(s): d58ae22

Update train_test_utils/validation.py

Browse files
Files changed (1) hide show
  1. train_test_utils/validation.py +18 -12
train_test_utils/validation.py CHANGED
@@ -1,24 +1,30 @@
1
-
2
  import torch
3
- import torch.nn as nn
4
- from torch.nn.utils import clip_grad_norm_
5
  from torch.utils.data import DataLoader
6
 
7
 
8
  def validation(actor, validation_dataset, batch_size):
9
-
 
 
 
 
 
 
 
 
 
 
 
10
  val_dataloader = DataLoader(dataset=validation_dataset,
11
  batch_size=batch_size,
12
  collate_fn=validation_dataset.collate)
13
 
14
  scores = []
15
- for batch in val_dataloader:
16
- with torch.no_grad():
17
- actor_output = actor(batch)
18
- cost = actor_output['total_time']
19
 
20
- scores.append(cost.reshape(-1))
21
-
22
- scores = torch.cat(scores, dim=0)
 
 
23
 
24
- return scores
 
 
1
  import torch
 
 
2
  from torch.utils.data import DataLoader
3
 
4
 
5
  def validation(actor, validation_dataset, batch_size):
6
+ """
7
+ Evaluate the actor model on the validation dataset.
8
+
9
+ Args:
10
+ actor: Trained model to evaluate
11
+ validation_dataset: Dataset for validation
12
+ batch_size: Size of mini-batches used in evaluation
13
+
14
+ Returns:
15
+ Tensor of total costs for each sample in the validation set
16
+ """
17
+ actor.eval() # Set model to evaluation mode
18
  val_dataloader = DataLoader(dataset=validation_dataset,
19
  batch_size=batch_size,
20
  collate_fn=validation_dataset.collate)
21
 
22
  scores = []
 
 
 
 
23
 
24
+ with torch.no_grad():
25
+ for batch in val_dataloader:
26
+ actor_output = actor(batch)
27
+ cost = actor_output['total_time'].view(-1)
28
+ scores.append(cost)
29
 
30
+ return torch.cat(scores, dim=0)