a-ragab-h-m's picture
Update utils/actor_utils.py
cb49c08 verified
import torch
def widen_tensor(tensor, factor):
"""
Duplicate a tensor `factor` times along the batch dimension.
"""
if tensor.dim() == 0:
return tensor
shape = tensor.shape
repeat_dims = [1, factor] + [1] * (len(shape) - 1)
expanded = tensor.unsqueeze(1).repeat(*repeat_dims)
new_shape = [shape[0] * factor] + list(shape[1:])
return expanded.view(*new_shape)
def widen_data(actor, include_embeddings=True, include_projections=True):
"""
Expand the actor's data `sample_size` times to support beam sampling.
"""
def widen_attributes(obj):
for name, value in obj.__dict__.items():
if isinstance(value, torch.Tensor) and value.dim() > 0:
setattr(obj, name, widen_tensor(value, actor.sample_size))
widen_attributes(actor.fleet)
widen_attributes(actor.graph)
actor.log_probs = widen_tensor(actor.log_probs, actor.sample_size)
if include_embeddings:
actor.node_embeddings = widen_tensor(actor.node_embeddings, actor.sample_size)
if include_projections:
def widen_projection(x):
if x.dim() > 3:
# (heads, batch, graph, embed)
y = x.unsqueeze(2).repeat(1, 1, actor.sample_size, 1, 1)
return y.view(x.size(0), x.size(1) * actor.sample_size, x.size(2), x.size(3))
return widen_tensor(x, actor.sample_size)
actor.node_projections = {
key: widen_projection(value) for key, value in actor.node_projections.items()
}
def select_data(actor, index, include_embeddings=True, include_projections=True):
"""
Select a specific subset of data using `index` (e.g., for beam search pruning).
"""
def select_attributes(obj):
for name, value in obj.__dict__.items():
if isinstance(value, torch.Tensor) and value.dim() > 0 and value.size(0) >= index.max().item():
setattr(obj, name, value[index])
select_attributes(actor.fleet)
select_attributes(actor.graph)
actor.log_probs = actor.log_probs[index]
if include_embeddings:
actor.node_embeddings = actor.node_embeddings[index]
if include_projections:
def select_projection(x):
if x.dim() > 3:
return x[:, index, :, :]
return x[index]
actor.node_projections = {
key: select_projection(value) for key, value in actor.node_projections.items()
}