Spaces:
Running
on
T4
Running
on
T4
File size: 2,618 Bytes
14d91dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
# Copyright (c) NXAI GmbH.
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
import itertools
from collections.abc import Iterable, Iterator, Sequence
from typing import Union
import numpy as np
import torch
ContextType = Union[
torch.Tensor,
np.ndarray,
list[torch.Tensor],
list[np.ndarray],
]
def _batched_slice(full_batch, full_meta: list[dict] | None, batch_size: int) -> Iterator[tuple[Sequence, list[dict]]]:
if len(full_batch) <= batch_size:
yield full_batch, full_meta if full_meta is not None else [{} for _ in range(len(full_batch))]
else:
for i in range(0, len(full_batch), batch_size):
batch = full_batch[i : i + batch_size]
yield batch, (full_meta[i : i + batch_size] if full_meta is not None else [{} for _ in range(len(batch))])
def _batched(iterable: Iterable, n: int):
it = iter(iterable)
while batch := tuple(itertools.islice(it, n)):
yield batch
def _batch_pad_iterable(iterable: Iterable[tuple[torch.Tensor, dict]], batch_size: int):
for batch in _batched(iterable, batch_size):
# ctx_it_len, ctx_it_data, it_meta = itertools.tee(batch, 3)
max_len = max(len(el[0]) for el in batch)
padded_batch = []
meta = []
for el in batch:
sample = el[0]
assert isinstance(sample, torch.Tensor)
assert sample.ndim == 1
assert len(sample) > 0, "Each sample needs to have a length > 0"
padding = torch.full(size=(max_len - len(sample),), fill_value=torch.nan, device=sample.device)
padded_batch.append(torch.cat((padding, sample)))
meta.append(el[1])
yield torch.stack(padded_batch), meta
def get_batches(context: ContextType, batch_size: int):
batches = None
if isinstance(context, torch.Tensor):
if context.ndim == 1:
context = context.unsqueeze(0)
assert context.ndim == 2
batches = _batched_slice(context, None, batch_size)
elif isinstance(context, np.ndarray):
if context.ndim == 1:
context = np.expand_dims(context, axis=0)
assert context.ndim == 2
batches = map(lambda x: (torch.Tensor(x[0]), x[1]), _batched_slice(context, None, batch_size))
elif isinstance(context, (list, Iterable)):
batches = _batch_pad_iterable(map(lambda x: (torch.Tensor(x), None), context), batch_size)
if batches is None:
raise ValueError(f"Context type {type(context)} not supported! Supported Types: {ContextType}")
return batches
|