|
import torch |
|
|
|
|
|
def widen_tensor(tensor, factor): |
|
""" |
|
Expands a tensor by repeating it along a new batch dimension. |
|
""" |
|
if tensor.ndim == 0: |
|
return tensor |
|
|
|
shape = list(tensor.shape) |
|
repeat_dims = [1, factor] + [1] * (tensor.ndim - 1) |
|
expanded = tensor.unsqueeze(1).repeat(*repeat_dims) |
|
|
|
new_shape = [shape[0] * factor] + shape[1:] |
|
return expanded.reshape(*new_shape) |
|
|
|
|
|
def widen_data(actor, include_embeddings=True, include_projections=True): |
|
""" |
|
Expands the actor's fleet, graph, and optionally embeddings/projections |
|
for use in beam search by repeating the batch dimension `sample_size` times. |
|
""" |
|
sample_size = actor.sample_size |
|
|
|
|
|
for name, tensor in vars(actor.fleet).items(): |
|
if isinstance(tensor, torch.Tensor) and tensor.ndim > 0: |
|
widened = widen_tensor(tensor, sample_size) |
|
setattr(actor.fleet, name, widened) |
|
|
|
|
|
for name, tensor in vars(actor.graph).items(): |
|
if isinstance(tensor, torch.Tensor) and tensor.ndim > 0: |
|
widened = widen_tensor(tensor, sample_size) |
|
setattr(actor.graph, name, widened) |
|
|
|
actor.log_probs = widen_tensor(actor.log_probs, sample_size) |
|
|
|
if include_embeddings: |
|
actor.node_embeddings = widen_tensor(actor.node_embeddings, sample_size) |
|
|
|
if include_projections: |
|
def widen_projection(tensor, size): |
|
if tensor.ndim > 3: |
|
|
|
tensor = tensor.unsqueeze(2).repeat(1, 1, size, 1, 1) |
|
return tensor.reshape(tensor.shape[0], tensor.shape[1] * size, tensor.shape[3], tensor.shape[4]) |
|
return widen_tensor(tensor, size) |
|
|
|
actor.node_projections = { |
|
key: widen_projection(tensor, sample_size) |
|
for key, tensor in actor.node_projections.items() |
|
} |
|
|
|
|
|
def select_data(actor, index, include_embeddings=True, include_projections=True): |
|
""" |
|
Selects a subset of the beam based on indices, usually used to keep top-k paths in beam search. |
|
""" |
|
index = index.long() |
|
max_index = index.max().item() |
|
|
|
|
|
for name, tensor in vars(actor.fleet).items(): |
|
if isinstance(tensor, torch.Tensor) and tensor.shape[0] > max_index: |
|
setattr(actor.fleet, name, tensor[index]) |
|
|
|
|
|
for name, tensor in vars(actor.graph).items(): |
|
if isinstance(tensor, torch.Tensor) and tensor.shape[0] > max_index: |
|
setattr(actor.graph, name, tensor[index]) |
|
|
|
actor.log_probs = actor.log_probs[index] |
|
|
|
if include_embeddings: |
|
actor.node_embeddings = actor.node_embeddings[index] |
|
|
|
if include_projections: |
|
def select_projection(tensor): |
|
if tensor.ndim > 3: |
|
return tensor[:, index, :, :] |
|
return tensor[index] |
|
|
|
actor.node_projections = { |
|
key: select_projection(tensor) |
|
for key, tensor in actor.node_projections.items() |
|
} |
|
|