from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Query from fastapi.responses import StreamingResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware import struct import sys import os import json import asyncio import logging # ===== MODEL CONFIGURATION ===== # Einfach zwischen den deutschen Modellen wechseln: USE_KARTOFFEL_MODEL = False # True = Kartoffel, False = Canopy-Deutsch if USE_KARTOFFEL_MODEL: MODEL_NAME = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" TOKENIZER_NAME = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" DEFAULT_VOICE = "Jakob" print("🥔 Using Kartoffel German Model") else: MODEL_NAME = "canopylabs/3b-de-ft-research_release" TOKENIZER_NAME = "canopylabs/3b-de-ft-research_release" DEFAULT_VOICE = "thomas" print("🇩🇪 Using Canopy German Model") # Add the orpheus-tts module to the path sys.path.append(os.path.join(os.path.dirname(__file__), 'orpheus-tts')) try: from orpheus_tts.engine_class import OrpheusModel except ImportError: from engine_class import OrpheusModel # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Orpheus TTS Server", version="1.0.0") # Add CORS middleware for web clients app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize the Orpheus model engine = None @app.on_event("startup") async def startup_event(): global engine try: engine = OrpheusModel( model_name=MODEL_NAME, tokenizer=TOKENIZER_NAME ) logger.info(f"Orpheus model loaded successfully: {MODEL_NAME}") except Exception as e: logger.error(f"Error loading Orpheus model: {e}") raise e def create_wav_header(sample_rate=24000, bits_per_sample=16, channels=1): """Create WAV header for audio streaming""" byte_rate = sample_rate * channels * bits_per_sample // 8 block_align = channels * bits_per_sample // 8 data_size = 0 header = struct.pack( '<4sI4s4sIHHIIHH4sI', b'RIFF', 36 + data_size, b'WAVE', b'fmt ', 16, 1, channels, sample_rate, byte_rate, block_align, bits_per_sample, b'data', data_size ) return header @app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "healthy", "model_loaded": engine is not None} @app.get("/tts") async def tts_stream( prompt: str = Query(..., description="Text to synthesize"), voice: str = Query(DEFAULT_VOICE, description="Voice to use"), temperature: float = Query(0.4, description="Temperature for generation"), top_p: float = Query(0.9, description="Top-p for generation"), max_tokens: int = Query(2000, description="Maximum tokens"), repetition_penalty: float = Query(1.1, description="Repetition penalty") ): """HTTP endpoint for TTS streaming""" if engine is None: raise HTTPException(status_code=503, detail="Model not loaded") def generate_audio_stream(): try: # Send WAV header first yield create_wav_header() # Generate speech tokens syn_tokens = engine.generate_speech( prompt=prompt, voice=voice, repetition_penalty=repetition_penalty, stop_token_ids=[128258], max_tokens=max_tokens, temperature=temperature, top_p=top_p ) # Stream audio chunks for chunk in syn_tokens: yield chunk except Exception as e: logger.error(f"Error in TTS generation: {e}") raise HTTPException(status_code=500, detail=str(e)) return StreamingResponse( generate_audio_stream(), media_type='audio/wav', headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", } ) @app.websocket("/ws/tts") async def websocket_tts(websocket: WebSocket): """WebSocket endpoint for real-time TTS streaming""" await websocket.accept() if engine is None: await websocket.send_json({"error": "Model not loaded"}) await websocket.close() return try: while True: # Receive request from client data = await websocket.receive_text() request = json.loads(data) prompt = request.get("prompt", "") voice = request.get("voice", DEFAULT_VOICE) temperature = request.get("temperature", 0.4) top_p = request.get("top_p", 0.9) max_tokens = request.get("max_tokens", 2000) repetition_penalty = request.get("repetition_penalty", 1.1) if not prompt: await websocket.send_json({"error": "No prompt provided"}) continue # Send status update await websocket.send_json({"status": "generating", "prompt": prompt}) try: # Send WAV header wav_header = create_wav_header() await websocket.send_bytes(wav_header) # Generate and stream audio syn_tokens = engine.generate_speech( prompt=prompt, voice=voice, repetition_penalty=repetition_penalty, stop_token_ids=[128258], max_tokens=max_tokens, temperature=temperature, top_p=top_p ) chunk_count = 0 for chunk in syn_tokens: await websocket.send_bytes(chunk) chunk_count += 1 # Send periodic status updates if chunk_count % 10 == 0: await websocket.send_json({ "status": "streaming", "chunks_sent": chunk_count }) # Send completion status await websocket.send_json({ "status": "completed", "total_chunks": chunk_count }) except Exception as e: logger.error(f"Error in WebSocket TTS generation: {e}") await websocket.send_json({"error": str(e)}) except WebSocketDisconnect: logger.info("WebSocket client disconnected") except Exception as e: logger.error(f"WebSocket error: {e}") await websocket.close() @app.get("/voices") async def get_available_voices(): """Get list of available voices""" if engine is None: raise HTTPException(status_code=503, detail="Model not loaded") return {"voices": engine.available_voices} @app.get("/") async def root(): """Root endpoint with API information""" return { "message": "Orpheus TTS Server", "endpoints": { "health": "/health", "tts_http": f"/tts?prompt=your_text&voice={DEFAULT_VOICE}", "tts_websocket": "/ws/tts", "voices": "/voices" }, "model_loaded": engine is not None } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")