import torch def widen_tensor(tensor, factor): """ Duplicate a tensor `factor` times along the batch dimension. """ if tensor.dim() == 0: return tensor shape = tensor.shape repeat_dims = [1, factor] + [1] * (len(shape) - 1) expanded = tensor.unsqueeze(1).repeat(*repeat_dims) new_shape = [shape[0] * factor] + list(shape[1:]) return expanded.view(*new_shape) def widen_data(actor, include_embeddings=True, include_projections=True): """ Expand the actor's data `sample_size` times to support beam sampling. """ def widen_attributes(obj): for name, value in obj.__dict__.items(): if isinstance(value, torch.Tensor) and value.dim() > 0: setattr(obj, name, widen_tensor(value, actor.sample_size)) widen_attributes(actor.fleet) widen_attributes(actor.graph) actor.log_probs = widen_tensor(actor.log_probs, actor.sample_size) if include_embeddings: actor.node_embeddings = widen_tensor(actor.node_embeddings, actor.sample_size) if include_projections: def widen_projection(x): if x.dim() > 3: # (heads, batch, graph, embed) y = x.unsqueeze(2).repeat(1, 1, actor.sample_size, 1, 1) return y.view(x.size(0), x.size(1) * actor.sample_size, x.size(2), x.size(3)) return widen_tensor(x, actor.sample_size) actor.node_projections = { key: widen_projection(value) for key, value in actor.node_projections.items() } def select_data(actor, index, include_embeddings=True, include_projections=True): """ Select a specific subset of data using `index` (e.g., for beam search pruning). """ def select_attributes(obj): for name, value in obj.__dict__.items(): if isinstance(value, torch.Tensor) and value.dim() > 0 and value.size(0) >= index.max().item(): setattr(obj, name, value[index]) select_attributes(actor.fleet) select_attributes(actor.graph) actor.log_probs = actor.log_probs[index] if include_embeddings: actor.node_embeddings = actor.node_embeddings[index] if include_projections: def select_projection(x): if x.dim() > 3: return x[:, index, :, :] return x[index] actor.node_projections = { key: select_projection(value) for key, value in actor.node_projections.items() }