Spaces:
Sleeping
Sleeping
""" | |
Define collate functions for new data types here | |
""" | |
from functools import partial | |
from itertools import chain | |
import dgl | |
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
from torch.utils.data._utils.collate import default_collate_fn_map, collate_tensor_fn, collate | |
import torch_geometric | |
def collate_pyg_fn(batch, collate_fn_map=None): | |
""" | |
PyG graph collation | |
""" | |
return torch_geometric.data.Batch.from_data_list(batch) | |
def collate_dgl_fn(batch, collate_fn_map=None): | |
""" | |
DGL graph collation | |
""" | |
return dgl.batch(batch) | |
def pad_collate_tensor_fn(batch, padding_value=0.0, collate_fn_map=None): | |
""" | |
Similar to pad_packed_sequence(pack_sequence(batch, enforce_sorted=False), batch_first=True), | |
but additionally supports padding a list of square Tensors of size ``(L x L x ...)``. | |
:param batch: | |
:param padding_value: | |
:param collate_fn_map: | |
:return: padded_batch, lengths | |
""" | |
lengths = [tensor.size(0) for tensor in batch] | |
if any(element != lengths[0] for element in lengths[1:]): | |
try: | |
# Tensors share at least one common dimension size, use pad_sequence | |
batch = pad_sequence(batch, batch_first=True, padding_value=padding_value) | |
except RuntimeError: | |
# Tensors do not share any common dimension size, find the max size of each dimension in the batch | |
max_sizes = [max([tensor.size(dim) for tensor in batch]) for dim in range(batch[0].dim())] | |
# Pad every dimension of all tensors in the batch to be the respective max size with the value | |
batch = collate_tensor_fn([ | |
torch.nn.functional.pad( | |
tensor, tuple(chain.from_iterable( | |
[(0, max_sizes[dim] - tensor.size(dim)) for dim in range(tensor.dim())][::-1]) | |
), mode='constant', value=padding_value) for tensor in batch | |
]) | |
else: | |
batch = collate_tensor_fn(batch) | |
lengths = torch.as_tensor(lengths) | |
# Return the padded batch tensor and the lengths | |
return batch, lengths | |
# Join custom collate functions with the default collation map of PyTorch | |
COLLATE_FN_MAP = default_collate_fn_map | { | |
torch_geometric.data.data.BaseData: collate_pyg_fn, | |
dgl.DGLGraph: collate_dgl_fn, | |
} | |
def collate_fn(batch, automatic_padding=False, padding_value=0): | |
if automatic_padding: | |
COLLATE_FN_MAP.update({ | |
torch.Tensor: partial(pad_collate_tensor_fn, padding_value=padding_value), | |
}) | |
return collate(batch, collate_fn_map=COLLATE_FN_MAP) | |
class VariableLengthSequence(torch.Tensor): | |
""" | |
A custom PyTorch Tensor class that is similar to PackedSequence, except it can be directly used as a batch tensor, | |
and it has an attribute called lengths, which signifies the length of each original sequence in the batch. | |
""" | |
def __new__(cls, data, lengths): | |
""" | |
Creates a new VariableLengthSequence object from the given data and lengths. | |
Args: | |
data (torch.Tensor): The batch collated tensor of shape (batch_size, max_length, *). | |
lengths (torch.Tensor): The lengths of each original sequence in the batch of shape (batch_size,). | |
Returns: | |
VariableLengthSequence: A new VariableLengthSequence object. | |
""" | |
# Check the validity of the inputs | |
assert isinstance(data, torch.Tensor), "data must be a torch.Tensor" | |
assert isinstance(lengths, torch.Tensor), "lengths must be a torch.Tensor" | |
assert data.dim() >= 2, "data must have at least two dimensions" | |
assert lengths.dim() == 1, "lengths must have one dimension" | |
assert data.size(0) == lengths.size(0), "data and lengths must have the same batch size" | |
assert lengths.min() > 0, "lengths must be positive" | |
assert lengths.max() <= data.size(1), "lengths must not exceed the max length of data" | |
# Create a new tensor object from data | |
obj = super().__new__(cls, data) | |
# Set the lengths attribute | |
obj.lengths = lengths | |
return obj | |
def __repr__(self, *, tensor_contents=None): | |
""" | |
Returns a string representation of the VariableLengthSequence object. | |
""" | |
return f"VariableLengthSequence(data={self.data}, lengths={self.lengths})" | |
def __reduce_ex__(self, proto): | |
""" | |
Enables pickling of the VariableLengthSequence object. | |
""" | |
return type(self), (self.data, self.lengths) | |