File size: 6,467 Bytes
6ae852e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""
Define collate functions for new data types here
"""
from functools import partial
from itertools import chain

import dgl
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data._utils.collate import default_collate_fn_map, collate_tensor_fn, collate
import torch_geometric


def collate_pyg_fn(batch, collate_fn_map=None):
    """
    PyG graph collation
    """
    return torch_geometric.data.Batch.from_data_list(batch)


def collate_dgl_fn(batch, collate_fn_map=None):
    """
    DGL graph collation
    """
    return dgl.batch(batch)


def pad_collate_tensor_fn(batch, padding_value=0.0, collate_fn_map=None):
    """
    Similar to pad_packed_sequence(pack_sequence(batch, enforce_sorted=False), batch_first=True),
    but additionally supports padding a list of square Tensors of size ``(L x L x ...)``.
    :param batch:
    :param padding_value:
    :param collate_fn_map:
    :return: padded_batch, lengths
    """
    lengths = [tensor.size(0) for tensor in batch]
    if any(element != lengths[0] for element in lengths[1:]):
        try:
            # Tensors share at least one common dimension size, use pad_sequence
            batch = pad_sequence(batch, batch_first=True, padding_value=padding_value)
        except RuntimeError:
            # Tensors do not share any common dimension size, find the max size of each dimension in the batch
            max_sizes = [max([tensor.size(dim) for tensor in batch]) for dim in range(batch[0].dim())]
            # Pad every dimension of all tensors in the batch to be the respective max size with the value
            batch = collate_tensor_fn([
                torch.nn.functional.pad(
                    tensor, tuple(chain.from_iterable(
                        [(0, max_sizes[dim] - tensor.size(dim)) for dim in range(tensor.dim())][::-1])
                    ), mode='constant', value=padding_value) for tensor in batch
            ])
    else:
        batch = collate_tensor_fn(batch)

    lengths = torch.as_tensor(lengths)
    # Return the padded batch tensor and the lengths
    return batch, lengths


# Join custom collate functions with the default collation map of PyTorch
COLLATE_FN_MAP = default_collate_fn_map | {
    torch_geometric.data.data.BaseData: collate_pyg_fn,
    dgl.DGLGraph: collate_dgl_fn,
}


def collate_fn(batch, automatic_padding=False, padding_value=0):
    if automatic_padding:
        COLLATE_FN_MAP.update({
            torch.Tensor: partial(pad_collate_tensor_fn, padding_value=padding_value),
        })
    return collate(batch, collate_fn_map=COLLATE_FN_MAP)


# class VariableLengthSequence(torch.Tensor):
#     """
#     A custom PyTorch Tensor class that is similar to PackedSequence, except it can be directly used as a batch tensor,
#     and it has an attribute called lengths, which signifies the length of each original sequence in the batch.
#     """
#
#     def __new__(cls, data, lengths):
#         """
#         Creates a new VariableLengthSequence object from the given data and lengths.
#         Args:
#             data (torch.Tensor): The batch collated tensor of shape (batch_size, max_length, *).
#             lengths (torch.Tensor): The lengths of each original sequence in the batch of shape (batch_size,).
#         Returns:
#             VariableLengthSequence: A new VariableLengthSequence object.
#         """
#         # Check the validity of the inputs
#         assert isinstance(data, torch.Tensor), "data must be a torch.Tensor"
#         assert isinstance(lengths, torch.Tensor), "lengths must be a torch.Tensor"
#         assert data.dim() >= 2, "data must have at least two dimensions"
#         assert lengths.dim() == 1, "lengths must have one dimension"
#         assert data.size(0) == lengths.size(0), "data and lengths must have the same batch size"
#         assert lengths.min() > 0, "lengths must be positive"
#         assert lengths.max() <= data.size(1), "lengths must not exceed the max length of data"
#
#         # Create a new tensor object from data
#         obj = super().__new__(cls, data)
#
#         # Set the lengths attribute
#         obj.lengths = lengths
#
#         return obj


# class VariableLengthSequence(torch.Tensor):
#     _lengths = torch.Tensor()
#
#     def __new__(cls, data, lengths, *args, **kwargs):
#         self = super().__new__(cls, data, *args, **kwargs)
#         self.lengths = lengths
#         return self
#
#     def clone(self, *args, **kwargs):
#         return VariableLengthSequence(super().clone(*args, **kwargs), self.lengths.clone())
#
#     def new_empty(self, *size):
#         return VariableLengthSequence(super().new_empty(*size), self.lengths)
#
#     def to(self, *args, **kwargs):
#         return VariableLengthSequence(super().to(*args, **kwargs), self.lengths.to(*args, **kwargs))
#
#     def __format__(self, format_spec):
#         # Convert self to a string or a number here, depending on what you need
#         return self.item().__format__(format_spec)
#
#     @property
#     def lengths(self):
#         return self._lengths
#
#     @lengths.setter
#     def lengths(self, lengths):
#         self._lengths = lengths
#
#     def cpu(self, *args, **kwargs):
#         return VariableLengthSequence(super().cpu(*args, **kwargs), self.lengths.cpu(*args, **kwargs))
#
#     def cuda(self, *args, **kwargs):
#         return VariableLengthSequence(super().cuda(*args, **kwargs), self.lengths.cuda(*args, **kwargs))
#
#     def pin_memory(self):
#         return VariableLengthSequence(super().pin_memory(), self.lengths.pin_memory())
#
#     def share_memory_(self):
#         super().share_memory_()
#         self.lengths.share_memory_()
#         return self
#
#     def detach_(self, *args, **kwargs):
#         super().detach_(*args, **kwargs)
#         self.lengths.detach_(*args, **kwargs)
#         return self
#
#     def detach(self, *args, **kwargs):
#         return VariableLengthSequence(super().detach(*args, **kwargs), self.lengths.detach(*args, **kwargs))
#
#     def record_stream(self, *args, **kwargs):
#         super().record_stream(*args, **kwargs)
#         self.lengths.record_stream(*args, **kwargs)
#         return self


    # @classmethod
    # def __torch_function__(cls, func, types, args=(), kwargs=None):
    #     return super().__torch_function__(func, types, args, kwargs) \
    #         if cls.lengths is not None else torch.Tensor.__torch_function__(func, types, args, kwargs)