libokj's picture
Upload 358 files
05ca42f
raw
history blame
4.53 kB
"""
Define collate functions for new data types here
"""
from functools import partial
from itertools import chain
import dgl
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data._utils.collate import default_collate_fn_map, collate_tensor_fn, collate
import torch_geometric
def collate_pyg_fn(batch, collate_fn_map=None):
"""
PyG graph collation
"""
return torch_geometric.data.Batch.from_data_list(batch)
def collate_dgl_fn(batch, collate_fn_map=None):
"""
DGL graph collation
"""
return dgl.batch(batch)
def pad_collate_tensor_fn(batch, padding_value=0.0, collate_fn_map=None):
"""
Similar to pad_packed_sequence(pack_sequence(batch, enforce_sorted=False), batch_first=True),
but additionally supports padding a list of square Tensors of size ``(L x L x ...)``.
:param batch:
:param padding_value:
:param collate_fn_map:
:return: padded_batch, lengths
"""
lengths = [tensor.size(0) for tensor in batch]
if any(element != lengths[0] for element in lengths[1:]):
try:
# Tensors share at least one common dimension size, use pad_sequence
batch = pad_sequence(batch, batch_first=True, padding_value=padding_value)
except RuntimeError:
# Tensors do not share any common dimension size, find the max size of each dimension in the batch
max_sizes = [max([tensor.size(dim) for tensor in batch]) for dim in range(batch[0].dim())]
# Pad every dimension of all tensors in the batch to be the respective max size with the value
batch = collate_tensor_fn([
torch.nn.functional.pad(
tensor, tuple(chain.from_iterable(
[(0, max_sizes[dim] - tensor.size(dim)) for dim in range(tensor.dim())][::-1])
), mode='constant', value=padding_value) for tensor in batch
])
else:
batch = collate_tensor_fn(batch)
lengths = torch.as_tensor(lengths)
# Return the padded batch tensor and the lengths
return batch, lengths
# Join custom collate functions with the default collation map of PyTorch
COLLATE_FN_MAP = default_collate_fn_map | {
torch_geometric.data.data.BaseData: collate_pyg_fn,
dgl.DGLGraph: collate_dgl_fn,
}
def collate_fn(batch, automatic_padding=False, padding_value=0):
if automatic_padding:
COLLATE_FN_MAP.update({
torch.Tensor: partial(pad_collate_tensor_fn, padding_value=padding_value),
})
return collate(batch, collate_fn_map=COLLATE_FN_MAP)
class VariableLengthSequence(torch.Tensor):
"""
A custom PyTorch Tensor class that is similar to PackedSequence, except it can be directly used as a batch tensor,
and it has an attribute called lengths, which signifies the length of each original sequence in the batch.
"""
def __new__(cls, data, lengths):
"""
Creates a new VariableLengthSequence object from the given data and lengths.
Args:
data (torch.Tensor): The batch collated tensor of shape (batch_size, max_length, *).
lengths (torch.Tensor): The lengths of each original sequence in the batch of shape (batch_size,).
Returns:
VariableLengthSequence: A new VariableLengthSequence object.
"""
# Check the validity of the inputs
assert isinstance(data, torch.Tensor), "data must be a torch.Tensor"
assert isinstance(lengths, torch.Tensor), "lengths must be a torch.Tensor"
assert data.dim() >= 2, "data must have at least two dimensions"
assert lengths.dim() == 1, "lengths must have one dimension"
assert data.size(0) == lengths.size(0), "data and lengths must have the same batch size"
assert lengths.min() > 0, "lengths must be positive"
assert lengths.max() <= data.size(1), "lengths must not exceed the max length of data"
# Create a new tensor object from data
obj = super().__new__(cls, data)
# Set the lengths attribute
obj.lengths = lengths
return obj
def __repr__(self, *, tensor_contents=None):
"""
Returns a string representation of the VariableLengthSequence object.
"""
return f"VariableLengthSequence(data={self.data}, lengths={self.lengths})"
def __reduce_ex__(self, proto):
"""
Enables pickling of the VariableLengthSequence object.
"""
return type(self), (self.data, self.lengths)