Tomtom84's picture
Update app.py
641d199 verified
raw
history blame
9.01 kB
# app.py ──────────────────────────────────────────────────────────────
import os
import json
import torch
import asyncio
import traceback # Import traceback for better error logging
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StoppingCriteria, StoppingCriteriaList
# Import BaseStreamer for the interface
from transformers.generation.streamers import BaseStreamer
from snac import SNAC # Ensure you have 'pip install snac'
# --- Globals (populated in load_models) ---
tok = None
model = None
snac = None
masker = None
stopping_criteria = None
device = "cuda" if torch.cuda.is_available() else "cpu"
# 0) Login + Device ---------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
print("πŸ”‘ Logging in to Hugging Face Hub...")
# Consider adding error handling for login failure if necessary
login(HF_TOKEN)
# torch.backends.cuda.enable_flash_sdp(False) # Uncomment if needed for PyTorch‑2.2‑Bug
# 1) Konstanten -------------------------------------------------------
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
START_TOKEN = 128259
NEW_BLOCK = 128257 # Token indicating start of audio generation
EOS_TOKEN = 128258 # End Of Speech token
AUDIO_BASE = 128266 # Base ID for audio tokens
AUDIO_SPAN = 4096 * 7 # 7 codebooks * 4096 codes per book = 28672 possible audio tokens
# Create AUDIO_IDS on the correct device later in load_models
AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
# 2) Logit‑Mask -------------------------------------------------------
class AudioMask(LogitsProcessor):
"""
Manages allowed tokens during generation.
- Initially allows NEW_BLOCK and AUDIO tokens.
- Allows EOS_TOKEN only after at least one audio block has been sent.
"""
def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
super().__init__()
# Allow NEW_BLOCK and all valid audio tokens initially
self.allow_initial = torch.cat([
torch.tensor([new_block_token_id], device=audio_ids.device),
audio_ids
], dim=0)
self.eos = torch.tensor([eos_token_id], device=audio_ids.device)
# Precompute combined tensor for allowing audio, NEW_BLOCK, and EOS
self.allow_with_eos = torch.cat([self.allow_initial, self.eos], dim=0)
self.sent_blocks = 0 # State: Number of audio blocks sent
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Determine which tokens are allowed based on whether blocks have been sent
current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow_initial
# Create a mask initialized to negative infinity
mask = torch.full_like(scores, float("-inf"))
# Set allowed token scores to 0 (effectively allowing them)
mask[:, current_allow] = 0
# Apply the mask to the scores
return scores + mask
def reset(self):
"""Resets the state for a new generation request."""
self.sent_blocks = 0
# 3) StoppingCriteria fΓΌr EOS ---------------------------------------
# generate() needs explicit stopping criteria when using a streamer
class EosStoppingCriteria(StoppingCriteria):
def __init__(self, eos_token_id: int):
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
# Check if the *last* generated token is the EOS token
# Check input_ids shape to prevent index error on first token
if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
print("StoppingCriteria: EOS detected.")
return True
return False
# 4) Benutzerdefinierter AudioStreamer -------------------------------
class AudioStreamer(BaseStreamer):
"""
Custom streamer to process audio tokens, decode them using SNAC,
and send audio bytes over a WebSocket.
"""
# Added target_device parameter
def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str):
self.ws = ws
self.snac = snac_decoder
self.masker = audio_mask # Reference to the mask to update sent_blocks
self.loop = loop # Event loop of the main thread for run_coroutine_threadsafe
# Use the passed target_device
self.device = target_device
self.buf: list[int] = [] # Buffer for audio token values (AUDIO_BASE subtracted)
self.tasks = set() # Keep track of pending send tasks
def _decode_block(self, block7: list[int]) -> bytes:
"""
Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes.
NOTE: The mapping from the 7 tokens to the 3 SNAC codebooks (l1, l2, l3)
is CRITICAL and based on the structure used by the specific model.
This implementation uses the mapping derived from the user's previous code.
If audio is distorted, try the alternative mapping commented out below.
"""
if len(block7) != 7:
print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
return b"" # Return empty bytes if block is incomplete
# --- Mapping based on user's previous version ---
try:
l1 = [block7[0]] # Index 0
l2 = [block7[1], block7[4]] # Indices 1, 4
l3 = [block7[2], block7[3], block7[5], block7[6]] # Indices 2, 3, 5, 6
# --- Alternative Hypothesis Mapping (Try if above fails) ---
# l1 = [block7[0], block7[3], block7[6]] # Indices 0, 3, 6
# l2 = [block7[1], block7[4]] # Indices 1, 4
# l3 = [block7[2], block7[5]] # Indices 2, 5
except IndexError as e:
print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}, Error: {e}")
return b""
# Convert lists to tensors on the correct device
try:
codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0)
codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0)
codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0)
codes = [codes_l1, codes_l2, codes_l3] # List of tensors for SNAC
except Exception as e:
print(f"Streamer Error: Failed converting lists to tensors. Error: {e}")
return b""
# Decode using SNAC
try:
with torch.no_grad():
# Ensure snac_decoder is on the correct device already (done via .to(device))
audio = self.snac.decode(codes)[0] # Decode expects list of tensors, result might have batch dim
except Exception as e:
print(f"Streamer Error: snac.decode failed. Input shapes: {[c.shape for c in codes]}. Error: {e}")
return b""
# Squeeze, move to CPU, convert to numpy
audio_np = audio.squeeze().detach().cpu().numpy()
# Convert to 16-bit PCM bytes
audio_bytes = (audio_np * 32767).astype("int16").tobytes()
return audio_bytes
async def _send_audio_bytes(self, data: bytes):
"""Coroutine to send bytes over WebSocket."""
if not data: # Don't send empty bytes
return
try:
await self.ws.send_bytes(data)
# print(f"Streamer: Sent {len(data)} audio bytes.") # Optional: Debug log
except WebSocketDisconnect:
print("Streamer: WebSocket disconnected during send.")
except Exception as e:
print(f"Streamer: Error sending bytes: {e}")
def put(self, value: torch.LongTensor):
"""
Receives new token IDs (Tensor) from generate() (runs in worker thread).
Processes tokens, decodes full blocks, and schedules sending via run_coroutine_threadsafe.
"""
# Ensure value is on CPU and flatten to a list of ints
if value.numel() == 0:
return
# Handle potential shape issues, ensure it's iterable
try:
new_token_ids = value.view(-1).tolist()
except Exception as e:
print(f"Streamer Error: Could not process incoming tensor: {value}, Error: {e}")
return
for t in new_token_ids:
if t == EOS_TOKEN:
# print("Streamer: EOS token encountered.")
# EOS is handled by StoppingCriteria, no action needed here except maybe logging.
break # Stop processing this batch if EOS is found
if t == NEW_BLOCK: