libokj's picture
Upload 110 files
6ae852e
"""
Define collate functions for new data types here
"""
from functools import partial
from itertools import chain
import dgl
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data._utils.collate import default_collate_fn_map, collate_tensor_fn, collate
import torch_geometric
def collate_pyg_fn(batch, collate_fn_map=None):
"""
PyG graph collation
"""
return torch_geometric.data.Batch.from_data_list(batch)
def collate_dgl_fn(batch, collate_fn_map=None):
"""
DGL graph collation
"""
return dgl.batch(batch)
def pad_collate_tensor_fn(batch, padding_value=0.0, collate_fn_map=None):
"""
Similar to pad_packed_sequence(pack_sequence(batch, enforce_sorted=False), batch_first=True),
but additionally supports padding a list of square Tensors of size ``(L x L x ...)``.
:param batch:
:param padding_value:
:param collate_fn_map:
:return: padded_batch, lengths
"""
lengths = [tensor.size(0) for tensor in batch]
if any(element != lengths[0] for element in lengths[1:]):
try:
# Tensors share at least one common dimension size, use pad_sequence
batch = pad_sequence(batch, batch_first=True, padding_value=padding_value)
except RuntimeError:
# Tensors do not share any common dimension size, find the max size of each dimension in the batch
max_sizes = [max([tensor.size(dim) for tensor in batch]) for dim in range(batch[0].dim())]
# Pad every dimension of all tensors in the batch to be the respective max size with the value
batch = collate_tensor_fn([
torch.nn.functional.pad(
tensor, tuple(chain.from_iterable(
[(0, max_sizes[dim] - tensor.size(dim)) for dim in range(tensor.dim())][::-1])
), mode='constant', value=padding_value) for tensor in batch
])
else:
batch = collate_tensor_fn(batch)
lengths = torch.as_tensor(lengths)
# Return the padded batch tensor and the lengths
return batch, lengths
# Join custom collate functions with the default collation map of PyTorch
COLLATE_FN_MAP = default_collate_fn_map | {
torch_geometric.data.data.BaseData: collate_pyg_fn,
dgl.DGLGraph: collate_dgl_fn,
}
def collate_fn(batch, automatic_padding=False, padding_value=0):
if automatic_padding:
COLLATE_FN_MAP.update({
torch.Tensor: partial(pad_collate_tensor_fn, padding_value=padding_value),
})
return collate(batch, collate_fn_map=COLLATE_FN_MAP)
# class VariableLengthSequence(torch.Tensor):
# """
# A custom PyTorch Tensor class that is similar to PackedSequence, except it can be directly used as a batch tensor,
# and it has an attribute called lengths, which signifies the length of each original sequence in the batch.
# """
#
# def __new__(cls, data, lengths):
# """
# Creates a new VariableLengthSequence object from the given data and lengths.
# Args:
# data (torch.Tensor): The batch collated tensor of shape (batch_size, max_length, *).
# lengths (torch.Tensor): The lengths of each original sequence in the batch of shape (batch_size,).
# Returns:
# VariableLengthSequence: A new VariableLengthSequence object.
# """
# # Check the validity of the inputs
# assert isinstance(data, torch.Tensor), "data must be a torch.Tensor"
# assert isinstance(lengths, torch.Tensor), "lengths must be a torch.Tensor"
# assert data.dim() >= 2, "data must have at least two dimensions"
# assert lengths.dim() == 1, "lengths must have one dimension"
# assert data.size(0) == lengths.size(0), "data and lengths must have the same batch size"
# assert lengths.min() > 0, "lengths must be positive"
# assert lengths.max() <= data.size(1), "lengths must not exceed the max length of data"
#
# # Create a new tensor object from data
# obj = super().__new__(cls, data)
#
# # Set the lengths attribute
# obj.lengths = lengths
#
# return obj
# class VariableLengthSequence(torch.Tensor):
# _lengths = torch.Tensor()
#
# def __new__(cls, data, lengths, *args, **kwargs):
# self = super().__new__(cls, data, *args, **kwargs)
# self.lengths = lengths
# return self
#
# def clone(self, *args, **kwargs):
# return VariableLengthSequence(super().clone(*args, **kwargs), self.lengths.clone())
#
# def new_empty(self, *size):
# return VariableLengthSequence(super().new_empty(*size), self.lengths)
#
# def to(self, *args, **kwargs):
# return VariableLengthSequence(super().to(*args, **kwargs), self.lengths.to(*args, **kwargs))
#
# def __format__(self, format_spec):
# # Convert self to a string or a number here, depending on what you need
# return self.item().__format__(format_spec)
#
# @property
# def lengths(self):
# return self._lengths
#
# @lengths.setter
# def lengths(self, lengths):
# self._lengths = lengths
#
# def cpu(self, *args, **kwargs):
# return VariableLengthSequence(super().cpu(*args, **kwargs), self.lengths.cpu(*args, **kwargs))
#
# def cuda(self, *args, **kwargs):
# return VariableLengthSequence(super().cuda(*args, **kwargs), self.lengths.cuda(*args, **kwargs))
#
# def pin_memory(self):
# return VariableLengthSequence(super().pin_memory(), self.lengths.pin_memory())
#
# def share_memory_(self):
# super().share_memory_()
# self.lengths.share_memory_()
# return self
#
# def detach_(self, *args, **kwargs):
# super().detach_(*args, **kwargs)
# self.lengths.detach_(*args, **kwargs)
# return self
#
# def detach(self, *args, **kwargs):
# return VariableLengthSequence(super().detach(*args, **kwargs), self.lengths.detach(*args, **kwargs))
#
# def record_stream(self, *args, **kwargs):
# super().record_stream(*args, **kwargs)
# self.lengths.record_stream(*args, **kwargs)
# return self
# @classmethod
# def __torch_function__(cls, func, types, args=(), kwargs=None):
# return super().__torch_function__(func, types, args, kwargs) \
# if cls.lengths is not None else torch.Tensor.__torch_function__(func, types, args, kwargs)