Spaces:
Running
Running
File size: 3,505 Bytes
5301c48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
from dataclasses import dataclass
from typing import Any, Callable, Collection, List, Literal, Optional, Union, AbstractSet
from starfish.data_ingest.splitter.base_splitter import TextSplitter
class TokenTextSplitter(TextSplitter):
"""Splits text into chunks using a tokenizer, with configurable chunk size and overlap."""
def __init__(
self,
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
chunk_size: int = 400,
chunk_overlap: int = 20,
**kwargs: Any,
) -> None:
"""Initialize the token splitter.
Args:
encoding_name: Name of the encoding to use
model_name: Optional model name to get encoding for
allowed_special: Special tokens to allow
disallowed_special: Special tokens to disallow
chunk_size: Maximum number of tokens per chunk
chunk_overlap: Number of overlapping tokens between chunks
"""
super().__init__(**kwargs)
self._tokenizer = self._get_tokenizer(encoding_name, model_name)
self._allowed_special = allowed_special
self._disallowed_special = disallowed_special
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
def _get_tokenizer(self, encoding_name: str, model_name: Optional[str]) -> Any:
"""Get tokenizer instance."""
try:
import tiktoken
return tiktoken.encoding_for_model(model_name) if model_name else tiktoken.get_encoding(encoding_name)
except ImportError:
raise ImportError("tiktoken package required. Install with `pip install tiktoken`.")
def split_text(self, text: str) -> List[str]:
"""Split text into chunks based on tokenization."""
tokenizer = Tokenizer(
chunk_overlap=self._chunk_overlap,
tokens_per_chunk=self._chunk_size,
decode=self._tokenizer.decode,
encode=lambda t: self._tokenizer.encode(
t,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
),
)
return split_text_on_tokens(text=text, tokenizer=tokenizer)
@dataclass(frozen=True)
class Tokenizer:
"""Tokenizer data class."""
chunk_overlap: int
"""Overlap in tokens between chunks"""
tokens_per_chunk: int
"""Maximum number of tokens per chunk"""
decode: Callable[[List[int]], str]
""" Function to decode a list of token ids to a string"""
encode: Callable[[str], List[int]]
""" Function to encode a string to a list of token ids"""
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]:
"""Split incoming text and return chunks using tokenizer."""
splits: List[str] = []
input_ids = tokenizer.encode(text)
start_idx = 0
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
while start_idx < len(input_ids):
splits.append(tokenizer.decode(chunk_ids))
if cur_idx == len(input_ids):
break
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return splits
|