import torch def widen_tensor(tensor, factor): """ Expands a tensor by repeating it along a new batch dimension. """ if tensor.ndim == 0: return tensor shape = list(tensor.shape) repeat_dims = [1, factor] + [1] * (tensor.ndim - 1) expanded = tensor.unsqueeze(1).repeat(*repeat_dims) new_shape = [shape[0] * factor] + shape[1:] return expanded.reshape(*new_shape) def widen_data(actor, include_embeddings=True, include_projections=True): """ Expands the actor's fleet, graph, and optionally embeddings/projections for use in beam search by repeating the batch dimension `sample_size` times. """ sample_size = actor.sample_size # Fleet tensors for name, tensor in vars(actor.fleet).items(): if isinstance(tensor, torch.Tensor) and tensor.ndim > 0: widened = widen_tensor(tensor, sample_size) setattr(actor.fleet, name, widened) # Graph tensors for name, tensor in vars(actor.graph).items(): if isinstance(tensor, torch.Tensor) and tensor.ndim > 0: widened = widen_tensor(tensor, sample_size) setattr(actor.graph, name, widened) actor.log_probs = widen_tensor(actor.log_probs, sample_size) if include_embeddings: actor.node_embeddings = widen_tensor(actor.node_embeddings, sample_size) if include_projections: def widen_projection(tensor, size): if tensor.ndim > 3: # Special case for shape: (n_heads, B, G, D) → (n_heads, B * size, G, D) tensor = tensor.unsqueeze(2).repeat(1, 1, size, 1, 1) return tensor.reshape(tensor.shape[0], tensor.shape[1] * size, tensor.shape[3], tensor.shape[4]) return widen_tensor(tensor, size) actor.node_projections = { key: widen_projection(tensor, sample_size) for key, tensor in actor.node_projections.items() } def select_data(actor, index, include_embeddings=True, include_projections=True): """ Selects a subset of the beam based on indices, usually used to keep top-k paths in beam search. """ index = index.long() max_index = index.max().item() # Select from fleet for name, tensor in vars(actor.fleet).items(): if isinstance(tensor, torch.Tensor) and tensor.shape[0] > max_index: setattr(actor.fleet, name, tensor[index]) # Select from graph for name, tensor in vars(actor.graph).items(): if isinstance(tensor, torch.Tensor) and tensor.shape[0] > max_index: setattr(actor.graph, name, tensor[index]) actor.log_probs = actor.log_probs[index] if include_embeddings: actor.node_embeddings = actor.node_embeddings[index] if include_projections: def select_projection(tensor): if tensor.ndim > 3: return tensor[:, index, :, :] return tensor[index] actor.node_projections = { key: select_projection(tensor) for key, tensor in actor.node_projections.items() }