a-ragab-h-m commited on
Commit
7de92e3
·
verified ·
1 Parent(s): 14c20ce

Update utils/actor_utils.py

Browse files
Files changed (1) hide show
  1. utils/actor_utils.py +50 -66
utils/actor_utils.py CHANGED
@@ -1,79 +1,64 @@
1
-
2
  import torch
3
 
4
 
5
- def widen_tensor(datum, factor):
6
-
7
- if len(datum.shape) == 0:
8
- return datum
9
-
10
- L = list(datum.shape)
11
- a = [1, factor] + [1 for _ in range(len(L) - 1)]
12
 
13
- datum = datum.unsqueeze(1)
14
- datum = datum.repeat(*a)
15
 
16
- if len(L) > 1:
17
- b = [L[0] * factor] + L[1:]
18
- else:
19
- b = [L[0] * factor]
20
-
21
- datum = datum.reshape(*b)
22
- return datum
23
 
24
 
25
  def widen_data(actor, include_embeddings=True, include_projections=True):
 
 
 
26
 
27
- F = dir(actor.fleet)
28
- for s in F:
29
- x = getattr(actor.fleet, s)
30
- if isinstance(x, torch.Tensor):
31
- if len(x.shape) > 0:
32
- y = widen_tensor(x, factor=actor.sample_size)
33
- setattr(actor.fleet, s, y)
34
 
35
- G = dir(actor.graph)
36
- for s in G:
37
- x = getattr(actor.graph, s)
38
- if isinstance(x, torch.Tensor):
39
- if len(x.shape) > 0:
40
- y = widen_tensor(x, factor=actor.sample_size)
41
- setattr(actor.graph, s, y)
42
 
43
- actor.log_probs = widen_tensor(actor.log_probs, factor=actor.sample_size)
44
 
45
  if include_embeddings:
46
- actor.node_embeddings = widen_tensor(actor.node_embeddings, factor=actor.sample_size)
47
 
48
  if include_projections:
49
- def widen_projection(x, size):
50
- if len(x.shape) > 3:
51
- y = x.unsqueeze(2).repeat(1, 1, size, 1, 1)
52
- return y.reshape(x.shape[0], x.shape[1] * size, x.shape[2], x.shape[3])
53
- else:
54
- return widen_tensor(x, size)
55
-
56
- actor.node_projections = {key : widen_projection(actor.node_projections[key], actor.sample_size)
57
- for key in actor.node_projections}
58
 
 
 
 
59
 
60
 
61
  def select_data(actor, index, include_embeddings=True, include_projections=True):
62
- m = index.max().item()
63
-
64
- F = dir(actor.fleet)
65
- for s in F:
66
- x = getattr(actor.fleet, s)
67
- if isinstance(x, torch.Tensor):
68
- if (len(x.shape) > 0) and (x.shape[0] >= m):
69
- setattr(actor.fleet, s, x[index])
70
-
71
- G = dir(actor.graph)
72
- for s in G:
73
- x = getattr(actor.graph, s)
74
- if isinstance(x, torch.Tensor):
75
- if (len(x.shape) > 0) and (x.shape[0] >= m):
76
- setattr(actor.graph, s, x[index])
77
 
78
  actor.log_probs = actor.log_probs[index]
79
 
@@ -81,12 +66,11 @@ def select_data(actor, index, include_embeddings=True, include_projections=True)
81
  actor.node_embeddings = actor.node_embeddings[index]
82
 
83
  if include_projections:
84
- def select_projection(x, index):
85
- if len(x.shape) > 3:
86
- return x[:,index,:,:]
87
- else:
88
- return x[index]
89
-
90
- actor.node_projections = {key : select_projection(actor.node_projections[key], index)
91
- for key in actor.node_projections}
92
-
 
 
1
  import torch
2
 
3
 
4
+ def widen_tensor(tensor, factor):
5
+ """
6
+ Duplicate a tensor `factor` times along the batch dimension.
7
+ """
8
+ if tensor.dim() == 0:
9
+ return tensor
 
10
 
11
+ shape = tensor.shape
12
+ repeat_dims = [1, factor] + [1] * (len(shape) - 1)
13
 
14
+ expanded = tensor.unsqueeze(1).repeat(*repeat_dims)
15
+ new_shape = [shape[0] * factor] + list(shape[1:])
16
+ return expanded.view(*new_shape)
 
 
 
 
17
 
18
 
19
  def widen_data(actor, include_embeddings=True, include_projections=True):
20
+ """
21
+ Expand the actor's data `sample_size` times to support beam sampling.
22
+ """
23
 
24
+ def widen_attributes(obj):
25
+ for name, value in obj.__dict__.items():
26
+ if isinstance(value, torch.Tensor) and value.dim() > 0:
27
+ setattr(obj, name, widen_tensor(value, actor.sample_size))
 
 
 
28
 
29
+ widen_attributes(actor.fleet)
30
+ widen_attributes(actor.graph)
 
 
 
 
 
31
 
32
+ actor.log_probs = widen_tensor(actor.log_probs, actor.sample_size)
33
 
34
  if include_embeddings:
35
+ actor.node_embeddings = widen_tensor(actor.node_embeddings, actor.sample_size)
36
 
37
  if include_projections:
38
+ def widen_projection(x):
39
+ if x.dim() > 3:
40
+ # (heads, batch, graph, embed)
41
+ y = x.unsqueeze(2).repeat(1, 1, actor.sample_size, 1, 1)
42
+ return y.view(x.size(0), x.size(1) * actor.sample_size, x.size(2), x.size(3))
43
+ return widen_tensor(x, actor.sample_size)
 
 
 
44
 
45
+ actor.node_projections = {
46
+ key: widen_projection(value) for key, value in actor.node_projections.items()
47
+ }
48
 
49
 
50
  def select_data(actor, index, include_embeddings=True, include_projections=True):
51
+ """
52
+ Select a specific subset of data using `index` (e.g., for beam search pruning).
53
+ """
54
+
55
+ def select_attributes(obj):
56
+ for name, value in obj.__dict__.items():
57
+ if isinstance(value, torch.Tensor) and value.size(0) >= index.max().item():
58
+ setattr(obj, name, value[index])
59
+
60
+ select_attributes(actor.fleet)
61
+ select_attributes(actor.graph)
 
 
 
 
62
 
63
  actor.log_probs = actor.log_probs[index]
64
 
 
66
  actor.node_embeddings = actor.node_embeddings[index]
67
 
68
  if include_projections:
69
+ def select_projection(x):
70
+ if x.dim() > 3:
71
+ return x[:, index, :, :]
72
+ return x[index]
73
+
74
+ actor.node_projections = {
75
+ key: select_projection(value) for key, value in actor.node_projections.items()
76
+ }