Spaces:
Paused
Paused
| from fastapi import FastAPI, Query, HTTPException, BackgroundTasks | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel, Field | |
| from typing import List, Dict, Optional, Tuple, Generator | |
| import torch | |
| import os | |
| import io | |
| import numpy as np | |
| from kokoro import KModel, KPipeline | |
| import spaces | |
| import time | |
| app = FastAPI(title="Kokoro TTS API", description="API for Kokoro text-to-speech conversion") | |
| # Constants | |
| IS_DUPLICATE = not os.getenv('SPACE_ID', '').startswith('hexgrad/') | |
| CHAR_LIMIT = None if IS_DUPLICATE else 5000 | |
| CUDA_AVAILABLE = torch.cuda.is_available() | |
| # Initialize models | |
| models = {gpu: KModel().to('cuda' if gpu else 'cpu').eval() for gpu in [False] + ([True] if CUDA_AVAILABLE else [])} | |
| pipelines = {lang_code: KPipeline(lang_code=lang_code, model=False) for lang_code in 'ab'} | |
| pipelines['a'].g2p.lexicon.golds['kokoro'] = 'kหOkษษนO' | |
| pipelines['b'].g2p.lexicon.golds['kokoro'] = 'kหQkษษนQ' | |
| # Voice choices | |
| CHOICES = { | |
| '๐บ๐ธ ๐บ Heart โค๏ธ': 'af_heart', | |
| '๐บ๐ธ ๐บ Bella ๐ฅ': 'af_bella', | |
| '๐บ๐ธ ๐บ Nicole ๐ง': 'af_nicole', | |
| '๐บ๐ธ ๐บ Aoede': 'af_aoede', | |
| '๐บ๐ธ ๐บ Kore': 'af_kore', | |
| '๐บ๐ธ ๐บ Sarah': 'af_sarah', | |
| '๐บ๐ธ ๐บ Nova': 'af_nova', | |
| '๐บ๐ธ ๐บ Sky': 'af_sky', | |
| '๐บ๐ธ ๐บ Alloy': 'af_alloy', | |
| '๐บ๐ธ ๐บ Jessica': 'af_jessica', | |
| '๐บ๐ธ ๐บ River': 'af_river', | |
| '๐บ๐ธ ๐น Michael': 'am_michael', | |
| '๐บ๐ธ ๐น Fenrir': 'am_fenrir', | |
| '๐บ๐ธ ๐น Puck': 'am_puck', | |
| '๐บ๐ธ ๐น Echo': 'am_echo', | |
| '๐บ๐ธ ๐น Eric': 'am_eric', | |
| '๐บ๐ธ ๐น Liam': 'am_liam', | |
| '๐บ๐ธ ๐น Onyx': 'am_onyx', | |
| '๐บ๐ธ ๐น Santa': 'am_santa', | |
| '๐บ๐ธ ๐น Adam': 'am_adam', | |
| '๐ฌ๐ง ๐บ Emma': 'bf_emma', | |
| '๐ฌ๐ง ๐บ Isabella': 'bf_isabella', | |
| '๐ฌ๐ง ๐บ Alice': 'bf_alice', | |
| '๐ฌ๐ง ๐บ Lily': 'bf_lily', | |
| '๐ฌ๐ง ๐น George': 'bm_george', | |
| '๐ฌ๐ง ๐น Fable': 'bm_fable', | |
| '๐ฌ๐ง ๐น Lewis': 'bm_lewis', | |
| '๐ฌ๐ง ๐น Daniel': 'bm_daniel', | |
| } | |
| # Load voices | |
| for v in CHOICES.values(): | |
| pipelines[v[0]].load_voice(v) | |
| # Sample text files | |
| with open('en.txt', 'r') as r: | |
| RANDOM_QUOTES = [line.strip() for line in r] | |
| def get_gatsby(): | |
| with open('gatsby5k.md', 'r') as r: | |
| return r.read().strip() | |
| def get_frankenstein(): | |
| with open('frankenstein5k.md', 'r') as r: | |
| return r.read().strip() | |
| # Pydantic models | |
| class TTSRequest(BaseModel): | |
| text: str = Field(..., description="Text to convert to speech") | |
| voice: str = Field("af_heart", description="Voice ID to use for TTS") | |
| speed: float = Field(1.0, description="Speech speed factor (0.5 to 2.0)", ge=0.5, le=2.0) | |
| use_gpu: bool = Field(CUDA_AVAILABLE, description="Whether to use GPU for inference") | |
| class TextRequest(BaseModel): | |
| text: str = Field(..., description="Text to tokenize") | |
| voice: str = Field("af_heart", description="Voice ID to use for tokenization") | |
| class Voice(BaseModel): | |
| display_name: str | |
| id: str | |
| language: str | |
| gender: str | |
| class VoiceList(BaseModel): | |
| voices: List[Voice] | |
| # GPU wrapper function | |
| def forward_gpu(ps, ref_s, speed): | |
| return models[True](ps, ref_s, speed) | |
| # Helper functions | |
| def generate_first(text: str, voice: str = 'af_heart', speed: float = 1.0, use_gpu: bool = CUDA_AVAILABLE): | |
| """Generate audio for the first sentence/segment of text""" | |
| text = text if CHAR_LIMIT is None else text.strip()[:CHAR_LIMIT] | |
| pipeline = pipelines[voice[0]] | |
| pack = pipeline.load_voice(voice) | |
| use_gpu = use_gpu and CUDA_AVAILABLE | |
| for _, ps, _ in pipeline(text, voice, speed): | |
| ref_s = pack[len(ps)-1] | |
| try: | |
| if use_gpu: | |
| audio = forward_gpu(ps, ref_s, speed) | |
| else: | |
| audio = models[False](ps, ref_s, speed) | |
| except Exception as e: | |
| if use_gpu: | |
| # Fallback to CPU | |
| audio = models[False](ps, ref_s, speed) | |
| else: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| return (24000, audio.numpy()), ps | |
| return None, '' | |
| def tokenize_first(text: str, voice: str = 'af_heart'): | |
| """Tokenize the first sentence/segment of text""" | |
| pipeline = pipelines[voice[0]] | |
| for _, ps, _ in pipeline(text, voice): | |
| return ps | |
| return '' | |
| def generate_all(text: str, voice: str = 'af_heart', speed: float = 1.0, use_gpu: bool = CUDA_AVAILABLE) -> Generator: | |
| """Generate audio for all segments of text""" | |
| text = text if CHAR_LIMIT is None else text.strip()[:CHAR_LIMIT] | |
| pipeline = pipelines[voice[0]] | |
| pack = pipeline.load_voice(voice) | |
| use_gpu = use_gpu and CUDA_AVAILABLE | |
| for _, ps, _ in pipeline(text, voice, speed): | |
| ref_s = pack[len(ps)-1] | |
| try: | |
| if use_gpu: | |
| audio = forward_gpu(ps, ref_s, speed) | |
| else: | |
| audio = models[False](ps, ref_s, speed) | |
| except Exception as e: | |
| if use_gpu: | |
| # Fallback to CPU | |
| audio = models[False](ps, ref_s, speed) | |
| else: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| yield audio.numpy() | |
| def create_wav(audio_data, sample_rate=24000): | |
| """Convert numpy array to WAV bytes""" | |
| import wave | |
| import struct | |
| wav_io = io.BytesIO() | |
| with wave.open(wav_io, 'wb') as wav_file: | |
| wav_file.setnchannels(1) # Mono | |
| wav_file.setsampwidth(2) # 16-bit | |
| wav_file.setframerate(sample_rate) | |
| # Convert float32 to int16 | |
| audio_data = (audio_data * 32767).astype(np.int16) | |
| wav_file.writeframes(audio_data.tobytes()) | |
| wav_io.seek(0) | |
| return wav_io.read() | |
| def stream_wav_chunks(audio_chunks, sample_rate=24000): | |
| """Stream WAV chunks as they're generated""" | |
| # Write WAV header first | |
| header_io = io.BytesIO() | |
| with wave.open(header_io, 'wb') as wav_file: | |
| wav_file.setnchannels(1) # Mono | |
| wav_file.setsampwidth(2) # 16-bit | |
| wav_file.setframerate(sample_rate) | |
| # We don't know the total frames yet | |
| wav_file.writeframes(b'') | |
| # Get header bytes | |
| header_io.seek(0) | |
| header_bytes = header_io.read(44) # WAV header is 44 bytes | |
| yield header_bytes | |
| # Stream audio chunks | |
| for chunk in audio_chunks: | |
| # Convert float32 to int16 | |
| audio_data = (chunk * 32767).astype(np.int16) | |
| yield audio_data.tobytes() | |
| time.sleep(0.1) # Small delay to avoid overwhelming the client | |
| # API Routes | |
| async def root(): | |
| """API root with basic information""" | |
| return { | |
| "message": "Kokoro TTS API", | |
| "description": "Convert text to speech using Kokoro TTS model", | |
| "endpoints": { | |
| "GET /voices": "List available voices", | |
| "POST /tts": "Convert text to speech", | |
| "POST /tokenize": "Tokenize text", | |
| "GET /stream": "Stream audio from text", | |
| "GET /samples": "Get sample texts" | |
| } | |
| } | |
| async def list_voices(): | |
| """List all available voices""" | |
| voice_list = [] | |
| for display_name, voice_id in CHOICES.items(): | |
| # Parse display name format: "๐บ๐ธ ๐บ Heart โค๏ธ" | |
| parts = display_name.split() | |
| language = "US English" if "๐บ๐ธ" in display_name else "UK English" | |
| gender = "Female" if "๐บ" in display_name else "Male" | |
| voice_list.append(Voice( | |
| display_name=display_name, | |
| id=voice_id, | |
| language=language, | |
| gender=gender | |
| )) | |
| return VoiceList(voices=voice_list) | |
| async def text_to_speech(request: TTSRequest): | |
| """Convert text to speech""" | |
| if request.voice not in CHOICES.values(): | |
| raise HTTPException(status_code=400, detail=f"Voice '{request.voice}' not found. Use /voices to see available options.") | |
| result, _ = generate_first(request.text, request.voice, request.speed, request.use_gpu) | |
| if result is None: | |
| raise HTTPException(status_code=500, detail="Failed to generate audio") | |
| sample_rate, audio_data = result | |
| wav_bytes = create_wav(audio_data, sample_rate) | |
| return StreamingResponse( | |
| io.BytesIO(wav_bytes), | |
| media_type="audio/wav", | |
| headers={"Content-Disposition": f"attachment; filename=tts_{request.voice}.wav"} | |
| ) | |
| async def tokenize_text(request: TextRequest): | |
| """Tokenize input text""" | |
| if request.voice not in CHOICES.values(): | |
| raise HTTPException(status_code=400, detail=f"Voice '{request.voice}' not found. Use /voices to see available options.") | |
| tokens = tokenize_first(request.text, request.voice) | |
| return {"text": request.text, "tokens": tokens} | |
| async def stream_tts( | |
| text: str = Query(..., description="Text to convert to speech"), | |
| voice: str = Query("af_heart", description="Voice ID"), | |
| speed: float = Query(1.0, description="Speech speed", ge=0.5, le=2.0), | |
| use_gpu: bool = Query(CUDA_AVAILABLE, description="Use GPU for inference") | |
| ): | |
| """Stream audio from text as it's generated""" | |
| if voice not in CHOICES.values(): | |
| raise HTTPException(status_code=400, detail=f"Voice '{voice}' not found. Use /voices to see available options.") | |
| # Limit text if needed | |
| if CHAR_LIMIT is not None: | |
| text = text.strip()[:CHAR_LIMIT] | |
| # Create generator for audio chunks | |
| audio_chunks = generate_all(text, voice, speed, use_gpu) | |
| # Stream as WAV | |
| return StreamingResponse( | |
| stream_wav_chunks(audio_chunks), | |
| media_type="audio/wav", | |
| headers={"Content-Disposition": f"attachment; filename=stream_{voice}.wav"} | |
| ) | |
| async def get_samples(): | |
| """Get sample texts""" | |
| import random | |
| return { | |
| "random_quote": random.choice(RANDOM_QUOTES), | |
| "gatsby_excerpt": get_gatsby()[:200] + "...", # First 200 chars | |
| "frankenstein_excerpt": get_frankenstein()[:200] + "..." # First 200 chars | |
| } | |
| async def get_sample(sample_type: str): | |
| """Get a specific sample text""" | |
| import random | |
| if sample_type == "random": | |
| return {"text": random.choice(RANDOM_QUOTES)} | |
| elif sample_type == "gatsby": | |
| return {"text": get_gatsby()} | |
| elif sample_type == "frankenstein": | |
| return {"text": get_frankenstein()} | |
| else: | |
| raise HTTPException(status_code=404, detail=f"Sample type '{sample_type}' not found") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) |