File size: 2,474 Bytes
f85cfc9
 
 
7de92e3
 
 
 
 
 
f85cfc9
7de92e3
 
f85cfc9
7de92e3
 
 
f85cfc9
 
 
7de92e3
 
 
f85cfc9
7de92e3
 
 
 
f85cfc9
7de92e3
 
f85cfc9
7de92e3
f85cfc9
 
7de92e3
f85cfc9
 
7de92e3
 
 
 
 
 
f85cfc9
7de92e3
 
 
f85cfc9
 
 
7de92e3
 
 
 
 
 
cb49c08
7de92e3
 
 
 
f85cfc9
 
 
 
 
 
 
7de92e3
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch


def widen_tensor(tensor, factor):
    """ 
    Duplicate a tensor `factor` times along the batch dimension.
    """
    if tensor.dim() == 0:
        return tensor

    shape = tensor.shape
    repeat_dims = [1, factor] + [1] * (len(shape) - 1)

    expanded = tensor.unsqueeze(1).repeat(*repeat_dims)
    new_shape = [shape[0] * factor] + list(shape[1:])
    return expanded.view(*new_shape)


def widen_data(actor, include_embeddings=True, include_projections=True):
    """
    Expand the actor's data `sample_size` times to support beam sampling.
    """

    def widen_attributes(obj):
        for name, value in obj.__dict__.items():
            if isinstance(value, torch.Tensor) and value.dim() > 0:
                setattr(obj, name, widen_tensor(value, actor.sample_size))

    widen_attributes(actor.fleet)
    widen_attributes(actor.graph)

    actor.log_probs = widen_tensor(actor.log_probs, actor.sample_size)

    if include_embeddings:
        actor.node_embeddings = widen_tensor(actor.node_embeddings, actor.sample_size)

    if include_projections:
        def widen_projection(x):
            if x.dim() > 3:
                # (heads, batch, graph, embed)
                y = x.unsqueeze(2).repeat(1, 1, actor.sample_size, 1, 1)
                return y.view(x.size(0), x.size(1) * actor.sample_size, x.size(2), x.size(3))
            return widen_tensor(x, actor.sample_size)

        actor.node_projections = {
            key: widen_projection(value) for key, value in actor.node_projections.items()
        }


def select_data(actor, index, include_embeddings=True, include_projections=True):
    """
    Select a specific subset of data using `index` (e.g., for beam search pruning).
    """

    def select_attributes(obj):
        for name, value in obj.__dict__.items():
            if isinstance(value, torch.Tensor) and value.dim() > 0 and value.size(0) >= index.max().item():
                setattr(obj, name, value[index])

    select_attributes(actor.fleet)
    select_attributes(actor.graph)

    actor.log_probs = actor.log_probs[index]

    if include_embeddings:
        actor.node_embeddings = actor.node_embeddings[index]

    if include_projections:
        def select_projection(x):
            if x.dim() > 3:
                return x[:, index, :, :]
            return x[index]

        actor.node_projections = {
            key: select_projection(value) for key, value in actor.node_projections.items()
        }