Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Form | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from kokoro import KPipeline | |
| import soundfile as sf | |
| import torch | |
| import os | |
| import tempfile | |
| import uuid | |
| import logging | |
| from typing import Optional | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="Kokoro TTS API", description="Text-to-Speech API using Kokoro", version="1.0.0") | |
| class TTSRequest(BaseModel): | |
| text: str | |
| voice: str = "af_heart" | |
| lang_code: str = "a" | |
| class KokoroTTSService: | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {self.device}") | |
| try: | |
| # Initialize Kokoro pipeline with default language | |
| self.pipeline = KPipeline(lang_code='a') | |
| logger.info("Kokoro TTS pipeline loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load Kokoro TTS pipeline: {e}") | |
| raise e | |
| def generate_speech(self, text: str, voice: str = "af_heart", lang_code: str = "a") -> str: | |
| """Generate speech and return the path to the output file""" | |
| try: | |
| # Create a unique filename for the output | |
| output_filename = f"kokoro_output_{uuid.uuid4().hex}.wav" | |
| output_path = os.path.join(tempfile.gettempdir(), output_filename) | |
| # Update pipeline language if different | |
| if self.pipeline.lang_code != lang_code: | |
| self.pipeline = KPipeline(lang_code=lang_code) | |
| # Generate speech using Kokoro | |
| generator = self.pipeline(text, voice=voice) | |
| # Get the first (and typically only) audio output | |
| for i, (gs, ps, audio) in enumerate(generator): | |
| logger.info(f"Generated audio segment {i}: gs={gs}, ps={ps}") | |
| # Save the audio to file | |
| sf.write(output_path, audio, 24000) | |
| break # Take the first generated audio | |
| return output_path | |
| except Exception as e: | |
| logger.error(f"Error generating speech: {e}") | |
| raise HTTPException(status_code=500, detail=f"Failed to generate speech: {str(e)}") | |
| def get_available_voices(self): | |
| """Return list of available voices""" | |
| # Common Kokoro voices - you may want to expand this list | |
| return [ | |
| "af_heart", "af_sky", "af_bella", "af_sarah", "af_nicole", | |
| "am_adam", "am_michael", "am_edward", "am_lewis" | |
| ] | |
| # Initialize Kokoro TTS service | |
| tts_service = KokoroTTSService() | |
| async def root(): | |
| return {"message": "Kokoro TTS API is running", "status": "healthy"} | |
| async def health_check(): | |
| return {"status": "healthy", "device": tts_service.device} | |
| async def get_voices(): | |
| """Get list of available voices""" | |
| return {"voices": tts_service.get_available_voices()} | |
| async def text_to_speech( | |
| text: str = Form(...), | |
| voice: str = Form("af_heart"), | |
| lang_code: str = Form("a") | |
| ): | |
| """ | |
| Convert text to speech using Kokoro TTS | |
| - **text**: The text to convert to speech | |
| - **voice**: Voice to use (default: "af_heart") | |
| - **lang_code**: Language code (default: "a" for auto-detect) | |
| """ | |
| if not text.strip(): | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| # Validate voice | |
| available_voices = tts_service.get_available_voices() | |
| if voice not in available_voices: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Voice '{voice}' not available. Available voices: {available_voices}" | |
| ) | |
| try: | |
| # Generate speech | |
| output_path = tts_service.generate_speech(text, voice, lang_code) | |
| # Return the generated audio file | |
| return FileResponse( | |
| output_path, | |
| media_type="audio/wav", | |
| filename=f"kokoro_tts_{voice}_{uuid.uuid4().hex}.wav", | |
| headers={"Content-Disposition": "attachment"} | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in TTS endpoint: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def text_to_speech_json(request: TTSRequest): | |
| """ | |
| Convert text to speech using JSON request body | |
| - **request**: TTSRequest containing text, voice, and lang_code | |
| """ | |
| if not request.text.strip(): | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| # Validate voice | |
| available_voices = tts_service.get_available_voices() | |
| if request.voice not in available_voices: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Voice '{request.voice}' not available. Available voices: {available_voices}" | |
| ) | |
| try: | |
| # Generate speech | |
| output_path = tts_service.generate_speech(request.text, request.voice, request.lang_code) | |
| # Return the generated audio file | |
| return FileResponse( | |
| output_path, | |
| media_type="audio/wav", | |
| filename=f"kokoro_tts_{request.voice}_{uuid.uuid4().hex}.wav", | |
| headers={"Content-Disposition": "attachment"} | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in TTS JSON endpoint: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) |