Spaces:
Running
on
T4
Running
on
T4
# Copyright (c) NXAI GmbH. | |
# This software may be used and distributed according to the terms of the NXAI Community License Agreement. | |
import itertools | |
from collections.abc import Iterable, Iterator, Sequence | |
from typing import Union | |
import numpy as np | |
import torch | |
ContextType = Union[ | |
torch.Tensor, | |
np.ndarray, | |
list[torch.Tensor], | |
list[np.ndarray], | |
] | |
def _batched_slice(full_batch, full_meta: list[dict] | None, batch_size: int) -> Iterator[tuple[Sequence, list[dict]]]: | |
if len(full_batch) <= batch_size: | |
yield full_batch, full_meta if full_meta is not None else [{} for _ in range(len(full_batch))] | |
else: | |
for i in range(0, len(full_batch), batch_size): | |
batch = full_batch[i : i + batch_size] | |
yield batch, (full_meta[i : i + batch_size] if full_meta is not None else [{} for _ in range(len(batch))]) | |
def _batched(iterable: Iterable, n: int): | |
it = iter(iterable) | |
while batch := tuple(itertools.islice(it, n)): | |
yield batch | |
def _batch_pad_iterable(iterable: Iterable[tuple[torch.Tensor, dict]], batch_size: int): | |
for batch in _batched(iterable, batch_size): | |
# ctx_it_len, ctx_it_data, it_meta = itertools.tee(batch, 3) | |
max_len = max(len(el[0]) for el in batch) | |
padded_batch = [] | |
meta = [] | |
for el in batch: | |
sample = el[0] | |
assert isinstance(sample, torch.Tensor) | |
assert sample.ndim == 1 | |
assert len(sample) > 0, "Each sample needs to have a length > 0" | |
padding = torch.full(size=(max_len - len(sample),), fill_value=torch.nan, device=sample.device) | |
padded_batch.append(torch.cat((padding, sample))) | |
meta.append(el[1]) | |
yield torch.stack(padded_batch), meta | |
def get_batches(context: ContextType, batch_size: int): | |
batches = None | |
if isinstance(context, torch.Tensor): | |
if context.ndim == 1: | |
context = context.unsqueeze(0) | |
assert context.ndim == 2 | |
batches = _batched_slice(context, None, batch_size) | |
elif isinstance(context, np.ndarray): | |
if context.ndim == 1: | |
context = np.expand_dims(context, axis=0) | |
assert context.ndim == 2 | |
batches = map(lambda x: (torch.Tensor(x[0]), x[1]), _batched_slice(context, None, batch_size)) | |
elif isinstance(context, (list, Iterable)): | |
batches = _batch_pad_iterable(map(lambda x: (torch.Tensor(x), None), context), batch_size) | |
if batches is None: | |
raise ValueError(f"Context type {type(context)} not supported! Supported Types: {ContextType}") | |
return batches | |