Spaces:
Runtime error
Runtime error
"""Unified text processing for TTS with smart chunking.""" | |
import re | |
import time | |
from typing import AsyncGenerator, Dict, List, Tuple | |
from loguru import logger | |
from ...core.config import settings | |
from ...structures.schemas import NormalizationOptions | |
from .normalizer import normalize_text | |
from .phonemizer import phonemize | |
from .vocabulary import tokenize | |
# Pre-compiled regex patterns for performance | |
CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))") | |
def process_text_chunk( | |
text: str, language: str = "a", skip_phonemize: bool = False | |
) -> List[int]: | |
"""Process a chunk of text through normalization, phonemization, and tokenization. | |
Args: | |
text: Text chunk to process | |
language: Language code for phonemization | |
skip_phonemize: If True, treat input as phonemes and skip normalization/phonemization | |
Returns: | |
List of token IDs | |
""" | |
start_time = time.time() | |
if skip_phonemize: | |
# Input is already phonemes, just tokenize | |
t0 = time.time() | |
tokens = tokenize(text) | |
t1 = time.time() | |
else: | |
# Normal text processing pipeline | |
t0 = time.time() | |
t1 = time.time() | |
t0 = time.time() | |
phonemes = phonemize(text, language, normalize=False) # Already normalized | |
t1 = time.time() | |
t0 = time.time() | |
tokens = tokenize(phonemes) | |
t1 = time.time() | |
total_time = time.time() - start_time | |
logger.debug( | |
f"Total processing took {total_time * 1000:.2f}ms for chunk: '{text[:50]}{'...' if len(text) > 50 else ''}'" | |
) | |
return tokens | |
async def yield_chunk( | |
text: str, tokens: List[int], chunk_count: int | |
) -> Tuple[str, List[int]]: | |
"""Yield a chunk with consistent logging.""" | |
logger.debug( | |
f"Yielding chunk {chunk_count}: '{text[:50]}{'...' if len(text) > 50 else ''}' ({len(tokens)} tokens)" | |
) | |
return text, tokens | |
def process_text(text: str, language: str = "a") -> List[int]: | |
"""Process text into token IDs. | |
Args: | |
text: Text to process | |
language: Language code for phonemization | |
Returns: | |
List of token IDs | |
""" | |
if not isinstance(text, str): | |
text = str(text) if text is not None else "" | |
text = text.strip() | |
if not text: | |
return [] | |
return process_text_chunk(text, language) | |
def get_sentence_info( | |
text: str, custom_phenomes_list: Dict[str, str] | |
) -> List[Tuple[str, List[int], int]]: | |
"""Process all sentences and return info.""" | |
sentences = re.split(r"([.!?;:])(?=\s|$)", text) | |
phoneme_length, min_value = len(custom_phenomes_list), 0 | |
results = [] | |
for i in range(0, len(sentences), 2): | |
sentence = sentences[i].strip() | |
for replaced in range(min_value, phoneme_length): | |
current_id = f"</|custom_phonemes_{replaced}|/>" | |
if current_id in sentence: | |
sentence = sentence.replace( | |
current_id, custom_phenomes_list.pop(current_id) | |
) | |
min_value += 1 | |
punct = sentences[i + 1] if i + 1 < len(sentences) else "" | |
if not sentence: | |
continue | |
full = sentence + punct | |
tokens = process_text_chunk(full) | |
results.append((full, tokens, len(tokens))) | |
return results | |
def handle_custom_phonemes(s: re.Match[str], phenomes_list: Dict[str, str]) -> str: | |
latest_id = f"</|custom_phonemes_{len(phenomes_list)}|/>" | |
phenomes_list[latest_id] = s.group(0).strip() | |
return latest_id | |
async def smart_split( | |
text: str, | |
max_tokens: int = settings.absolute_max_tokens, | |
lang_code: str = "a", | |
normalization_options: NormalizationOptions = NormalizationOptions(), | |
) -> AsyncGenerator[Tuple[str, List[int]], None]: | |
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens.""" | |
start_time = time.time() | |
chunk_count = 0 | |
logger.info(f"Starting smart split for {len(text)} chars") | |
custom_phoneme_list = {} | |
# Normalize text | |
if settings.advanced_text_normalization and normalization_options.normalize: | |
print(lang_code) | |
if lang_code in ["a", "b", "en-us", "en-gb"]: | |
text = CUSTOM_PHONEMES.sub( | |
lambda s: handle_custom_phonemes(s, custom_phoneme_list), text | |
) | |
text = normalize_text(text, normalization_options) | |
else: | |
logger.info( | |
"Skipping text normalization as it is only supported for english" | |
) | |
# Process all sentences | |
sentences = get_sentence_info(text, custom_phoneme_list) | |
current_chunk = [] | |
current_tokens = [] | |
current_count = 0 | |
for sentence, tokens, count in sentences: | |
# Handle sentences that exceed max tokens | |
if count > max_tokens: | |
# Yield current chunk if any | |
if current_chunk: | |
chunk_text = " ".join(current_chunk) | |
chunk_count += 1 | |
logger.debug( | |
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" | |
) | |
yield chunk_text, current_tokens | |
current_chunk = [] | |
current_tokens = [] | |
current_count = 0 | |
# Split long sentence on commas | |
clauses = re.split(r"([,])", sentence) | |
clause_chunk = [] | |
clause_tokens = [] | |
clause_count = 0 | |
for j in range(0, len(clauses), 2): | |
clause = clauses[j].strip() | |
comma = clauses[j + 1] if j + 1 < len(clauses) else "" | |
if not clause: | |
continue | |
full_clause = clause + comma | |
tokens = process_text_chunk(full_clause) | |
count = len(tokens) | |
# If adding clause keeps us under max and not optimal yet | |
if ( | |
clause_count + count <= max_tokens | |
and clause_count + count <= settings.target_max_tokens | |
): | |
clause_chunk.append(full_clause) | |
clause_tokens.extend(tokens) | |
clause_count += count | |
else: | |
# Yield clause chunk if we have one | |
if clause_chunk: | |
chunk_text = " ".join(clause_chunk) | |
chunk_count += 1 | |
logger.debug( | |
f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)" | |
) | |
yield chunk_text, clause_tokens | |
clause_chunk = [full_clause] | |
clause_tokens = tokens | |
clause_count = count | |
# Don't forget last clause chunk | |
if clause_chunk: | |
chunk_text = " ".join(clause_chunk) | |
chunk_count += 1 | |
logger.debug( | |
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)" | |
) | |
yield chunk_text, clause_tokens | |
# Regular sentence handling | |
elif ( | |
current_count >= settings.target_min_tokens | |
and current_count + count > settings.target_max_tokens | |
): | |
# If we have a good sized chunk and adding next sentence exceeds target, | |
# yield current chunk and start new one | |
chunk_text = " ".join(current_chunk) | |
chunk_count += 1 | |
logger.info( | |
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" | |
) | |
yield chunk_text, current_tokens | |
current_chunk = [sentence] | |
current_tokens = tokens | |
current_count = count | |
elif current_count + count <= settings.target_max_tokens: | |
# Keep building chunk while under target max | |
current_chunk.append(sentence) | |
current_tokens.extend(tokens) | |
current_count += count | |
elif ( | |
current_count + count <= max_tokens | |
and current_count < settings.target_min_tokens | |
): | |
# Only exceed target max if we haven't reached minimum size yet | |
current_chunk.append(sentence) | |
current_tokens.extend(tokens) | |
current_count += count | |
else: | |
# Yield current chunk and start new one | |
if current_chunk: | |
chunk_text = " ".join(current_chunk) | |
chunk_count += 1 | |
logger.info( | |
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" | |
) | |
yield chunk_text, current_tokens | |
current_chunk = [sentence] | |
current_tokens = tokens | |
current_count = count | |
# Don't forget the last chunk | |
if current_chunk: | |
chunk_text = " ".join(current_chunk) | |
chunk_count += 1 | |
logger.info( | |
f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" | |
) | |
yield chunk_text, current_tokens | |
total_time = time.time() - start_time | |
logger.info( | |
f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks" | |
) | |