vrp-shanghai-transformer / utils /beam_search_utils.py
a-ragab-h-m's picture
Update utils/beam_search_utils.py
99f1377 verified
import torch
def widen_tensor(tensor, factor):
"""
Expands a tensor by repeating it along a new batch dimension.
"""
if tensor.ndim == 0:
return tensor
shape = list(tensor.shape)
repeat_dims = [1, factor] + [1] * (tensor.ndim - 1)
expanded = tensor.unsqueeze(1).repeat(*repeat_dims)
new_shape = [shape[0] * factor] + shape[1:]
return expanded.reshape(*new_shape)
def widen_data(actor, include_embeddings=True, include_projections=True):
"""
Expands the actor's fleet, graph, and optionally embeddings/projections
for use in beam search by repeating the batch dimension `sample_size` times.
"""
sample_size = actor.sample_size
# Fleet tensors
for name, tensor in vars(actor.fleet).items():
if isinstance(tensor, torch.Tensor) and tensor.ndim > 0:
widened = widen_tensor(tensor, sample_size)
setattr(actor.fleet, name, widened)
# Graph tensors
for name, tensor in vars(actor.graph).items():
if isinstance(tensor, torch.Tensor) and tensor.ndim > 0:
widened = widen_tensor(tensor, sample_size)
setattr(actor.graph, name, widened)
actor.log_probs = widen_tensor(actor.log_probs, sample_size)
if include_embeddings:
actor.node_embeddings = widen_tensor(actor.node_embeddings, sample_size)
if include_projections:
def widen_projection(tensor, size):
if tensor.ndim > 3:
# Special case for shape: (n_heads, B, G, D) β†’ (n_heads, B * size, G, D)
tensor = tensor.unsqueeze(2).repeat(1, 1, size, 1, 1)
return tensor.reshape(tensor.shape[0], tensor.shape[1] * size, tensor.shape[3], tensor.shape[4])
return widen_tensor(tensor, size)
actor.node_projections = {
key: widen_projection(tensor, sample_size)
for key, tensor in actor.node_projections.items()
}
def select_data(actor, index, include_embeddings=True, include_projections=True):
"""
Selects a subset of the beam based on indices, usually used to keep top-k paths in beam search.
"""
index = index.long()
max_index = index.max().item()
# Select from fleet
for name, tensor in vars(actor.fleet).items():
if isinstance(tensor, torch.Tensor) and tensor.shape[0] > max_index:
setattr(actor.fleet, name, tensor[index])
# Select from graph
for name, tensor in vars(actor.graph).items():
if isinstance(tensor, torch.Tensor) and tensor.shape[0] > max_index:
setattr(actor.graph, name, tensor[index])
actor.log_probs = actor.log_probs[index]
if include_embeddings:
actor.node_embeddings = actor.node_embeddings[index]
if include_projections:
def select_projection(tensor):
if tensor.ndim > 3:
return tensor[:, index, :, :]
return tensor[index]
actor.node_projections = {
key: select_projection(tensor)
for key, tensor in actor.node_projections.items()
}