|
|
|
import os |
|
import json |
|
import torch |
|
import asyncio |
|
import traceback |
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
|
from huggingface_hub import login |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StoppingCriteria, StoppingCriteriaList |
|
|
|
from transformers.generation.streamers import BaseStreamer |
|
from snac import SNAC |
|
|
|
|
|
tok = None |
|
model = None |
|
snac = None |
|
masker = None |
|
stopping_criteria = None |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
if HF_TOKEN: |
|
print("π Logging in to Hugging Face Hub...") |
|
login(HF_TOKEN) |
|
|
|
|
|
|
|
|
|
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" |
|
START_TOKEN = 128259 |
|
NEW_BLOCK = 128257 |
|
|
|
EOS_TOKEN = 128258 |
|
|
|
AUDIO_BASE = 128266 |
|
AUDIO_SPAN = 4096 * 7 |
|
CODEBOOK_SIZE = 4096 |
|
|
|
AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN) |
|
|
|
|
|
|
|
class AudioMask(LogitsProcessor): |
|
def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int): |
|
super().__init__() |
|
new_block_tensor = torch.tensor([new_block_token_id], device=audio_ids.device, dtype=torch.long) |
|
eos_tensor = torch.tensor([eos_token_id], device=audio_ids.device, dtype=torch.long) |
|
self.allow = torch.cat([new_block_tensor, audio_ids], dim=0) |
|
self.eos = eos_tensor |
|
self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0) |
|
self.sent_blocks = 0 |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow |
|
mask = torch.full_like(scores, float("-inf")) |
|
mask[:, current_allow] = 0 |
|
return scores + mask |
|
|
|
def reset(self): |
|
self.sent_blocks = 0 |
|
|
|
|
|
|
|
class EosStoppingCriteria(StoppingCriteria): |
|
def __init__(self, eos_token_id: int): |
|
self.eos_token_id = eos_token_id |
|
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
if self.eos_token_id is None: |
|
return False |
|
if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id: |
|
print(f"StoppingCriteria: EOS detected (ID: {self.eos_token_id}).") |
|
return True |
|
return False |
|
|
|
|
|
class AudioStreamer(BaseStreamer): |
|
|
|
def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str, eos_token_id: int): |
|
self.ws = ws |
|
self.snac = snac_decoder |
|
self.masker = audio_mask |
|
self.loop = loop |
|
self.device = target_device |
|
self.eos_token_id = eos_token_id |
|
self.buf: list[int] = [] |
|
self.tasks = set() |
|
|
|
def _decode_block(self, block7: list[int]) -> bytes: |
|
""" |
|
Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes. |
|
NOTE: Extracts base code value (0-4095) using modulo, assuming |
|
input values represent (slot_offset + code_value). |
|
Maps extracted values using the structure potentially correct for Kartoffel_Orpheus. |
|
""" |
|
if len(block7) != 7: |
|
|
|
return b"" |
|
|
|
try: |
|
|
|
code_val_0 = block7[0] % CODEBOOK_SIZE |
|
code_val_1 = block7[1] % CODEBOOK_SIZE |
|
code_val_2 = block7[2] % CODEBOOK_SIZE |
|
code_val_3 = block7[3] % CODEBOOK_SIZE |
|
code_val_4 = block7[4] % CODEBOOK_SIZE |
|
code_val_5 = block7[5] % CODEBOOK_SIZE |
|
code_val_6 = block7[6] % CODEBOOK_SIZE |
|
|
|
|
|
l1 = [code_val_0] |
|
l2 = [code_val_1, code_val_4] |
|
l3 = [code_val_2, code_val_3, code_val_5, code_val_6] |
|
|
|
except IndexError: |
|
print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}") |
|
return b"" |
|
except Exception as e_map: |
|
print(f"Streamer Error: Exception during code value extraction/mapping: {e_map}. Block: {block7}") |
|
return b"" |
|
|
|
|
|
try: |
|
codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0) |
|
codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0) |
|
codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0) |
|
codes = [codes_l1, codes_l2, codes_l3] |
|
except Exception as e_tensor: |
|
print(f"Streamer Error: Exception during tensor conversion: {e_tensor}. l1={l1}, l2={l2}, l3={l3}") |
|
return b"" |
|
|
|
|
|
try: |
|
with torch.no_grad(): |
|
audio = self.snac.decode(codes)[0] |
|
except Exception as e_decode: |
|
print(f"Streamer Error: Exception during snac.decode: {e_decode}") |
|
|
|
return b"" |
|
|
|
|
|
try: |
|
audio_np = audio.squeeze().detach().cpu().numpy() |
|
audio_bytes = (audio_np * 32767).astype("int16").tobytes() |
|
return audio_bytes |
|
except Exception as e_post: |
|
print(f"Streamer Error: Exception during post-processing: {e_post}. Audio tensor shape: {audio.shape}") |
|
return b"" |
|
|
|
async def _send_audio_bytes(self, data: bytes): |
|
"""Coroutine to send bytes over WebSocket.""" |
|
if not data: |
|
return |
|
try: |
|
await self.ws.send_bytes(data) |
|
except WebSocketDisconnect: |
|
|
|
|
|
pass |
|
except Exception as e: |
|
if "Cannot call \"send\" once a close message has been sent" in str(e) or \ |
|
"Connection is closed" in str(e): |
|
|
|
pass |
|
else: |
|
print(f"Streamer: Error sending bytes: {e}") |
|
|
|
def put(self, value: torch.LongTensor): |
|
""" |
|
Receives new token IDs (Tensor) from generate(). |
|
Processes tokens, decodes full blocks, and schedules sending. |
|
""" |
|
if value.numel() == 0: |
|
return |
|
new_token_ids = value.squeeze().cpu().tolist() |
|
if isinstance(new_token_ids, int): |
|
new_token_ids = [new_token_ids] |
|
|
|
for t in new_token_ids: |
|
|
|
if t == NEW_BLOCK: |
|
self.buf.clear() |
|
continue |
|
|
|
|
|
if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN: |
|
self.buf.append(t - AUDIO_BASE) |
|
|
|
|
|
|
|
|
|
if len(self.buf) == 7: |
|
audio_bytes = self._decode_block(self.buf) |
|
self.buf.clear() |
|
|
|
if audio_bytes: |
|
future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop) |
|
self.tasks.add(future) |
|
future.add_done_callback(self.tasks.discard) |
|
|
|
if self.masker.sent_blocks == 0: |
|
self.masker.sent_blocks = 1 |
|
|
|
def end(self): |
|
"""Called by generate() when generation finishes.""" |
|
if len(self.buf) > 0: |
|
print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.") |
|
self.buf.clear() |
|
pass |
|
|
|
|
|
app = FastAPI() |
|
|
|
@app.on_event("startup") |
|
async def load_models_startup(): |
|
|
|
global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU |
|
|
|
print(f"π Starting up on device: {device}") |
|
print("β³ Lade Modelle β¦", flush=True) |
|
|
|
tok = AutoTokenizer.from_pretrained(REPO) |
|
print("Tokenizer loaded.") |
|
|
|
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) |
|
print(f"SNAC loaded to {device}.") |
|
|
|
model_dtype = torch.float32 |
|
if device == "cuda": |
|
if torch.cuda.is_bf16_supported(): |
|
model_dtype = torch.bfloat16 |
|
print("Using bfloat16 for model.") |
|
else: |
|
model_dtype = torch.float16 |
|
print("Using float16 for model.") |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
REPO, |
|
device_map={"": 0} if device == "cuda" else None, |
|
torch_dtype=model_dtype, |
|
low_cpu_mem_usage=True, |
|
) |
|
print(f"Model loaded to {model.device} with dtype {model.dtype}.") |
|
model.eval() |
|
|
|
|
|
conf_eos = model.config.eos_token_id |
|
tok_eos = tok.eos_token_id |
|
print(f"Model Config EOS ID: {conf_eos}") |
|
print(f"Tokenizer EOS ID: {tok_eos}") |
|
print(f"Using Constant EOS_TOKEN: {EOS_TOKEN}") |
|
if conf_eos != EOS_TOKEN or tok_eos != EOS_TOKEN: |
|
print(f"β οΈ WARNING: Constant EOS_TOKEN {EOS_TOKEN} differs from model/tokenizer IDs ({conf_eos}/{tok_eos}).") |
|
|
|
|
|
|
|
if model.config.pad_token_id is None: |
|
print(f"Setting model.config.pad_token_id to Constant EOS token ID ({EOS_TOKEN})") |
|
model.config.pad_token_id = EOS_TOKEN |
|
|
|
audio_ids_device = AUDIO_IDS_CPU.to(device) |
|
|
|
masker = AudioMask(audio_ids_device, NEW_BLOCK, EOS_TOKEN) |
|
print("AudioMask initialized.") |
|
|
|
|
|
stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(EOS_TOKEN)]) |
|
print("StoppingCriteria initialized.") |
|
|
|
print("β
Modelle geladen und bereit!", flush=True) |
|
|
|
@app.get("/") |
|
def hello(): |
|
return {"status": "ok", "message": "TTS Service is running"} |
|
|
|
|
|
def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]: |
|
"""Builds the input_ids and attention_mask for the model.""" |
|
prompt_text = f"{voice}: {text}" |
|
prompt_ids = tok(prompt_text, return_tensors="pt").input_ids.to(device) |
|
|
|
input_ids = torch.cat([ |
|
torch.tensor([[START_TOKEN]], device=device, dtype=torch.long), |
|
prompt_ids, |
|
torch.tensor([[NEW_BLOCK]], device=device, dtype=torch.long) |
|
], dim=1) |
|
|
|
attention_mask = torch.ones_like(input_ids) |
|
return input_ids, attention_mask |
|
|
|
|
|
@app.websocket("/ws/tts") |
|
async def tts(ws: WebSocket): |
|
|
|
await ws.accept() |
|
print("π Client connected") |
|
streamer = None |
|
main_loop = asyncio.get_running_loop() |
|
|
|
try: |
|
req_text = await ws.receive_text() |
|
print(f"Received request: {req_text}") |
|
req = json.loads(req_text) |
|
text = req.get("text", "Hallo Welt, wie geht es dir heute?") |
|
voice = req.get("voice", "Jakob") |
|
|
|
if not text: |
|
print("β οΈ Request text is empty.") |
|
await ws.close(code=1003, reason="Text cannot be empty") |
|
return |
|
|
|
print(f"Generating audio for: '{text}' with voice '{voice}'") |
|
ids, attn = build_prompt(text, voice) |
|
masker.reset() |
|
|
|
streamer = AudioStreamer(ws, snac, masker, main_loop, device, EOS_TOKEN) |
|
|
|
print("Starting generation in background thread...") |
|
|
|
await asyncio.to_thread( |
|
model.generate, |
|
input_ids=ids, |
|
attention_mask=attn, |
|
max_new_tokens=2500, |
|
logits_processor=[masker], |
|
stopping_criteria=stopping_criteria, |
|
|
|
do_sample=True, |
|
temperature=0.6, |
|
top_p=0.9, |
|
repetition_penalty=1.2, |
|
no_repeat_ngram_size=4, |
|
|
|
use_cache=True, |
|
streamer=streamer, |
|
eos_token_id=EOS_TOKEN |
|
) |
|
print("Generation thread finished.") |
|
|
|
except WebSocketDisconnect: |
|
print("π Client disconnected.") |
|
except json.JSONDecodeError: |
|
print("β Invalid JSON received.") |
|
if ws.client_state.name == "CONNECTED": |
|
await ws.close(code=1003, reason="Invalid JSON format") |
|
except Exception as e: |
|
error_details = traceback.format_exc() |
|
print(f"β WSβError: {e}\n{error_details}", flush=True) |
|
error_payload = json.dumps({"error": str(e)}) |
|
try: |
|
if ws.client_state.name == "CONNECTED": |
|
await ws.send_text(error_payload) |
|
except Exception: |
|
pass |
|
if ws.client_state.name == "CONNECTED": |
|
await ws.close(code=1011) |
|
finally: |
|
if streamer: |
|
try: |
|
streamer.end() |
|
except Exception as e_end: |
|
print(f"Error during streamer.end(): {e_end}") |
|
|
|
print("Closing connection.") |
|
if ws.client_state.name == "CONNECTED": |
|
try: |
|
await ws.close(code=1000) |
|
except RuntimeError as e_close: |
|
if "Cannot call \"send\"" not in str(e_close) and "Connection is closed" not in str(e_close): |
|
print(f"Runtime error closing websocket: {e_close}") |
|
except Exception as e_close_final: |
|
print(f"Error closing websocket: {e_close_final}") |
|
elif ws.client_state.name != "DISCONNECTED": |
|
print(f"WebSocket final state: {ws.client_state.name}") |
|
print("Connection closed.") |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
print("Starting Uvicorn server...") |
|
uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info") |