Spaces:
Sleeping
Sleeping
File size: 4,534 Bytes
05ca42f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
"""
Define collate functions for new data types here
"""
from functools import partial
from itertools import chain
import dgl
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data._utils.collate import default_collate_fn_map, collate_tensor_fn, collate
import torch_geometric
def collate_pyg_fn(batch, collate_fn_map=None):
"""
PyG graph collation
"""
return torch_geometric.data.Batch.from_data_list(batch)
def collate_dgl_fn(batch, collate_fn_map=None):
"""
DGL graph collation
"""
return dgl.batch(batch)
def pad_collate_tensor_fn(batch, padding_value=0.0, collate_fn_map=None):
"""
Similar to pad_packed_sequence(pack_sequence(batch, enforce_sorted=False), batch_first=True),
but additionally supports padding a list of square Tensors of size ``(L x L x ...)``.
:param batch:
:param padding_value:
:param collate_fn_map:
:return: padded_batch, lengths
"""
lengths = [tensor.size(0) for tensor in batch]
if any(element != lengths[0] for element in lengths[1:]):
try:
# Tensors share at least one common dimension size, use pad_sequence
batch = pad_sequence(batch, batch_first=True, padding_value=padding_value)
except RuntimeError:
# Tensors do not share any common dimension size, find the max size of each dimension in the batch
max_sizes = [max([tensor.size(dim) for tensor in batch]) for dim in range(batch[0].dim())]
# Pad every dimension of all tensors in the batch to be the respective max size with the value
batch = collate_tensor_fn([
torch.nn.functional.pad(
tensor, tuple(chain.from_iterable(
[(0, max_sizes[dim] - tensor.size(dim)) for dim in range(tensor.dim())][::-1])
), mode='constant', value=padding_value) for tensor in batch
])
else:
batch = collate_tensor_fn(batch)
lengths = torch.as_tensor(lengths)
# Return the padded batch tensor and the lengths
return batch, lengths
# Join custom collate functions with the default collation map of PyTorch
COLLATE_FN_MAP = default_collate_fn_map | {
torch_geometric.data.data.BaseData: collate_pyg_fn,
dgl.DGLGraph: collate_dgl_fn,
}
def collate_fn(batch, automatic_padding=False, padding_value=0):
if automatic_padding:
COLLATE_FN_MAP.update({
torch.Tensor: partial(pad_collate_tensor_fn, padding_value=padding_value),
})
return collate(batch, collate_fn_map=COLLATE_FN_MAP)
class VariableLengthSequence(torch.Tensor):
"""
A custom PyTorch Tensor class that is similar to PackedSequence, except it can be directly used as a batch tensor,
and it has an attribute called lengths, which signifies the length of each original sequence in the batch.
"""
def __new__(cls, data, lengths):
"""
Creates a new VariableLengthSequence object from the given data and lengths.
Args:
data (torch.Tensor): The batch collated tensor of shape (batch_size, max_length, *).
lengths (torch.Tensor): The lengths of each original sequence in the batch of shape (batch_size,).
Returns:
VariableLengthSequence: A new VariableLengthSequence object.
"""
# Check the validity of the inputs
assert isinstance(data, torch.Tensor), "data must be a torch.Tensor"
assert isinstance(lengths, torch.Tensor), "lengths must be a torch.Tensor"
assert data.dim() >= 2, "data must have at least two dimensions"
assert lengths.dim() == 1, "lengths must have one dimension"
assert data.size(0) == lengths.size(0), "data and lengths must have the same batch size"
assert lengths.min() > 0, "lengths must be positive"
assert lengths.max() <= data.size(1), "lengths must not exceed the max length of data"
# Create a new tensor object from data
obj = super().__new__(cls, data)
# Set the lengths attribute
obj.lengths = lengths
return obj
def __repr__(self, *, tensor_contents=None):
"""
Returns a string representation of the VariableLengthSequence object.
"""
return f"VariableLengthSequence(data={self.data}, lengths={self.lengths})"
def __reduce_ex__(self, proto):
"""
Enables pickling of the VariableLengthSequence object.
"""
return type(self), (self.data, self.lengths)
|