File size: 4,534 Bytes
05ca42f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
Define collate functions for new data types here
"""
from functools import partial
from itertools import chain

import dgl
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data._utils.collate import default_collate_fn_map, collate_tensor_fn, collate
import torch_geometric


def collate_pyg_fn(batch, collate_fn_map=None):
    """
    PyG graph collation
    """
    return torch_geometric.data.Batch.from_data_list(batch)


def collate_dgl_fn(batch, collate_fn_map=None):
    """
    DGL graph collation
    """
    return dgl.batch(batch)


def pad_collate_tensor_fn(batch, padding_value=0.0, collate_fn_map=None):
    """
    Similar to pad_packed_sequence(pack_sequence(batch, enforce_sorted=False), batch_first=True),
    but additionally supports padding a list of square Tensors of size ``(L x L x ...)``.
    :param batch:
    :param padding_value:
    :param collate_fn_map:
    :return: padded_batch, lengths
    """
    lengths = [tensor.size(0) for tensor in batch]
    if any(element != lengths[0] for element in lengths[1:]):
        try:
            # Tensors share at least one common dimension size, use pad_sequence
            batch = pad_sequence(batch, batch_first=True, padding_value=padding_value)
        except RuntimeError:
            # Tensors do not share any common dimension size, find the max size of each dimension in the batch
            max_sizes = [max([tensor.size(dim) for tensor in batch]) for dim in range(batch[0].dim())]
            # Pad every dimension of all tensors in the batch to be the respective max size with the value
            batch = collate_tensor_fn([
                torch.nn.functional.pad(
                    tensor, tuple(chain.from_iterable(
                        [(0, max_sizes[dim] - tensor.size(dim)) for dim in range(tensor.dim())][::-1])
                    ), mode='constant', value=padding_value) for tensor in batch
            ])
    else:
        batch = collate_tensor_fn(batch)

    lengths = torch.as_tensor(lengths)
    # Return the padded batch tensor and the lengths
    return batch, lengths


# Join custom collate functions with the default collation map of PyTorch
COLLATE_FN_MAP = default_collate_fn_map | {
    torch_geometric.data.data.BaseData: collate_pyg_fn,
    dgl.DGLGraph: collate_dgl_fn,
}


def collate_fn(batch, automatic_padding=False, padding_value=0):
    if automatic_padding:
        COLLATE_FN_MAP.update({
            torch.Tensor: partial(pad_collate_tensor_fn, padding_value=padding_value),
        })
    return collate(batch, collate_fn_map=COLLATE_FN_MAP)


class VariableLengthSequence(torch.Tensor):
    """
    A custom PyTorch Tensor class that is similar to PackedSequence, except it can be directly used as a batch tensor,
    and it has an attribute called lengths, which signifies the length of each original sequence in the batch.
    """

    def __new__(cls, data, lengths):
        """
        Creates a new VariableLengthSequence object from the given data and lengths.
        Args:
            data (torch.Tensor): The batch collated tensor of shape (batch_size, max_length, *).
            lengths (torch.Tensor): The lengths of each original sequence in the batch of shape (batch_size,).
        Returns:
            VariableLengthSequence: A new VariableLengthSequence object.
        """
        # Check the validity of the inputs
        assert isinstance(data, torch.Tensor), "data must be a torch.Tensor"
        assert isinstance(lengths, torch.Tensor), "lengths must be a torch.Tensor"
        assert data.dim() >= 2, "data must have at least two dimensions"
        assert lengths.dim() == 1, "lengths must have one dimension"
        assert data.size(0) == lengths.size(0), "data and lengths must have the same batch size"
        assert lengths.min() > 0, "lengths must be positive"
        assert lengths.max() <= data.size(1), "lengths must not exceed the max length of data"

        # Create a new tensor object from data
        obj = super().__new__(cls, data)

        # Set the lengths attribute
        obj.lengths = lengths

        return obj

    def __repr__(self, *, tensor_contents=None):
        """
        Returns a string representation of the VariableLengthSequence object.
        """
        return f"VariableLengthSequence(data={self.data}, lengths={self.lengths})"

    def __reduce_ex__(self, proto):
        """
        Enables pickling of the VariableLengthSequence object.
        """
        return type(self), (self.data, self.lengths)