File size: 9,560 Bytes
05b45a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
"""Unified text processing for TTS with smart chunking."""

import re
import time
from typing import AsyncGenerator, Dict, List, Tuple

from loguru import logger

from ...core.config import settings
from ...structures.schemas import NormalizationOptions
from .normalizer import normalize_text
from .phonemizer import phonemize
from .vocabulary import tokenize

# Pre-compiled regex patterns for performance
CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))")


def process_text_chunk(
    text: str, language: str = "a", skip_phonemize: bool = False
) -> List[int]:
    """Process a chunk of text through normalization, phonemization, and tokenization.

    Args:
        text: Text chunk to process
        language: Language code for phonemization
        skip_phonemize: If True, treat input as phonemes and skip normalization/phonemization

    Returns:
        List of token IDs
    """
    start_time = time.time()

    if skip_phonemize:
        # Input is already phonemes, just tokenize
        t0 = time.time()
        tokens = tokenize(text)
        t1 = time.time()
    else:
        # Normal text processing pipeline
        t0 = time.time()
        t1 = time.time()

        t0 = time.time()
        phonemes = phonemize(text, language, normalize=False)  # Already normalized
        t1 = time.time()

        t0 = time.time()
        tokens = tokenize(phonemes)
        t1 = time.time()

    total_time = time.time() - start_time
    logger.debug(
        f"Total processing took {total_time * 1000:.2f}ms for chunk: '{text[:50]}{'...' if len(text) > 50 else ''}'"
    )

    return tokens


async def yield_chunk(
    text: str, tokens: List[int], chunk_count: int
) -> Tuple[str, List[int]]:
    """Yield a chunk with consistent logging."""
    logger.debug(
        f"Yielding chunk {chunk_count}: '{text[:50]}{'...' if len(text) > 50 else ''}' ({len(tokens)} tokens)"
    )
    return text, tokens


def process_text(text: str, language: str = "a") -> List[int]:
    """Process text into token IDs.

    Args:
        text: Text to process
        language: Language code for phonemization

    Returns:
        List of token IDs
    """
    if not isinstance(text, str):
        text = str(text) if text is not None else ""

    text = text.strip()
    if not text:
        return []

    return process_text_chunk(text, language)


def get_sentence_info(
    text: str, custom_phenomes_list: Dict[str, str]
) -> List[Tuple[str, List[int], int]]:
    """Process all sentences and return info."""
    sentences = re.split(r"([.!?;:])(?=\s|$)", text)
    phoneme_length, min_value = len(custom_phenomes_list), 0

    results = []
    for i in range(0, len(sentences), 2):
        sentence = sentences[i].strip()
        for replaced in range(min_value, phoneme_length):
            current_id = f"</|custom_phonemes_{replaced}|/>"
            if current_id in sentence:
                sentence = sentence.replace(
                    current_id, custom_phenomes_list.pop(current_id)
                )
                min_value += 1

        punct = sentences[i + 1] if i + 1 < len(sentences) else ""

        if not sentence:
            continue

        full = sentence + punct
        tokens = process_text_chunk(full)
        results.append((full, tokens, len(tokens)))

    return results


def handle_custom_phonemes(s: re.Match[str], phenomes_list: Dict[str, str]) -> str:
    latest_id = f"</|custom_phonemes_{len(phenomes_list)}|/>"
    phenomes_list[latest_id] = s.group(0).strip()
    return latest_id


async def smart_split(
    text: str,
    max_tokens: int = settings.absolute_max_tokens,
    lang_code: str = "a",
    normalization_options: NormalizationOptions = NormalizationOptions(),
) -> AsyncGenerator[Tuple[str, List[int]], None]:
    """Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
    start_time = time.time()
    chunk_count = 0
    logger.info(f"Starting smart split for {len(text)} chars")

    custom_phoneme_list = {}

    # Normalize text
    if settings.advanced_text_normalization and normalization_options.normalize:
        print(lang_code)
        if lang_code in ["a", "b", "en-us", "en-gb"]:
            text = CUSTOM_PHONEMES.sub(
                lambda s: handle_custom_phonemes(s, custom_phoneme_list), text
            )
            text = normalize_text(text, normalization_options)
        else:
            logger.info(
                "Skipping text normalization as it is only supported for english"
            )

    # Process all sentences
    sentences = get_sentence_info(text, custom_phoneme_list)

    current_chunk = []
    current_tokens = []
    current_count = 0

    for sentence, tokens, count in sentences:
        # Handle sentences that exceed max tokens
        if count > max_tokens:
            # Yield current chunk if any
            if current_chunk:
                chunk_text = " ".join(current_chunk)
                chunk_count += 1
                logger.debug(
                    f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
                )
                yield chunk_text, current_tokens
                current_chunk = []
                current_tokens = []
                current_count = 0

            # Split long sentence on commas
            clauses = re.split(r"([,])", sentence)
            clause_chunk = []
            clause_tokens = []
            clause_count = 0

            for j in range(0, len(clauses), 2):
                clause = clauses[j].strip()
                comma = clauses[j + 1] if j + 1 < len(clauses) else ""

                if not clause:
                    continue

                full_clause = clause + comma

                tokens = process_text_chunk(full_clause)
                count = len(tokens)

                # If adding clause keeps us under max and not optimal yet
                if (
                    clause_count + count <= max_tokens
                    and clause_count + count <= settings.target_max_tokens
                ):
                    clause_chunk.append(full_clause)
                    clause_tokens.extend(tokens)
                    clause_count += count
                else:
                    # Yield clause chunk if we have one
                    if clause_chunk:
                        chunk_text = " ".join(clause_chunk)
                        chunk_count += 1
                        logger.debug(
                            f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
                        )
                        yield chunk_text, clause_tokens
                    clause_chunk = [full_clause]
                    clause_tokens = tokens
                    clause_count = count

            # Don't forget last clause chunk
            if clause_chunk:
                chunk_text = " ".join(clause_chunk)
                chunk_count += 1
                logger.debug(
                    f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
                )
                yield chunk_text, clause_tokens

        # Regular sentence handling
        elif (
            current_count >= settings.target_min_tokens
            and current_count + count > settings.target_max_tokens
        ):
            # If we have a good sized chunk and adding next sentence exceeds target,
            # yield current chunk and start new one
            chunk_text = " ".join(current_chunk)
            chunk_count += 1
            logger.info(
                f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
            )
            yield chunk_text, current_tokens
            current_chunk = [sentence]
            current_tokens = tokens
            current_count = count
        elif current_count + count <= settings.target_max_tokens:
            # Keep building chunk while under target max
            current_chunk.append(sentence)
            current_tokens.extend(tokens)
            current_count += count
        elif (
            current_count + count <= max_tokens
            and current_count < settings.target_min_tokens
        ):
            # Only exceed target max if we haven't reached minimum size yet
            current_chunk.append(sentence)
            current_tokens.extend(tokens)
            current_count += count
        else:
            # Yield current chunk and start new one
            if current_chunk:
                chunk_text = " ".join(current_chunk)
                chunk_count += 1
                logger.info(
                    f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
                )
                yield chunk_text, current_tokens
            current_chunk = [sentence]
            current_tokens = tokens
            current_count = count

    # Don't forget the last chunk
    if current_chunk:
        chunk_text = " ".join(current_chunk)
        chunk_count += 1
        logger.info(
            f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
        )
        yield chunk_text, current_tokens

    total_time = time.time() - start_time
    logger.info(
        f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks"
    )