Tomtom84's picture
Update app.py
b1de4bc verified
raw
history blame
7.74 kB
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Query
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import struct
import sys
import os
import json
import asyncio
import logging
# ===== MODEL CONFIGURATION =====
# Einfach zwischen den deutschen Modellen wechseln:
USE_KARTOFFEL_MODEL = True # True = Kartoffel, False = Canopy-Deutsch
if USE_KARTOFFEL_MODEL:
MODEL_NAME = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
TOKENIZER_NAME = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
DEFAULT_VOICE = "Jakob"
print("🥔 Using Kartoffel German Model")
else:
MODEL_NAME = "canopylabs/3b-de-ft-research_release"
TOKENIZER_NAME = "canopylabs/3b-de-ft-research_release"
DEFAULT_VOICE = "thomas"
print("🇩🇪 Using Canopy German Model")
# Add the orpheus-tts module to the path
sys.path.append(os.path.join(os.path.dirname(__file__), 'orpheus-tts'))
try:
from orpheus_tts.engine_class import OrpheusModel
except ImportError:
from engine_class import OrpheusModel
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Orpheus TTS Server", version="1.0.0")
# Add CORS middleware for web clients
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize the Orpheus model
engine = None
@app.on_event("startup")
async def startup_event():
global engine
try:
engine = OrpheusModel(
model_name=MODEL_NAME,
tokenizer=TOKENIZER_NAME
)
logger.info(f"Orpheus model loaded successfully: {MODEL_NAME}")
except Exception as e:
logger.error(f"Error loading Orpheus model: {e}")
raise e
def create_wav_header(sample_rate=24000, bits_per_sample=16, channels=1):
"""Create WAV header for audio streaming"""
byte_rate = sample_rate * channels * bits_per_sample // 8
block_align = channels * bits_per_sample // 8
data_size = 0
header = struct.pack(
'<4sI4s4sIHHIIHH4sI',
b'RIFF',
36 + data_size,
b'WAVE',
b'fmt ',
16,
1,
channels,
sample_rate,
byte_rate,
block_align,
bits_per_sample,
b'data',
data_size
)
return header
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy", "model_loaded": engine is not None}
@app.get("/tts")
async def tts_stream(
prompt: str = Query(..., description="Text to synthesize"),
voice: str = Query(DEFAULT_VOICE, description="Voice to use"),
temperature: float = Query(0.4, description="Temperature for generation"),
top_p: float = Query(0.9, description="Top-p for generation"),
max_tokens: int = Query(2000, description="Maximum tokens"),
repetition_penalty: float = Query(1.1, description="Repetition penalty")
):
"""HTTP endpoint for TTS streaming"""
if engine is None:
raise HTTPException(status_code=503, detail="Model not loaded")
def generate_audio_stream():
try:
# Send WAV header first
yield create_wav_header()
# Generate speech tokens
syn_tokens = engine.generate_speech(
prompt=prompt,
voice=voice,
repetition_penalty=repetition_penalty,
stop_token_ids=[128258],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p
)
# Stream audio chunks
for chunk in syn_tokens:
yield chunk
except Exception as e:
logger.error(f"Error in TTS generation: {e}")
raise HTTPException(status_code=500, detail=str(e))
return StreamingResponse(
generate_audio_stream(),
media_type='audio/wav',
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
}
)
@app.websocket("/ws/tts")
async def websocket_tts(websocket: WebSocket):
"""WebSocket endpoint for real-time TTS streaming"""
await websocket.accept()
if engine is None:
await websocket.send_json({"error": "Model not loaded"})
await websocket.close()
return
try:
while True:
# Receive request from client
data = await websocket.receive_text()
request = json.loads(data)
prompt = request.get("prompt", "")
voice = request.get("voice", DEFAULT_VOICE)
temperature = request.get("temperature", 0.4)
top_p = request.get("top_p", 0.9)
max_tokens = request.get("max_tokens", 2000)
repetition_penalty = request.get("repetition_penalty", 1.1)
if not prompt:
await websocket.send_json({"error": "No prompt provided"})
continue
# Send status update
await websocket.send_json({"status": "generating", "prompt": prompt})
try:
# Send WAV header
wav_header = create_wav_header()
await websocket.send_bytes(wav_header)
# Generate and stream audio
syn_tokens = engine.generate_speech(
prompt=prompt,
voice=voice,
repetition_penalty=repetition_penalty,
stop_token_ids=[128258],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p
)
chunk_count = 0
for chunk in syn_tokens:
await websocket.send_bytes(chunk)
chunk_count += 1
# Send periodic status updates
if chunk_count % 10 == 0:
await websocket.send_json({
"status": "streaming",
"chunks_sent": chunk_count
})
# Send completion status
await websocket.send_json({
"status": "completed",
"total_chunks": chunk_count
})
except Exception as e:
logger.error(f"Error in WebSocket TTS generation: {e}")
await websocket.send_json({"error": str(e)})
except WebSocketDisconnect:
logger.info("WebSocket client disconnected")
except Exception as e:
logger.error(f"WebSocket error: {e}")
await websocket.close()
@app.get("/voices")
async def get_available_voices():
"""Get list of available voices"""
if engine is None:
raise HTTPException(status_code=503, detail="Model not loaded")
return {"voices": engine.available_voices}
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"message": "Orpheus TTS Server",
"endpoints": {
"health": "/health",
"tts_http": f"/tts?prompt=your_text&voice={DEFAULT_VOICE}",
"tts_websocket": "/ws/tts",
"voices": "/voices"
},
"model_loaded": engine is not None
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")