| | import time |
| |
|
| | from fish_speech.utils.schema import ( |
| | ServeStreamDelta, |
| | ServeStreamResponse, |
| | ServeTextPart, |
| | ServeVQPart, |
| | ) |
| |
|
| |
|
| | def initialize_decode_buffers(num_samples): |
| | """Initialise the decode buffers for each sample.""" |
| | decode_buffer = [[] for _ in range(num_samples)] |
| | parts = [[] for _ in range(num_samples)] |
| | finished = [False for _ in range(num_samples)] |
| | return decode_buffer, parts, finished |
| |
|
| |
|
| | def send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request): |
| | """Send the remaining text buffer for a sample.""" |
| | if len(decode_buffer[sample_id]) == 0: |
| | return [] |
| |
|
| | decoded = tokenizer.decode(decode_buffer[sample_id]) |
| | part = ServeTextPart(text=decoded) |
| |
|
| | responses = [] |
| | if request.streaming: |
| | responses.append(ServeStreamResponse(delta=ServeStreamDelta(part=part))) |
| | else: |
| | parts[sample_id].append(part) |
| |
|
| | decode_buffer[sample_id] = [] |
| | return responses |
| |
|
| |
|
| | def handle_semantic_tokens(tokens, config, sample_id, parts, request): |
| | """Handle the semantic tokens returned by the model.""" |
| | responses = [] |
| | _tokens = tokens[1:].clone() |
| |
|
| | if not config.share_codebook_embeddings: |
| | for i in range(len(_tokens)): |
| | _tokens[i] -= config.codebook_size * i |
| |
|
| | |
| | if request.streaming: |
| | responses.append( |
| | ServeStreamResponse( |
| | sample_id=sample_id, |
| | delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())), |
| | ) |
| | ) |
| | else: |
| | |
| | if not parts[sample_id] or not isinstance(parts[sample_id][-1], ServeVQPart): |
| | parts[sample_id].append(ServeVQPart(codes=_tokens.tolist())) |
| | else: |
| | |
| | for codebook_id, value in enumerate(_tokens): |
| | parts[sample_id][-1].codes[codebook_id].append(value.item()) |
| |
|
| | return responses |
| |
|
| |
|
| | def process_response_tokens( |
| | response, |
| | tokenizer, |
| | config, |
| | request, |
| | decode_buffer, |
| | parts, |
| | finished, |
| | im_end_id, |
| | stats, |
| | start, |
| | is_first_token, |
| | ): |
| | """Process the response tokens returned by the model.""" |
| | responses = [] |
| | for sample_id, tokens in enumerate(response): |
| | if finished[sample_id]: |
| | continue |
| |
|
| | |
| | if tokens[0] == im_end_id: |
| | finished[sample_id] = True |
| | |
| | responses.extend( |
| | send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request) |
| | ) |
| | if request.streaming: |
| | responses.append( |
| | ServeStreamResponse( |
| | sample_id=sample_id, |
| | finish_reason="stop", |
| | stats=stats, |
| | ) |
| | ) |
| | continue |
| |
|
| | |
| | is_semantic = ( |
| | tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id |
| | ) |
| |
|
| | if is_semantic: |
| | |
| | responses.extend( |
| | send_reset_buffer(sample_id, decode_buffer, tokenizer, parts, request) |
| | ) |
| | responses.extend( |
| | handle_semantic_tokens(tokens, config, sample_id, parts, request) |
| | ) |
| | else: |
| | |
| | decode_buffer[sample_id].append(tokens[0, 0]) |
| |
|
| | if is_first_token: |
| | stats["time_to_first_token"] = (time.time() - start) * 1000 |
| |
|
| | return responses |
| |
|