|
import torch |
|
|
|
|
|
def widen_tensor(tensor, factor): |
|
""" |
|
Duplicate a tensor `factor` times along the batch dimension. |
|
""" |
|
if tensor.dim() == 0: |
|
return tensor |
|
|
|
shape = tensor.shape |
|
repeat_dims = [1, factor] + [1] * (len(shape) - 1) |
|
|
|
expanded = tensor.unsqueeze(1).repeat(*repeat_dims) |
|
new_shape = [shape[0] * factor] + list(shape[1:]) |
|
return expanded.view(*new_shape) |
|
|
|
|
|
def widen_data(actor, include_embeddings=True, include_projections=True): |
|
""" |
|
Expand the actor's data `sample_size` times to support beam sampling. |
|
""" |
|
|
|
def widen_attributes(obj): |
|
for name, value in obj.__dict__.items(): |
|
if isinstance(value, torch.Tensor) and value.dim() > 0: |
|
setattr(obj, name, widen_tensor(value, actor.sample_size)) |
|
|
|
widen_attributes(actor.fleet) |
|
widen_attributes(actor.graph) |
|
|
|
actor.log_probs = widen_tensor(actor.log_probs, actor.sample_size) |
|
|
|
if include_embeddings: |
|
actor.node_embeddings = widen_tensor(actor.node_embeddings, actor.sample_size) |
|
|
|
if include_projections: |
|
def widen_projection(x): |
|
if x.dim() > 3: |
|
|
|
y = x.unsqueeze(2).repeat(1, 1, actor.sample_size, 1, 1) |
|
return y.view(x.size(0), x.size(1) * actor.sample_size, x.size(2), x.size(3)) |
|
return widen_tensor(x, actor.sample_size) |
|
|
|
actor.node_projections = { |
|
key: widen_projection(value) for key, value in actor.node_projections.items() |
|
} |
|
|
|
|
|
def select_data(actor, index, include_embeddings=True, include_projections=True): |
|
""" |
|
Select a specific subset of data using `index` (e.g., for beam search pruning). |
|
""" |
|
|
|
def select_attributes(obj): |
|
for name, value in obj.__dict__.items(): |
|
if isinstance(value, torch.Tensor) and value.dim() > 0 and value.size(0) >= index.max().item(): |
|
setattr(obj, name, value[index]) |
|
|
|
select_attributes(actor.fleet) |
|
select_attributes(actor.graph) |
|
|
|
actor.log_probs = actor.log_probs[index] |
|
|
|
if include_embeddings: |
|
actor.node_embeddings = actor.node_embeddings[index] |
|
|
|
if include_projections: |
|
def select_projection(x): |
|
if x.dim() > 3: |
|
return x[:, index, :, :] |
|
return x[index] |
|
|
|
actor.node_projections = { |
|
key: select_projection(value) for key, value in actor.node_projections.items() |
|
} |
|
|