Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import re | |
# from torch._six import container_abcs, string_classes, int_classes | |
from torch._six import string_classes | |
import collections | |
""" | |
Modified by Serkan Sulun | |
Filters out None samples | |
""" | |
""""Contains definitions of the methods used by the _DataLoaderIter workers to | |
collate samples fetched from dataset into Tensor(s). | |
These **needs** to be in global scope since Py2 doesn't support serializing | |
static methods. | |
""" | |
_use_shared_memory = False | |
r"""Whether to use shared memory in batch_collate""" | |
np_str_obj_array_pattern = re.compile(r'[SaUO]') | |
error_msg_fmt = "batch must contain tensors, numbers, dicts or lists; found {}" | |
numpy_type_map = { | |
'float64': torch.DoubleTensor, | |
'float32': torch.FloatTensor, | |
'float16': torch.HalfTensor, | |
'int64': torch.LongTensor, | |
'int32': torch.IntTensor, | |
'int16': torch.ShortTensor, | |
'int8': torch.CharTensor, | |
'uint8': torch.ByteTensor, | |
} | |
def filter_collate(batch): | |
r"""Puts each data field into a tensor with outer dimension batch size""" | |
if isinstance(batch, list) or isinstance(batch, tuple): | |
batch = [i for i in batch if i is not None] # filter out None s | |
if batch != []: | |
elem_type = type(batch[0]) | |
if isinstance(batch[0], torch.Tensor): | |
out = None | |
if _use_shared_memory: | |
# If we're in a background process, concatenate directly into a | |
# shared memory tensor to avoid an extra copy | |
numel = sum([x.numel() for x in batch]) | |
storage = batch[0].storage()._new_shared(numel) | |
out = batch[0].new(storage) | |
return torch.stack(batch, 0, out=out) | |
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ | |
and elem_type.__name__ != 'string_': | |
elem = batch[0] | |
if elem_type.__name__ == 'ndarray': | |
# array of string classes and object | |
if np_str_obj_array_pattern.search(elem.dtype.str) is not None: | |
raise TypeError(error_msg_fmt.format(elem.dtype)) | |
return filter_collate([torch.from_numpy(b) for b in batch]) | |
if elem.shape == (): # scalars | |
py_type = float if elem.dtype.name.startswith('float') else int | |
return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) | |
elif isinstance(batch[0], float): | |
return torch.tensor(batch, dtype=torch.float64) | |
elif isinstance(batch[0], int): | |
return torch.tensor(batch) | |
elif isinstance(batch[0], string_classes): | |
return batch | |
elif isinstance(batch[0], collections.abc.Mapping): | |
return {key: filter_collate([d[key] for d in batch]) for key in batch[0]} | |
elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'): # namedtuple | |
return type(batch[0])(*(filter_collate(samples) for samples in zip(*batch))) | |
elif isinstance(batch[0], collections.abc.Sequence): | |
transposed = zip(*batch) | |
return [filter_collate(samples) for samples in transposed] | |
raise TypeError((error_msg_fmt.format(type(batch[0])))) | |
else: | |
return batch | |