Spaces:
Running
Running
from dataclasses import dataclass | |
from typing import Any, Callable, Collection, List, Literal, Optional, Union, AbstractSet | |
from starfish.data_ingest.splitter.base_splitter import TextSplitter | |
class TokenTextSplitter(TextSplitter): | |
"""Splits text into chunks using a tokenizer, with configurable chunk size and overlap.""" | |
def __init__( | |
self, | |
encoding_name: str = "gpt2", | |
model_name: Optional[str] = None, | |
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), | |
disallowed_special: Union[Literal["all"], Collection[str]] = "all", | |
chunk_size: int = 400, | |
chunk_overlap: int = 20, | |
**kwargs: Any, | |
) -> None: | |
"""Initialize the token splitter. | |
Args: | |
encoding_name: Name of the encoding to use | |
model_name: Optional model name to get encoding for | |
allowed_special: Special tokens to allow | |
disallowed_special: Special tokens to disallow | |
chunk_size: Maximum number of tokens per chunk | |
chunk_overlap: Number of overlapping tokens between chunks | |
""" | |
super().__init__(**kwargs) | |
self._tokenizer = self._get_tokenizer(encoding_name, model_name) | |
self._allowed_special = allowed_special | |
self._disallowed_special = disallowed_special | |
self._chunk_size = chunk_size | |
self._chunk_overlap = chunk_overlap | |
def _get_tokenizer(self, encoding_name: str, model_name: Optional[str]) -> Any: | |
"""Get tokenizer instance.""" | |
try: | |
import tiktoken | |
return tiktoken.encoding_for_model(model_name) if model_name else tiktoken.get_encoding(encoding_name) | |
except ImportError: | |
raise ImportError("tiktoken package required. Install with `pip install tiktoken`.") | |
def split_text(self, text: str) -> List[str]: | |
"""Split text into chunks based on tokenization.""" | |
tokenizer = Tokenizer( | |
chunk_overlap=self._chunk_overlap, | |
tokens_per_chunk=self._chunk_size, | |
decode=self._tokenizer.decode, | |
encode=lambda t: self._tokenizer.encode( | |
t, | |
allowed_special=self._allowed_special, | |
disallowed_special=self._disallowed_special, | |
), | |
) | |
return split_text_on_tokens(text=text, tokenizer=tokenizer) | |
class Tokenizer: | |
"""Tokenizer data class.""" | |
chunk_overlap: int | |
"""Overlap in tokens between chunks""" | |
tokens_per_chunk: int | |
"""Maximum number of tokens per chunk""" | |
decode: Callable[[List[int]], str] | |
""" Function to decode a list of token ids to a string""" | |
encode: Callable[[str], List[int]] | |
""" Function to encode a string to a list of token ids""" | |
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]: | |
"""Split incoming text and return chunks using tokenizer.""" | |
splits: List[str] = [] | |
input_ids = tokenizer.encode(text) | |
start_idx = 0 | |
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) | |
chunk_ids = input_ids[start_idx:cur_idx] | |
while start_idx < len(input_ids): | |
splits.append(tokenizer.decode(chunk_ids)) | |
if cur_idx == len(input_ids): | |
break | |
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap | |
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) | |
chunk_ids = input_ids[start_idx:cur_idx] | |
return splits | |