File size: 1,377 Bytes
14d91dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# Copyright (c) NXAI GmbH.
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.

import datasets
import torch

from .standard_adapter import _batch_pad_iterable

DEF_TARGET_COLUMN = "target"


def _get_hf_map(dataset: datasets.Dataset, **hf_kwargs):
    target_col = hf_kwargs.get("target_column", DEF_TARGET_COLUMN)
    meta_columns = hf_kwargs.get("meta_columns", ())

    columns_to_pass = [target_col] + list(meta_columns)
    remove_cols = [col for col in dataset.column_names if col not in columns_to_pass]
    dataset = (
        dataset.with_format("torch")
        .remove_columns(remove_cols)
        .cast_column(target_col, datasets.Sequence(datasets.Value("float32")))
    )

    def yield_batch_tuples(sample: dict) -> tuple[torch.Tensor, dict]:
        context_data = sample[target_col]
        if context_data.ndim > 1:
            context_data = context_data.squeeze()
        assert context_data.ndim == 1
        meta = {k: sample[k] for k in meta_columns if k in sample}
        meta["length"] = len(context_data)
        return context_data, meta

    return dataset, yield_batch_tuples


def get_hfdata_batches(hf_dataset: datasets.Dataset, batch_size: int, **hf_kwargs):
    dataset, map_func = _get_hf_map(hf_dataset, **hf_kwargs)
    return _batch_pad_iterable(map(map_func, dataset), batch_size)