|
from typing import List |
|
|
|
import torch |
|
from datasets import IterableDataset |
|
from .prompt_tokenizers import PromptTokenizingStrategy |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TokenizedPromptDataset(IterableDataset): |
|
def __init__( |
|
self, |
|
prompt_tokenizer: PromptTokenizingStrategy, |
|
dataset: IterableDataset, |
|
): |
|
self.prompt_tokenizer = prompt_tokenizer |
|
self.dataset = dataset |
|
|
|
def __iter__(self): |
|
iterator = iter(self.dataset) |
|
yield self.prompt_tokenizer.tokenize_prompt(next(iterator)) |
|
|
|
|
|
class ConstantLengthDataset(IterableDataset): |
|
""" |
|
Iterable dataset that returns constant length chunks of tokens from stream of text files. |
|
Args: |
|
tokenizer (Tokenizer): The processor used for proccessing the data. |
|
dataset (dataset.Dataset): Dataset with text files. |
|
infinite (bool): If True the iterator is reset after dataset reaches end else stops. |
|
seq_length (int): Length of token sequences to return. |
|
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
tokenizer, |
|
datasets, |
|
infinite=False, |
|
seq_length=2048, |
|
num_of_sequences=1024, |
|
chars_per_token=3.6, |
|
): |
|
self.tokenizer = tokenizer |
|
self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else args.eos_token_id |
|
self.datasets: List[IterableDataset] = datasets |
|
self.seq_length = seq_length |
|
self.infinite = infinite |
|
self.current_size = 0 |
|
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences |
|
|
|
def __iter__(self): |
|
iterator = iter(self.datasets) |
|
more_examples = True |
|
while more_examples: |
|
buffer, buffer_len = [], 0 |
|
while True: |
|
if buffer_len >= self.max_buffer_size: |
|
break |
|
try: |
|
buffer.append(next(iterator)) |
|
buffer_len += len(buffer[-1]) |
|
except StopIteration: |
|
if self.infinite: |
|
iterator = iter(self.datasets) |
|
else: |
|
more_examples = False |
|
break |
|
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] |
|
all_token_ids = [] |
|
for tokenized_input in tokenized_inputs: |
|
all_token_ids.extend(tokenized_input + [self.concat_token_id]) |
|
for i in range(0, len(all_token_ids), self.seq_length): |
|
input_ids = all_token_ids[i : i + self.seq_length] |
|
if len(input_ids) == self.seq_length: |
|
self.current_size += 1 |
|
yield { |
|
"input_ids": torch.LongTensor(input_ids), |
|
"labels": torch.LongTensor(input_ids), |
|
"attention_masks": torch.LongTensor(input_ids), |
|
} |
|
|