a-ragab-h-m commited on
Commit
873cef2
·
verified ·
1 Parent(s): bf2229d

Update train_test_utils/train.py

Browse files
Files changed (1) hide show
  1. train_test_utils/train.py +8 -16
train_test_utils/train.py CHANGED
@@ -1,14 +1,10 @@
1
  import torch
2
  import torch.nn as nn
3
  from torch.nn.utils import clip_grad_norm_
4
- from torch.utils.data import DataLoader
5
  from just_time_windows.Actor.actor import Actor
6
 
7
 
8
-
9
-
10
  def train_batch(actor, baseline, batch, optimizer, gradient_clipping=True, comparison_model=None, compute_cost_ratio=True):
11
-
12
  device = actor.device
13
 
14
  actor.train_mode()
@@ -16,7 +12,6 @@ def train_batch(actor, baseline, batch, optimizer, gradient_clipping=True, compa
16
  actor_output = actor(batch)
17
  actor_cost, log_probs = actor_output['total_time'], actor_output['log_probs']
18
 
19
-
20
  with torch.no_grad():
21
  baseline.greedy_search()
22
  baseline_output = baseline(batch)
@@ -29,19 +24,17 @@ def train_batch(actor, baseline, batch, optimizer, gradient_clipping=True, compa
29
 
30
  if gradient_clipping:
31
  for group in optimizer.param_groups:
32
- clip_grad_norm_(
33
- group['params'],
34
- 1,
35
- norm_type=2
36
- )
37
 
38
  optimizer.step()
39
 
40
- if compute_cost_ratio and (comparison_model is None):
41
- normalize = actor.apply_normalization
42
- comparison_model = Actor(model=None, num_neighbors_action=1, normalize=normalize, device=device)
43
-
44
  if compute_cost_ratio:
 
 
 
 
45
  with torch.no_grad():
46
  comp_output = comparison_model(batch)
47
  comp_cost = comp_output['total_time']
@@ -49,6 +42,5 @@ def train_batch(actor, baseline, batch, optimizer, gradient_clipping=True, compa
49
  a = comp_cost.sum().item()
50
  b = actor_cost.sum().item()
51
  return b / a
52
- else:
53
- return None
54
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  from torch.nn.utils import clip_grad_norm_
 
4
  from just_time_windows.Actor.actor import Actor
5
 
6
 
 
 
7
  def train_batch(actor, baseline, batch, optimizer, gradient_clipping=True, comparison_model=None, compute_cost_ratio=True):
 
8
  device = actor.device
9
 
10
  actor.train_mode()
 
12
  actor_output = actor(batch)
13
  actor_cost, log_probs = actor_output['total_time'], actor_output['log_probs']
14
 
 
15
  with torch.no_grad():
16
  baseline.greedy_search()
17
  baseline_output = baseline(batch)
 
24
 
25
  if gradient_clipping:
26
  for group in optimizer.param_groups:
27
+ params = [p for p in group['params'] if p.grad is not None]
28
+ if params:
29
+ clip_grad_norm_(params, max_norm=1, norm_type=2)
 
 
30
 
31
  optimizer.step()
32
 
 
 
 
 
33
  if compute_cost_ratio:
34
+ if comparison_model is None:
35
+ normalize = actor.apply_normalization
36
+ comparison_model = Actor(model=None, num_neighbors_action=1, normalize=normalize, device=device)
37
+
38
  with torch.no_grad():
39
  comp_output = comparison_model(batch)
40
  comp_cost = comp_output['total_time']
 
42
  a = comp_cost.sum().item()
43
  b = actor_cost.sum().item()
44
  return b / a
 
 
45
 
46
+ return None