Update utils/actor_utils.py
Browse files- utils/actor_utils.py +50 -66
utils/actor_utils.py
CHANGED
@@ -1,79 +1,64 @@
|
|
1 |
-
|
2 |
import torch
|
3 |
|
4 |
|
5 |
-
def widen_tensor(
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
a = [1, factor] + [1 for _ in range(len(L) - 1)]
|
12 |
|
13 |
-
|
14 |
-
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
b = [L[0] * factor]
|
20 |
-
|
21 |
-
datum = datum.reshape(*b)
|
22 |
-
return datum
|
23 |
|
24 |
|
25 |
def widen_data(actor, include_embeddings=True, include_projections=True):
|
|
|
|
|
|
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
if len(x.shape) > 0:
|
32 |
-
y = widen_tensor(x, factor=actor.sample_size)
|
33 |
-
setattr(actor.fleet, s, y)
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
x = getattr(actor.graph, s)
|
38 |
-
if isinstance(x, torch.Tensor):
|
39 |
-
if len(x.shape) > 0:
|
40 |
-
y = widen_tensor(x, factor=actor.sample_size)
|
41 |
-
setattr(actor.graph, s, y)
|
42 |
|
43 |
-
actor.log_probs = widen_tensor(actor.log_probs,
|
44 |
|
45 |
if include_embeddings:
|
46 |
-
actor.node_embeddings = widen_tensor(actor.node_embeddings,
|
47 |
|
48 |
if include_projections:
|
49 |
-
def widen_projection(x
|
50 |
-
if
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
actor.node_projections = {key : widen_projection(actor.node_projections[key], actor.sample_size)
|
57 |
-
for key in actor.node_projections}
|
58 |
|
|
|
|
|
|
|
59 |
|
60 |
|
61 |
def select_data(actor, index, include_embeddings=True, include_projections=True):
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
if (
|
69 |
-
setattr(
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
x = getattr(actor.graph, s)
|
74 |
-
if isinstance(x, torch.Tensor):
|
75 |
-
if (len(x.shape) > 0) and (x.shape[0] >= m):
|
76 |
-
setattr(actor.graph, s, x[index])
|
77 |
|
78 |
actor.log_probs = actor.log_probs[index]
|
79 |
|
@@ -81,12 +66,11 @@ def select_data(actor, index, include_embeddings=True, include_projections=True)
|
|
81 |
actor.node_embeddings = actor.node_embeddings[index]
|
82 |
|
83 |
if include_projections:
|
84 |
-
def select_projection(x
|
85 |
-
if
|
86 |
-
return x[:,index
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
1 |
import torch
|
2 |
|
3 |
|
4 |
+
def widen_tensor(tensor, factor):
|
5 |
+
"""
|
6 |
+
Duplicate a tensor `factor` times along the batch dimension.
|
7 |
+
"""
|
8 |
+
if tensor.dim() == 0:
|
9 |
+
return tensor
|
|
|
10 |
|
11 |
+
shape = tensor.shape
|
12 |
+
repeat_dims = [1, factor] + [1] * (len(shape) - 1)
|
13 |
|
14 |
+
expanded = tensor.unsqueeze(1).repeat(*repeat_dims)
|
15 |
+
new_shape = [shape[0] * factor] + list(shape[1:])
|
16 |
+
return expanded.view(*new_shape)
|
|
|
|
|
|
|
|
|
17 |
|
18 |
|
19 |
def widen_data(actor, include_embeddings=True, include_projections=True):
|
20 |
+
"""
|
21 |
+
Expand the actor's data `sample_size` times to support beam sampling.
|
22 |
+
"""
|
23 |
|
24 |
+
def widen_attributes(obj):
|
25 |
+
for name, value in obj.__dict__.items():
|
26 |
+
if isinstance(value, torch.Tensor) and value.dim() > 0:
|
27 |
+
setattr(obj, name, widen_tensor(value, actor.sample_size))
|
|
|
|
|
|
|
28 |
|
29 |
+
widen_attributes(actor.fleet)
|
30 |
+
widen_attributes(actor.graph)
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
+
actor.log_probs = widen_tensor(actor.log_probs, actor.sample_size)
|
33 |
|
34 |
if include_embeddings:
|
35 |
+
actor.node_embeddings = widen_tensor(actor.node_embeddings, actor.sample_size)
|
36 |
|
37 |
if include_projections:
|
38 |
+
def widen_projection(x):
|
39 |
+
if x.dim() > 3:
|
40 |
+
# (heads, batch, graph, embed)
|
41 |
+
y = x.unsqueeze(2).repeat(1, 1, actor.sample_size, 1, 1)
|
42 |
+
return y.view(x.size(0), x.size(1) * actor.sample_size, x.size(2), x.size(3))
|
43 |
+
return widen_tensor(x, actor.sample_size)
|
|
|
|
|
|
|
44 |
|
45 |
+
actor.node_projections = {
|
46 |
+
key: widen_projection(value) for key, value in actor.node_projections.items()
|
47 |
+
}
|
48 |
|
49 |
|
50 |
def select_data(actor, index, include_embeddings=True, include_projections=True):
|
51 |
+
"""
|
52 |
+
Select a specific subset of data using `index` (e.g., for beam search pruning).
|
53 |
+
"""
|
54 |
+
|
55 |
+
def select_attributes(obj):
|
56 |
+
for name, value in obj.__dict__.items():
|
57 |
+
if isinstance(value, torch.Tensor) and value.size(0) >= index.max().item():
|
58 |
+
setattr(obj, name, value[index])
|
59 |
+
|
60 |
+
select_attributes(actor.fleet)
|
61 |
+
select_attributes(actor.graph)
|
|
|
|
|
|
|
|
|
62 |
|
63 |
actor.log_probs = actor.log_probs[index]
|
64 |
|
|
|
66 |
actor.node_embeddings = actor.node_embeddings[index]
|
67 |
|
68 |
if include_projections:
|
69 |
+
def select_projection(x):
|
70 |
+
if x.dim() > 3:
|
71 |
+
return x[:, index, :, :]
|
72 |
+
return x[index]
|
73 |
+
|
74 |
+
actor.node_projections = {
|
75 |
+
key: select_projection(value) for key, value in actor.node_projections.items()
|
76 |
+
}
|
|