File size: 7,736 Bytes
b45ed35 4715aa2 b45ed35 dbb4a9f b45ed35 dbb4a9f 1a347c6 89afaf3 1a347c6 b45ed35 4715aa2 b45ed35 4715aa2 b45ed35 55145d2 b45ed35 1d792aa b45ed35 4715aa2 b45ed35 4715aa2 b45ed35 1d792aa b45ed35 4715aa2 b45ed35 1a347c6 b45ed35 1a347c6 4715aa2 b45ed35 4715aa2 b45ed35 4715aa2 b45ed35 1a347c6 b45ed35 4715aa2 b45ed35 4715aa2 b45ed35 1a347c6 b45ed35 55145d2 b45ed35 4715aa2 b45ed35 4715aa2 b45ed35 1a347c6 b45ed35 55145d2 b45ed35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Query
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import struct
import sys
import os
import json
import asyncio
import logging
# ===== MODEL CONFIGURATION =====
# Einfach zwischen den deutschen Modellen wechseln:
USE_KARTOFFEL_MODEL = False # True = Kartoffel, False = Canopy-Deutsch
if USE_KARTOFFEL_MODEL:
MODEL_NAME = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
TOKENIZER_NAME = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
DEFAULT_VOICE = "Jakob"
print("🥔 Using Kartoffel German Model")
else:
MODEL_NAME = "canopylabs/3b-de-ft-research_release"
TOKENIZER_NAME = "canopylabs/3b-de-ft-research_release"
DEFAULT_VOICE = "thomas"
print("🇩🇪 Using Canopy German Model")
# Add the orpheus-tts module to the path
sys.path.append(os.path.join(os.path.dirname(__file__), 'orpheus-tts'))
try:
from orpheus_tts.engine_class import OrpheusModel
except ImportError:
from engine_class import OrpheusModel
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Orpheus TTS Server", version="1.0.0")
# Add CORS middleware for web clients
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize the Orpheus model
engine = None
@app.on_event("startup")
async def startup_event():
global engine
try:
engine = OrpheusModel(
model_name=MODEL_NAME,
tokenizer=TOKENIZER_NAME
)
logger.info(f"Orpheus model loaded successfully: {MODEL_NAME}")
except Exception as e:
logger.error(f"Error loading Orpheus model: {e}")
raise e
def create_wav_header(sample_rate=24000, bits_per_sample=16, channels=1):
"""Create WAV header for audio streaming"""
byte_rate = sample_rate * channels * bits_per_sample // 8
block_align = channels * bits_per_sample // 8
data_size = 0
header = struct.pack(
'<4sI4s4sIHHIIHH4sI',
b'RIFF',
36 + data_size,
b'WAVE',
b'fmt ',
16,
1,
channels,
sample_rate,
byte_rate,
block_align,
bits_per_sample,
b'data',
data_size
)
return header
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy", "model_loaded": engine is not None}
@app.get("/tts")
async def tts_stream(
prompt: str = Query(..., description="Text to synthesize"),
voice: str = Query(DEFAULT_VOICE, description="Voice to use"),
temperature: float = Query(0.4, description="Temperature for generation"),
top_p: float = Query(0.9, description="Top-p for generation"),
max_tokens: int = Query(2000, description="Maximum tokens"),
repetition_penalty: float = Query(1.1, description="Repetition penalty")
):
"""HTTP endpoint for TTS streaming"""
if engine is None:
raise HTTPException(status_code=503, detail="Model not loaded")
def generate_audio_stream():
try:
# Send WAV header first
yield create_wav_header()
# Generate speech tokens
syn_tokens = engine.generate_speech(
prompt=prompt,
voice=voice,
repetition_penalty=repetition_penalty,
stop_token_ids=[128258],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p
)
# Stream audio chunks
for chunk in syn_tokens:
yield chunk
except Exception as e:
logger.error(f"Error in TTS generation: {e}")
raise HTTPException(status_code=500, detail=str(e))
return StreamingResponse(
generate_audio_stream(),
media_type='audio/wav',
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
}
)
@app.websocket("/ws/tts")
async def websocket_tts(websocket: WebSocket):
"""WebSocket endpoint for real-time TTS streaming"""
await websocket.accept()
if engine is None:
await websocket.send_json({"error": "Model not loaded"})
await websocket.close()
return
try:
while True:
# Receive request from client
data = await websocket.receive_text()
request = json.loads(data)
prompt = request.get("prompt", "")
voice = request.get("voice", DEFAULT_VOICE)
temperature = request.get("temperature", 0.4)
top_p = request.get("top_p", 0.9)
max_tokens = request.get("max_tokens", 2000)
repetition_penalty = request.get("repetition_penalty", 1.1)
if not prompt:
await websocket.send_json({"error": "No prompt provided"})
continue
# Send status update
await websocket.send_json({"status": "generating", "prompt": prompt})
try:
# Send WAV header
wav_header = create_wav_header()
await websocket.send_bytes(wav_header)
# Generate and stream audio
syn_tokens = engine.generate_speech(
prompt=prompt,
voice=voice,
repetition_penalty=repetition_penalty,
stop_token_ids=[128258],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p
)
chunk_count = 0
for chunk in syn_tokens:
await websocket.send_bytes(chunk)
chunk_count += 1
# Send periodic status updates
if chunk_count % 10 == 0:
await websocket.send_json({
"status": "streaming",
"chunks_sent": chunk_count
})
# Send completion status
await websocket.send_json({
"status": "completed",
"total_chunks": chunk_count
})
except Exception as e:
logger.error(f"Error in WebSocket TTS generation: {e}")
await websocket.send_json({"error": str(e)})
except WebSocketDisconnect:
logger.info("WebSocket client disconnected")
except Exception as e:
logger.error(f"WebSocket error: {e}")
await websocket.close()
@app.get("/voices")
async def get_available_voices():
"""Get list of available voices"""
if engine is None:
raise HTTPException(status_code=503, detail="Model not loaded")
return {"voices": engine.available_voices}
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"message": "Orpheus TTS Server",
"endpoints": {
"health": "/health",
"tts_http": f"/tts?prompt=your_text&voice={DEFAULT_VOICE}",
"tts_websocket": "/ws/tts",
"voices": "/voices"
},
"model_loaded": engine is not None
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
|