Spaces:
Sleeping
Sleeping
File size: 6,467 Bytes
6ae852e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
"""
Define collate functions for new data types here
"""
from functools import partial
from itertools import chain
import dgl
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data._utils.collate import default_collate_fn_map, collate_tensor_fn, collate
import torch_geometric
def collate_pyg_fn(batch, collate_fn_map=None):
"""
PyG graph collation
"""
return torch_geometric.data.Batch.from_data_list(batch)
def collate_dgl_fn(batch, collate_fn_map=None):
"""
DGL graph collation
"""
return dgl.batch(batch)
def pad_collate_tensor_fn(batch, padding_value=0.0, collate_fn_map=None):
"""
Similar to pad_packed_sequence(pack_sequence(batch, enforce_sorted=False), batch_first=True),
but additionally supports padding a list of square Tensors of size ``(L x L x ...)``.
:param batch:
:param padding_value:
:param collate_fn_map:
:return: padded_batch, lengths
"""
lengths = [tensor.size(0) for tensor in batch]
if any(element != lengths[0] for element in lengths[1:]):
try:
# Tensors share at least one common dimension size, use pad_sequence
batch = pad_sequence(batch, batch_first=True, padding_value=padding_value)
except RuntimeError:
# Tensors do not share any common dimension size, find the max size of each dimension in the batch
max_sizes = [max([tensor.size(dim) for tensor in batch]) for dim in range(batch[0].dim())]
# Pad every dimension of all tensors in the batch to be the respective max size with the value
batch = collate_tensor_fn([
torch.nn.functional.pad(
tensor, tuple(chain.from_iterable(
[(0, max_sizes[dim] - tensor.size(dim)) for dim in range(tensor.dim())][::-1])
), mode='constant', value=padding_value) for tensor in batch
])
else:
batch = collate_tensor_fn(batch)
lengths = torch.as_tensor(lengths)
# Return the padded batch tensor and the lengths
return batch, lengths
# Join custom collate functions with the default collation map of PyTorch
COLLATE_FN_MAP = default_collate_fn_map | {
torch_geometric.data.data.BaseData: collate_pyg_fn,
dgl.DGLGraph: collate_dgl_fn,
}
def collate_fn(batch, automatic_padding=False, padding_value=0):
if automatic_padding:
COLLATE_FN_MAP.update({
torch.Tensor: partial(pad_collate_tensor_fn, padding_value=padding_value),
})
return collate(batch, collate_fn_map=COLLATE_FN_MAP)
# class VariableLengthSequence(torch.Tensor):
# """
# A custom PyTorch Tensor class that is similar to PackedSequence, except it can be directly used as a batch tensor,
# and it has an attribute called lengths, which signifies the length of each original sequence in the batch.
# """
#
# def __new__(cls, data, lengths):
# """
# Creates a new VariableLengthSequence object from the given data and lengths.
# Args:
# data (torch.Tensor): The batch collated tensor of shape (batch_size, max_length, *).
# lengths (torch.Tensor): The lengths of each original sequence in the batch of shape (batch_size,).
# Returns:
# VariableLengthSequence: A new VariableLengthSequence object.
# """
# # Check the validity of the inputs
# assert isinstance(data, torch.Tensor), "data must be a torch.Tensor"
# assert isinstance(lengths, torch.Tensor), "lengths must be a torch.Tensor"
# assert data.dim() >= 2, "data must have at least two dimensions"
# assert lengths.dim() == 1, "lengths must have one dimension"
# assert data.size(0) == lengths.size(0), "data and lengths must have the same batch size"
# assert lengths.min() > 0, "lengths must be positive"
# assert lengths.max() <= data.size(1), "lengths must not exceed the max length of data"
#
# # Create a new tensor object from data
# obj = super().__new__(cls, data)
#
# # Set the lengths attribute
# obj.lengths = lengths
#
# return obj
# class VariableLengthSequence(torch.Tensor):
# _lengths = torch.Tensor()
#
# def __new__(cls, data, lengths, *args, **kwargs):
# self = super().__new__(cls, data, *args, **kwargs)
# self.lengths = lengths
# return self
#
# def clone(self, *args, **kwargs):
# return VariableLengthSequence(super().clone(*args, **kwargs), self.lengths.clone())
#
# def new_empty(self, *size):
# return VariableLengthSequence(super().new_empty(*size), self.lengths)
#
# def to(self, *args, **kwargs):
# return VariableLengthSequence(super().to(*args, **kwargs), self.lengths.to(*args, **kwargs))
#
# def __format__(self, format_spec):
# # Convert self to a string or a number here, depending on what you need
# return self.item().__format__(format_spec)
#
# @property
# def lengths(self):
# return self._lengths
#
# @lengths.setter
# def lengths(self, lengths):
# self._lengths = lengths
#
# def cpu(self, *args, **kwargs):
# return VariableLengthSequence(super().cpu(*args, **kwargs), self.lengths.cpu(*args, **kwargs))
#
# def cuda(self, *args, **kwargs):
# return VariableLengthSequence(super().cuda(*args, **kwargs), self.lengths.cuda(*args, **kwargs))
#
# def pin_memory(self):
# return VariableLengthSequence(super().pin_memory(), self.lengths.pin_memory())
#
# def share_memory_(self):
# super().share_memory_()
# self.lengths.share_memory_()
# return self
#
# def detach_(self, *args, **kwargs):
# super().detach_(*args, **kwargs)
# self.lengths.detach_(*args, **kwargs)
# return self
#
# def detach(self, *args, **kwargs):
# return VariableLengthSequence(super().detach(*args, **kwargs), self.lengths.detach(*args, **kwargs))
#
# def record_stream(self, *args, **kwargs):
# super().record_stream(*args, **kwargs)
# self.lengths.record_stream(*args, **kwargs)
# return self
# @classmethod
# def __torch_function__(cls, func, types, args=(), kwargs=None):
# return super().__torch_function__(func, types, args, kwargs) \
# if cls.lengths is not None else torch.Tensor.__torch_function__(func, types, args, kwargs)
|