File size: 3,080 Bytes
78f21d4
 
 
99f1377
 
 
 
 
 
78f21d4
99f1377
 
 
78f21d4
99f1377
 
78f21d4
 
 
99f1377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78f21d4
 
99f1377
78f21d4
 
99f1377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78f21d4
 
99f1377
78f21d4
 
99f1377
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch


def widen_tensor(tensor, factor):
    """
    Expands a tensor by repeating it along a new batch dimension.
    """
    if tensor.ndim == 0:
        return tensor

    shape = list(tensor.shape)
    repeat_dims = [1, factor] + [1] * (tensor.ndim - 1)
    expanded = tensor.unsqueeze(1).repeat(*repeat_dims)

    new_shape = [shape[0] * factor] + shape[1:]
    return expanded.reshape(*new_shape)


def widen_data(actor, include_embeddings=True, include_projections=True):
    """
    Expands the actor's fleet, graph, and optionally embeddings/projections
    for use in beam search by repeating the batch dimension `sample_size` times.
    """
    sample_size = actor.sample_size

    # Fleet tensors
    for name, tensor in vars(actor.fleet).items():
        if isinstance(tensor, torch.Tensor) and tensor.ndim > 0:
            widened = widen_tensor(tensor, sample_size)
            setattr(actor.fleet, name, widened)

    # Graph tensors
    for name, tensor in vars(actor.graph).items():
        if isinstance(tensor, torch.Tensor) and tensor.ndim > 0:
            widened = widen_tensor(tensor, sample_size)
            setattr(actor.graph, name, widened)

    actor.log_probs = widen_tensor(actor.log_probs, sample_size)

    if include_embeddings:
        actor.node_embeddings = widen_tensor(actor.node_embeddings, sample_size)

    if include_projections:
        def widen_projection(tensor, size):
            if tensor.ndim > 3:
                # Special case for shape: (n_heads, B, G, D) → (n_heads, B * size, G, D)
                tensor = tensor.unsqueeze(2).repeat(1, 1, size, 1, 1)
                return tensor.reshape(tensor.shape[0], tensor.shape[1] * size, tensor.shape[3], tensor.shape[4])
            return widen_tensor(tensor, size)

        actor.node_projections = {
            key: widen_projection(tensor, sample_size)
            for key, tensor in actor.node_projections.items()
        }


def select_data(actor, index, include_embeddings=True, include_projections=True):
    """
    Selects a subset of the beam based on indices, usually used to keep top-k paths in beam search.
    """
    index = index.long()
    max_index = index.max().item()

    # Select from fleet
    for name, tensor in vars(actor.fleet).items():
        if isinstance(tensor, torch.Tensor) and tensor.shape[0] > max_index:
            setattr(actor.fleet, name, tensor[index])

    # Select from graph
    for name, tensor in vars(actor.graph).items():
        if isinstance(tensor, torch.Tensor) and tensor.shape[0] > max_index:
            setattr(actor.graph, name, tensor[index])

    actor.log_probs = actor.log_probs[index]

    if include_embeddings:
        actor.node_embeddings = actor.node_embeddings[index]

    if include_projections:
        def select_projection(tensor):
            if tensor.ndim > 3:
                return tensor[:, index, :, :]
            return tensor[index]

        actor.node_projections = {
            key: select_projection(tensor)
            for key, tensor in actor.node_projections.items()
        }