Spaces:
Sleeping
Sleeping
""" | |
Define collate functions for new data types here | |
""" | |
from functools import partial | |
from itertools import chain | |
import dgl | |
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
from torch.utils.data._utils.collate import default_collate_fn_map, collate_tensor_fn, collate | |
import torch_geometric | |
def collate_pyg_fn(batch, collate_fn_map=None): | |
""" | |
PyG graph collation | |
""" | |
return torch_geometric.data.Batch.from_data_list(batch) | |
def collate_dgl_fn(batch, collate_fn_map=None): | |
""" | |
DGL graph collation | |
""" | |
return dgl.batch(batch) | |
def pad_collate_tensor_fn(batch, padding_value=0.0, collate_fn_map=None): | |
""" | |
Similar to pad_packed_sequence(pack_sequence(batch, enforce_sorted=False), batch_first=True), | |
but additionally supports padding a list of square Tensors of size ``(L x L x ...)``. | |
:param batch: | |
:param padding_value: | |
:param collate_fn_map: | |
:return: padded_batch, lengths | |
""" | |
lengths = [tensor.size(0) for tensor in batch] | |
if any(element != lengths[0] for element in lengths[1:]): | |
try: | |
# Tensors share at least one common dimension size, use pad_sequence | |
batch = pad_sequence(batch, batch_first=True, padding_value=padding_value) | |
except RuntimeError: | |
# Tensors do not share any common dimension size, find the max size of each dimension in the batch | |
max_sizes = [max([tensor.size(dim) for tensor in batch]) for dim in range(batch[0].dim())] | |
# Pad every dimension of all tensors in the batch to be the respective max size with the value | |
batch = collate_tensor_fn([ | |
torch.nn.functional.pad( | |
tensor, tuple(chain.from_iterable( | |
[(0, max_sizes[dim] - tensor.size(dim)) for dim in range(tensor.dim())][::-1]) | |
), mode='constant', value=padding_value) for tensor in batch | |
]) | |
else: | |
batch = collate_tensor_fn(batch) | |
lengths = torch.as_tensor(lengths) | |
# Return the padded batch tensor and the lengths | |
return batch, lengths | |
# Join custom collate functions with the default collation map of PyTorch | |
COLLATE_FN_MAP = default_collate_fn_map | { | |
torch_geometric.data.data.BaseData: collate_pyg_fn, | |
dgl.DGLGraph: collate_dgl_fn, | |
} | |
def collate_fn(batch, automatic_padding=False, padding_value=0): | |
if automatic_padding: | |
COLLATE_FN_MAP.update({ | |
torch.Tensor: partial(pad_collate_tensor_fn, padding_value=padding_value), | |
}) | |
return collate(batch, collate_fn_map=COLLATE_FN_MAP) | |
# class VariableLengthSequence(torch.Tensor): | |
# """ | |
# A custom PyTorch Tensor class that is similar to PackedSequence, except it can be directly used as a batch tensor, | |
# and it has an attribute called lengths, which signifies the length of each original sequence in the batch. | |
# """ | |
# | |
# def __new__(cls, data, lengths): | |
# """ | |
# Creates a new VariableLengthSequence object from the given data and lengths. | |
# Args: | |
# data (torch.Tensor): The batch collated tensor of shape (batch_size, max_length, *). | |
# lengths (torch.Tensor): The lengths of each original sequence in the batch of shape (batch_size,). | |
# Returns: | |
# VariableLengthSequence: A new VariableLengthSequence object. | |
# """ | |
# # Check the validity of the inputs | |
# assert isinstance(data, torch.Tensor), "data must be a torch.Tensor" | |
# assert isinstance(lengths, torch.Tensor), "lengths must be a torch.Tensor" | |
# assert data.dim() >= 2, "data must have at least two dimensions" | |
# assert lengths.dim() == 1, "lengths must have one dimension" | |
# assert data.size(0) == lengths.size(0), "data and lengths must have the same batch size" | |
# assert lengths.min() > 0, "lengths must be positive" | |
# assert lengths.max() <= data.size(1), "lengths must not exceed the max length of data" | |
# | |
# # Create a new tensor object from data | |
# obj = super().__new__(cls, data) | |
# | |
# # Set the lengths attribute | |
# obj.lengths = lengths | |
# | |
# return obj | |
# class VariableLengthSequence(torch.Tensor): | |
# _lengths = torch.Tensor() | |
# | |
# def __new__(cls, data, lengths, *args, **kwargs): | |
# self = super().__new__(cls, data, *args, **kwargs) | |
# self.lengths = lengths | |
# return self | |
# | |
# def clone(self, *args, **kwargs): | |
# return VariableLengthSequence(super().clone(*args, **kwargs), self.lengths.clone()) | |
# | |
# def new_empty(self, *size): | |
# return VariableLengthSequence(super().new_empty(*size), self.lengths) | |
# | |
# def to(self, *args, **kwargs): | |
# return VariableLengthSequence(super().to(*args, **kwargs), self.lengths.to(*args, **kwargs)) | |
# | |
# def __format__(self, format_spec): | |
# # Convert self to a string or a number here, depending on what you need | |
# return self.item().__format__(format_spec) | |
# | |
# @property | |
# def lengths(self): | |
# return self._lengths | |
# | |
# @lengths.setter | |
# def lengths(self, lengths): | |
# self._lengths = lengths | |
# | |
# def cpu(self, *args, **kwargs): | |
# return VariableLengthSequence(super().cpu(*args, **kwargs), self.lengths.cpu(*args, **kwargs)) | |
# | |
# def cuda(self, *args, **kwargs): | |
# return VariableLengthSequence(super().cuda(*args, **kwargs), self.lengths.cuda(*args, **kwargs)) | |
# | |
# def pin_memory(self): | |
# return VariableLengthSequence(super().pin_memory(), self.lengths.pin_memory()) | |
# | |
# def share_memory_(self): | |
# super().share_memory_() | |
# self.lengths.share_memory_() | |
# return self | |
# | |
# def detach_(self, *args, **kwargs): | |
# super().detach_(*args, **kwargs) | |
# self.lengths.detach_(*args, **kwargs) | |
# return self | |
# | |
# def detach(self, *args, **kwargs): | |
# return VariableLengthSequence(super().detach(*args, **kwargs), self.lengths.detach(*args, **kwargs)) | |
# | |
# def record_stream(self, *args, **kwargs): | |
# super().record_stream(*args, **kwargs) | |
# self.lengths.record_stream(*args, **kwargs) | |
# return self | |
# @classmethod | |
# def __torch_function__(cls, func, types, args=(), kwargs=None): | |
# return super().__torch_function__(func, types, args, kwargs) \ | |
# if cls.lengths is not None else torch.Tensor.__torch_function__(func, types, args, kwargs) | |