qwerrwe / src /axolotl /datasets.py
winglian's picture
WIP for axolotl trainer
ce24f5e
raw
history blame
3.37 kB
from typing import List
import torch
from datasets import IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example
# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets
class TokenizedPromptDataset(IterableDataset):
def __init__(
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset,
):
self.prompt_tokenizer = prompt_tokenizer
self.dataset = dataset
def __iter__(self):
iterator = iter(self.dataset)
yield self.prompt_tokenizer.tokenize_prompt(next(iterator))
class ConstantLengthDataset(IterableDataset):
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
tokenizer (Tokenizer): The processor used for proccessing the data.
dataset (dataset.Dataset): Dataset with text files.
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
seq_length (int): Length of token sequences to return.
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
"""
def __init__(
self,
tokenizer,
datasets,
infinite=False,
seq_length=2048,
num_of_sequences=1024,
chars_per_token=3.6,
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else args.eos_token_id
self.datasets: List[IterableDataset] = datasets
self.seq_length = seq_length
self.infinite = infinite
self.current_size = 0
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
def __iter__(self):
iterator = iter(self.datasets)
more_examples = True
while more_examples:
buffer, buffer_len = [], 0
while True:
if buffer_len >= self.max_buffer_size:
break
try:
buffer.append(next(iterator))
buffer_len += len(buffer[-1])
except StopIteration:
if self.infinite:
iterator = iter(self.datasets)
else:
more_examples = False
break
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
all_token_ids.extend(tokenized_input + [self.concat_token_id])
for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i : i + self.seq_length]
if len(input_ids) == self.seq_length:
self.current_size += 1
yield {
"input_ids": torch.LongTensor(input_ids),
"labels": torch.LongTensor(input_ids),
"attention_masks": torch.LongTensor(input_ids),
}