Spaces:
Running
on
T4
Running
on
T4
# Copyright (c) NXAI GmbH. | |
# This software may be used and distributed according to the terms of the NXAI Community License Agreement. | |
import datasets | |
import torch | |
from .standard_adapter import _batch_pad_iterable | |
DEF_TARGET_COLUMN = "target" | |
def _get_hf_map(dataset: datasets.Dataset, **hf_kwargs): | |
target_col = hf_kwargs.get("target_column", DEF_TARGET_COLUMN) | |
meta_columns = hf_kwargs.get("meta_columns", ()) | |
columns_to_pass = [target_col] + list(meta_columns) | |
remove_cols = [col for col in dataset.column_names if col not in columns_to_pass] | |
dataset = ( | |
dataset.with_format("torch") | |
.remove_columns(remove_cols) | |
.cast_column(target_col, datasets.Sequence(datasets.Value("float32"))) | |
) | |
def yield_batch_tuples(sample: dict) -> tuple[torch.Tensor, dict]: | |
context_data = sample[target_col] | |
if context_data.ndim > 1: | |
context_data = context_data.squeeze() | |
assert context_data.ndim == 1 | |
meta = {k: sample[k] for k in meta_columns if k in sample} | |
meta["length"] = len(context_data) | |
return context_data, meta | |
return dataset, yield_batch_tuples | |
def get_hfdata_batches(hf_dataset: datasets.Dataset, batch_size: int, **hf_kwargs): | |
dataset, map_func = _get_hf_map(hf_dataset, **hf_kwargs) | |
return _batch_pad_iterable(map(map_func, dataset), batch_size) | |