Spaces:
Paused
Paused
# app.py ────────────────────────────────────────────────────────────── | |
import os | |
import json | |
import torch | |
import asyncio | |
import traceback # Import traceback for better error logging | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
from huggingface_hub import login | |
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StoppingCriteria, StoppingCriteriaList | |
# Import BaseStreamer for the interface | |
from transformers.generation.streamers import BaseStreamer | |
from snac import SNAC # Ensure you have 'pip install snac' | |
# --- Globals (populated in load_models) --- | |
tok = None | |
model = None | |
snac = None | |
masker = None | |
stopping_criteria = None | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# 0) Login + Device --------------------------------------------------- | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if HF_TOKEN: | |
print("🔑 Logging in to Hugging Face Hub...") | |
login(HF_TOKEN) | |
# torch.backends.cuda.enable_flash_sdp(False) # Uncomment if needed for PyTorch‑2.2‑Bug | |
# 1) Konstanten ------------------------------------------------------- | |
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" | |
# CHUNK_TOKENS = 50 # Not directly used by us with the streamer approach | |
START_TOKEN = 128259 | |
NEW_BLOCK = 128257 | |
EOS_TOKEN = 128258 | |
AUDIO_BASE = 128266 | |
AUDIO_SPAN = 4096 * 7 # 28672 Codes | |
# Create AUDIO_IDS on the correct device later in load_models | |
AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN) | |
# 2) Logit‑Mask ------------------------------------------------------- | |
class AudioMask(LogitsProcessor): | |
def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int): | |
super().__init__() | |
# Allow NEW_BLOCK and all valid audio tokens initially | |
self.allow = torch.cat([ | |
torch.tensor([new_block_token_id], device=audio_ids.device), # Add NEW_BLOCK token ID | |
audio_ids | |
], dim=0) | |
self.eos = torch.tensor([eos_token_id], device=audio_ids.device) # Store EOS token ID as tensor | |
self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0) # Precompute combined tensor | |
self.sent_blocks = 0 # State: Number of audio blocks sent | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
# Determine which tokens are allowed based on whether blocks have been sent | |
current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow | |
# Create a mask initialized to negative infinity | |
mask = torch.full_like(scores, float("-inf")) | |
# Set allowed token scores to 0 (effectively allowing them) | |
mask[:, current_allow] = 0 | |
# Apply the mask to the scores | |
return scores + mask | |
def reset(self): | |
"""Resets the state for a new generation request.""" | |
self.sent_blocks = 0 | |
# 3) StoppingCriteria für EOS --------------------------------------- | |
# generate() needs explicit stopping criteria when using a streamer | |
class EosStoppingCriteria(StoppingCriteria): | |
def __init__(self, eos_token_id: int): | |
self.eos_token_id = eos_token_id | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
# Check if the *last* generated token is the EOS token | |
if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id: | |
# print("StoppingCriteria: EOS detected.") | |
return True | |
return False | |
# 4) Benutzerdefinierter AudioStreamer ------------------------------- | |
class AudioStreamer(BaseStreamer): | |
def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop): | |
self.ws = ws | |
self.snac = snac_decoder | |
self.masker = audio_mask # Reference to the mask to update sent_blocks | |
self.loop = loop # Event loop of the main thread for run_coroutine_threadsafe | |
self.device = snac_decoder.device # Get device from the decoder | |
self.buf: list[int] = [] # Buffer for audio token values (AUDIO_BASE subtracted) | |
self.tasks = set() # Keep track of pending send tasks | |
def _decode_block(self, block7: list[int]) -> bytes: | |
""" | |
Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes. | |
NOTE: The mapping from the 7 tokens to the 3 SNAC codebooks (l1, l2, l3) | |
is based on a common interleaving hypothesis. Verify if model docs specify otherwise. | |
""" | |
if len(block7) != 7: | |
print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.") | |
return b"" # Return empty bytes if block is incomplete | |
# --- Hypothesis: Interleaving mapping --- | |
# Assumes 7 tokens map to 3 codebooks like this: | |
# Codebook 1 (l1) uses tokens at indices 0, 3, 6 | |
# Codebook 2 (l2) uses tokens at indices 1, 4 | |
# Codebook 3 (l3) uses tokens at indices 2, 5 | |
try: | |
l1 = [block7[0], block7[3], block7[6]] | |
l2 = [block7[1], block7[4]] | |
l3 = [block7[2], block7[5]] | |
except IndexError: | |
print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}") | |
return b"" | |
# Convert lists to tensors on the correct device | |
codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0) | |
codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0) | |
codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0) | |
codes = [codes_l1, codes_l2, codes_l3] # List of tensors for SNAC | |
# Decode using SNAC | |
with torch.no_grad(): | |
audio = self.snac.decode(codes)[0] # Decode expects list of tensors, result might have batch dim | |
# Squeeze, move to CPU, convert to numpy | |
audio_np = audio.squeeze().detach().cpu().numpy() | |
# Convert to 16-bit PCM bytes | |
audio_bytes = (audio_np * 32767).astype("int16").tobytes() | |
return audio_bytes | |
async def _send_audio_bytes(self, data: bytes): | |
"""Coroutine to send bytes over WebSocket.""" | |
if not data: # Don't send empty bytes | |
return | |
try: | |
await self.ws.send_bytes(data) | |
# print(f"Streamer: Sent {len(data)} audio bytes.") | |
except WebSocketDisconnect: | |
print("Streamer: WebSocket disconnected during send.") | |
except Exception as e: | |
print(f"Streamer: Error sending bytes: {e}") | |
def put(self, value: torch.LongTensor): | |
""" | |
Receives new token IDs (Tensor) from generate() (runs in worker thread). | |
Processes tokens, decodes full blocks, and schedules sending via run_coroutine_threadsafe. | |
""" | |
# Ensure value is on CPU and flatten to a list of ints | |
if value.numel() == 0: | |
return | |
new_token_ids = value.squeeze().tolist() | |
if isinstance(new_token_ids, int): # Handle single token case | |
new_token_ids = [new_token_ids] | |
for t in new_token_ids: | |
if t == EOS_TOKEN: | |
# print("Streamer: EOS token encountered.") | |
# EOS is handled by StoppingCriteria, no action needed here except maybe logging. | |
break # Stop processing this batch if EOS is found | |
if t == NEW_BLOCK: | |
# print("Streamer: NEW_BLOCK token encountered.") | |
# NEW_BLOCK indicates the start of audio, might reset buffer if needed | |
self.buf.clear() | |
continue # Move to the next token | |
# Check if token is within the expected audio range | |
if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN: | |
self.buf.append(t - AUDIO_BASE) # Store value relative to base | |
else: | |
# Log unexpected tokens (like START_TOKEN or others if generation goes wrong) | |
# print(f"Streamer Warning: Ignoring unexpected token {t}") | |
pass # Ignore tokens outside the audio range | |
# If buffer has 7 tokens, decode and send | |
if len(self.buf) == 7: | |
audio_bytes = self._decode_block(self.buf) | |
self.buf.clear() # Clear buffer after processing | |
if audio_bytes: # Only send if decoding was successful | |
# Schedule the async send function to run on the main event loop | |
future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop) | |
self.tasks.add(future) | |
# Optional: Remove completed tasks to prevent memory leak if generation is very long | |
future.add_done_callback(self.tasks.discard) | |
# Allow EOS only after the first full block has been processed and scheduled for sending | |
if self.masker.sent_blocks == 0: | |
# print("Streamer: First audio block processed, allowing EOS.") | |
self.masker.sent_blocks = 1 # Update state in the mask | |
# Note: No need to explicitly wait for tasks here. put() should return quickly. | |
def end(self): | |
"""Called by generate() when generation finishes.""" | |
# Handle any remaining tokens in the buffer (optional, here we discard them) | |
if len(self.buf) > 0: | |
print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.") | |
self.buf.clear() | |
# Optional: Wait briefly for any outstanding send tasks to complete? | |
# This is tricky because end() is sync. A robust solution might involve | |
# signaling the WebSocket handler to wait before closing. | |
# For simplicity, we rely on FastAPI/Uvicorn's graceful shutdown handling. | |
# print(f"Streamer: Generation finished. Pending send tasks: {len(self.tasks)}") | |
pass | |
# 5) FastAPI App ------------------------------------------------------ | |
app = FastAPI() | |
async def load_models_startup(): # Make startup async if needed for future async loads | |
global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU | |
print(f"🚀 Starting up on device: {device}") | |
print("⏳ Lade Modelle …", flush=True) | |
tok = AutoTokenizer.from_pretrained(REPO) | |
print("Tokenizer loaded.") | |
# Load SNAC first (usually smaller) | |
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) | |
print(f"SNAC loaded to {snac.device}.") | |
# Load the main model | |
model = AutoModelForCausalLM.from_pretrained( | |
REPO, | |
device_map={"": 0} if device == "cuda" else None, # Assign to GPU 0 if cuda | |
torch_dtype=torch.bfloat16 if device == "cuda" and torch.cuda.is_bf16_supported() else torch.float32, # Use bfloat16 if supported | |
low_cpu_mem_usage=True, # Good practice for large models | |
) | |
model.config.pad_token_id = model.config.eos_token_id # Set pad token | |
print(f"Model loaded to {model.device}.") | |
# Ensure model is in evaluation mode | |
model.eval() | |
# Initialize AudioMask (needs AUDIO_IDS on the correct device) | |
audio_ids_device = AUDIO_IDS_CPU.to(device) | |
masker = AudioMask(audio_ids_device, NEW_BLOCK, EOS_TOKEN) | |
print("AudioMask initialized.") | |
# Initialize StoppingCriteria | |
# IMPORTANT: Create the list and add the criteria instance | |
stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(EOS_TOKEN)]) | |
print("StoppingCriteria initialized.") | |
print("✅ Modelle geladen und bereit!", flush=True) | |
def hello(): | |
return {"status": "ok", "message": "TTS Service is running"} | |
# 6) Helper zum Prompt Bauen ------------------------------------------- | |
def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]: | |
"""Builds the input_ids and attention_mask for the model.""" | |
# Format: <START> <VOICE>: <TEXT> <NEW_BLOCK> | |
prompt_text = f"{voice}: {text}" | |
prompt_ids = tok(prompt_text, return_tensors="pt").input_ids.to(device) | |
# Construct input_ids tensor | |
input_ids = torch.cat([ | |
torch.tensor([[START_TOKEN]], device=device), # Start token | |
prompt_ids, # Encoded prompt | |
torch.tensor([[NEW_BLOCK]], device=device) # New block token to trigger audio | |
], dim=1) | |
# Create attention mask (all ones) | |
attention_mask = torch.ones_like(input_ids) | |
return input_ids, attention_mask | |
# 7) WebSocket‑Endpoint (vereinfacht mit Streamer) --------------------- | |
async def tts(ws: WebSocket): | |
await ws.accept() | |
print(" клиент подключился") # Client connected | |
streamer = None # Initialize for finally block | |
main_loop = asyncio.get_running_loop() # Get the current event loop | |
try: | |
# Receive configuration | |
req_text = await ws.receive_text() | |
print(f"Received request: {req_text}") | |
req = json.loads(req_text) | |
text = req.get("text", "Hallo Welt, wie geht es dir heute?") # Default text | |
voice = req.get("voice", "Jakob") # Default voice | |
if not text: | |
await ws.close(code=1003, reason="Text cannot be empty") | |
return | |
print(f"Generating audio for: '{text}' with voice '{voice}'") | |
# Prepare prompt | |
ids, attn = build_prompt(text, voice) | |
# --- Reset stateful components --- | |
masker.reset() # CRITICAL: Reset the mask state for the new request | |
# --- Create Streamer Instance --- | |
streamer = AudioStreamer(ws, snac, masker, main_loop) | |
# --- Run model.generate in a separate thread --- | |
# This prevents blocking the main FastAPI event loop | |
print("Starting generation...") | |
await asyncio.to_thread( | |
model.generate, | |
input_ids=ids, | |
attention_mask=attn, | |
max_new_tokens=1500, # Limit generation length (adjust as needed) | |
logits_processor=[masker], | |
stopping_criteria=stopping_criteria, | |
do_sample=False, # Use greedy decoding for potentially more stable audio | |
# do_sample=True, temperature=0.7, top_p=0.95, # Or use sampling | |
use_cache=True, | |
streamer=streamer # Pass the custom streamer | |
# No need to manage past_key_values manually | |
) | |
print("Generation finished.") | |
except WebSocketDisconnect: | |
print("Client disconnected.") | |
except json.JSONDecodeError: | |
print("❌ Invalid JSON received.") | |
await ws.close(code=1003, reason="Invalid JSON format") # 1003 = Cannot accept data type | |
except Exception as e: | |
error_details = traceback.format_exc() | |
print(f"❌ WS‑Error: {e}\n{error_details}", flush=True) | |
# Try to send an error message before closing, if possible | |
error_payload = json.dumps({"error": str(e)}) | |
try: | |
if ws.client_state.name == "CONNECTED": | |
await ws.send_text(error_payload) # Send error as text/json | |
except Exception: | |
pass # Ignore error during error reporting | |
# Close with internal server error code | |
if ws.client_state.name == "CONNECTED": | |
await ws.close(code=1011) # 1011 = Internal Server Error | |
finally: | |
# Ensure streamer's end method is called if it exists | |
if streamer: | |
try: | |
streamer.end() | |
except Exception as e_end: | |
print(f"Error during streamer.end(): {e_end}") | |
# Ensure WebSocket is closed | |
print("Closing connection.") | |
if ws.client_state.name != "DISCONNECTED": | |
try: | |
await ws.close(code=1000) # 1000 = Normal Closure | |
except RuntimeError as e_close: | |
# Can happen if connection is already closing/closed | |
print(f"Runtime error closing websocket: {e_close}") | |
except Exception as e_close_final: | |
print(f"Error closing websocket: {e_close_final}") | |
print("Connection closed.") | |
# 8) Dev‑Start -------------------------------------------------------- | |
if __name__ == "__main__": | |
import uvicorn | |
print("Starting Uvicorn server...") | |
# Use reload=True only for development, remove for production | |
uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info") #, reload=True) |