|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Query |
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import struct |
|
import sys |
|
import os |
|
import json |
|
import asyncio |
|
import logging |
|
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), 'orpheus-tts')) |
|
|
|
try: |
|
from orpheus_tts.engine_class import OrpheusModel |
|
except ImportError: |
|
from engine_class import OrpheusModel |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI(title="Orpheus TTS Server", version="1.0.0") |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
engine = None |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
global engine |
|
try: |
|
engine = OrpheusModel( |
|
model_name="SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1", |
|
tokenizer="SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" |
|
) |
|
logger.info("Orpheus model loaded successfully") |
|
except Exception as e: |
|
logger.error(f"Error loading Orpheus model: {e}") |
|
raise e |
|
|
|
def create_wav_header(sample_rate=24000, bits_per_sample=16, channels=1): |
|
"""Create WAV header for audio streaming""" |
|
byte_rate = sample_rate * channels * bits_per_sample // 8 |
|
block_align = channels * bits_per_sample // 8 |
|
data_size = 0 |
|
|
|
header = struct.pack( |
|
'<4sI4s4sIHHIIHH4sI', |
|
b'RIFF', |
|
36 + data_size, |
|
b'WAVE', |
|
b'fmt ', |
|
16, |
|
1, |
|
channels, |
|
sample_rate, |
|
byte_rate, |
|
block_align, |
|
bits_per_sample, |
|
b'data', |
|
data_size |
|
) |
|
return header |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
"""Health check endpoint""" |
|
return {"status": "healthy", "model_loaded": engine is not None} |
|
|
|
@app.get("/tts") |
|
async def tts_stream( |
|
prompt: str = Query(..., description="Text to synthesize"), |
|
voice: str = Query("Jakob", description="Voice to use"), |
|
temperature: float = Query(0.4, description="Temperature for generation"), |
|
top_p: float = Query(0.9, description="Top-p for generation"), |
|
max_tokens: int = Query(2000, description="Maximum tokens"), |
|
repetition_penalty: float = Query(1.1, description="Repetition penalty") |
|
): |
|
"""HTTP endpoint for TTS streaming""" |
|
if engine is None: |
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
def generate_audio_stream(): |
|
try: |
|
|
|
yield create_wav_header() |
|
|
|
|
|
syn_tokens = engine.generate_speech( |
|
prompt=prompt, |
|
voice=voice, |
|
repetition_penalty=repetition_penalty, |
|
stop_token_ids=[128258], |
|
max_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p |
|
) |
|
|
|
|
|
for chunk in syn_tokens: |
|
yield chunk |
|
|
|
except Exception as e: |
|
logger.error(f"Error in TTS generation: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
return StreamingResponse( |
|
generate_audio_stream(), |
|
media_type='audio/wav', |
|
headers={ |
|
"Cache-Control": "no-cache", |
|
"Connection": "keep-alive", |
|
} |
|
) |
|
|
|
@app.websocket("/ws/tts") |
|
async def websocket_tts(websocket: WebSocket): |
|
"""WebSocket endpoint for real-time TTS streaming""" |
|
await websocket.accept() |
|
|
|
if engine is None: |
|
await websocket.send_json({"error": "Model not loaded"}) |
|
await websocket.close() |
|
return |
|
|
|
try: |
|
while True: |
|
|
|
data = await websocket.receive_text() |
|
request = json.loads(data) |
|
|
|
prompt = request.get("prompt", "") |
|
voice = request.get("voice", "Jakob") |
|
temperature = request.get("temperature", 0.4) |
|
top_p = request.get("top_p", 0.9) |
|
max_tokens = request.get("max_tokens", 2000) |
|
repetition_penalty = request.get("repetition_penalty", 1.1) |
|
|
|
if not prompt: |
|
await websocket.send_json({"error": "No prompt provided"}) |
|
continue |
|
|
|
|
|
await websocket.send_json({"status": "generating", "prompt": prompt}) |
|
|
|
try: |
|
|
|
wav_header = create_wav_header() |
|
await websocket.send_bytes(wav_header) |
|
|
|
|
|
syn_tokens = engine.generate_speech( |
|
prompt=prompt, |
|
voice=voice, |
|
repetition_penalty=repetition_penalty, |
|
stop_token_ids=[128258], |
|
max_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p |
|
) |
|
|
|
chunk_count = 0 |
|
for chunk in syn_tokens: |
|
await websocket.send_bytes(chunk) |
|
chunk_count += 1 |
|
|
|
|
|
if chunk_count % 10 == 0: |
|
await websocket.send_json({ |
|
"status": "streaming", |
|
"chunks_sent": chunk_count |
|
}) |
|
|
|
|
|
await websocket.send_json({ |
|
"status": "completed", |
|
"total_chunks": chunk_count |
|
}) |
|
|
|
except Exception as e: |
|
logger.error(f"Error in WebSocket TTS generation: {e}") |
|
await websocket.send_json({"error": str(e)}) |
|
|
|
except WebSocketDisconnect: |
|
logger.info("WebSocket client disconnected") |
|
except Exception as e: |
|
logger.error(f"WebSocket error: {e}") |
|
await websocket.close() |
|
|
|
@app.get("/voices") |
|
async def get_available_voices(): |
|
"""Get list of available voices""" |
|
if engine is None: |
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
return {"voices": engine.available_voices} |
|
|
|
@app.get("/") |
|
async def root(): |
|
"""Root endpoint with API information""" |
|
return { |
|
"message": "Orpheus TTS Server", |
|
"endpoints": { |
|
"health": "/health", |
|
"tts_http": "/tts?prompt=your_text&voice=Jakob", |
|
"tts_websocket": "/ws/tts", |
|
"voices": "/voices" |
|
}, |
|
"model_loaded": engine is not None |
|
} |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info") |
|
|