File size: 9,013 Bytes
0b5b901 dbb4a9f 4189fe1 9bf14d0 dbb4a9f 0316ec3 e3958ab 479f253 dbb4a9f 641d199 479f253 2008a3f dbb4a9f e3958ab dbb4a9f 641d199 dbb4a9f e3958ab 641d199 dbb4a9f 479f253 dbb4a9f 641d199 e3958ab dbb4a9f 641d199 dbb4a9f e3958ab dbb4a9f 641d199 dbb4a9f a0cc672 dbb4a9f a0cc672 e3958ab dbb4a9f 0dfc310 dbb4a9f 641d199 dbb4a9f 641d199 dbb4a9f 641d199 dbb4a9f 641d199 dbb4a9f 641d199 dbb4a9f 641d199 dbb4a9f 641d199 dbb4a9f 641d199 dbb4a9f 641d199 dbb4a9f 641d199 dbb4a9f 641d199 dbb4a9f 641d199 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
# app.py ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
import os
import json
import torch
import asyncio
import traceback # Import traceback for better error logging
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StoppingCriteria, StoppingCriteriaList
# Import BaseStreamer for the interface
from transformers.generation.streamers import BaseStreamer
from snac import SNAC # Ensure you have 'pip install snac'
# --- Globals (populated in load_models) ---
tok = None
model = None
snac = None
masker = None
stopping_criteria = None
device = "cuda" if torch.cuda.is_available() else "cpu"
# 0) Login + Device ---------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
print("π Logging in to Hugging Face Hub...")
# Consider adding error handling for login failure if necessary
login(HF_TOKEN)
# torch.backends.cuda.enable_flash_sdp(False) # Uncomment if needed for PyTorchβ2.2βBug
# 1) Konstanten -------------------------------------------------------
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
START_TOKEN = 128259
NEW_BLOCK = 128257 # Token indicating start of audio generation
EOS_TOKEN = 128258 # End Of Speech token
AUDIO_BASE = 128266 # Base ID for audio tokens
AUDIO_SPAN = 4096 * 7 # 7 codebooks * 4096 codes per book = 28672 possible audio tokens
# Create AUDIO_IDS on the correct device later in load_models
AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
# 2) LogitβMask -------------------------------------------------------
class AudioMask(LogitsProcessor):
"""
Manages allowed tokens during generation.
- Initially allows NEW_BLOCK and AUDIO tokens.
- Allows EOS_TOKEN only after at least one audio block has been sent.
"""
def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
super().__init__()
# Allow NEW_BLOCK and all valid audio tokens initially
self.allow_initial = torch.cat([
torch.tensor([new_block_token_id], device=audio_ids.device),
audio_ids
], dim=0)
self.eos = torch.tensor([eos_token_id], device=audio_ids.device)
# Precompute combined tensor for allowing audio, NEW_BLOCK, and EOS
self.allow_with_eos = torch.cat([self.allow_initial, self.eos], dim=0)
self.sent_blocks = 0 # State: Number of audio blocks sent
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Determine which tokens are allowed based on whether blocks have been sent
current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow_initial
# Create a mask initialized to negative infinity
mask = torch.full_like(scores, float("-inf"))
# Set allowed token scores to 0 (effectively allowing them)
mask[:, current_allow] = 0
# Apply the mask to the scores
return scores + mask
def reset(self):
"""Resets the state for a new generation request."""
self.sent_blocks = 0
# 3) StoppingCriteria fΓΌr EOS ---------------------------------------
# generate() needs explicit stopping criteria when using a streamer
class EosStoppingCriteria(StoppingCriteria):
def __init__(self, eos_token_id: int):
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
# Check if the *last* generated token is the EOS token
# Check input_ids shape to prevent index error on first token
if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
print("StoppingCriteria: EOS detected.")
return True
return False
# 4) Benutzerdefinierter AudioStreamer -------------------------------
class AudioStreamer(BaseStreamer):
"""
Custom streamer to process audio tokens, decode them using SNAC,
and send audio bytes over a WebSocket.
"""
# Added target_device parameter
def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str):
self.ws = ws
self.snac = snac_decoder
self.masker = audio_mask # Reference to the mask to update sent_blocks
self.loop = loop # Event loop of the main thread for run_coroutine_threadsafe
# Use the passed target_device
self.device = target_device
self.buf: list[int] = [] # Buffer for audio token values (AUDIO_BASE subtracted)
self.tasks = set() # Keep track of pending send tasks
def _decode_block(self, block7: list[int]) -> bytes:
"""
Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes.
NOTE: The mapping from the 7 tokens to the 3 SNAC codebooks (l1, l2, l3)
is CRITICAL and based on the structure used by the specific model.
This implementation uses the mapping derived from the user's previous code.
If audio is distorted, try the alternative mapping commented out below.
"""
if len(block7) != 7:
print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
return b"" # Return empty bytes if block is incomplete
# --- Mapping based on user's previous version ---
try:
l1 = [block7[0]] # Index 0
l2 = [block7[1], block7[4]] # Indices 1, 4
l3 = [block7[2], block7[3], block7[5], block7[6]] # Indices 2, 3, 5, 6
# --- Alternative Hypothesis Mapping (Try if above fails) ---
# l1 = [block7[0], block7[3], block7[6]] # Indices 0, 3, 6
# l2 = [block7[1], block7[4]] # Indices 1, 4
# l3 = [block7[2], block7[5]] # Indices 2, 5
except IndexError as e:
print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}, Error: {e}")
return b""
# Convert lists to tensors on the correct device
try:
codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0)
codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0)
codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0)
codes = [codes_l1, codes_l2, codes_l3] # List of tensors for SNAC
except Exception as e:
print(f"Streamer Error: Failed converting lists to tensors. Error: {e}")
return b""
# Decode using SNAC
try:
with torch.no_grad():
# Ensure snac_decoder is on the correct device already (done via .to(device))
audio = self.snac.decode(codes)[0] # Decode expects list of tensors, result might have batch dim
except Exception as e:
print(f"Streamer Error: snac.decode failed. Input shapes: {[c.shape for c in codes]}. Error: {e}")
return b""
# Squeeze, move to CPU, convert to numpy
audio_np = audio.squeeze().detach().cpu().numpy()
# Convert to 16-bit PCM bytes
audio_bytes = (audio_np * 32767).astype("int16").tobytes()
return audio_bytes
async def _send_audio_bytes(self, data: bytes):
"""Coroutine to send bytes over WebSocket."""
if not data: # Don't send empty bytes
return
try:
await self.ws.send_bytes(data)
# print(f"Streamer: Sent {len(data)} audio bytes.") # Optional: Debug log
except WebSocketDisconnect:
print("Streamer: WebSocket disconnected during send.")
except Exception as e:
print(f"Streamer: Error sending bytes: {e}")
def put(self, value: torch.LongTensor):
"""
Receives new token IDs (Tensor) from generate() (runs in worker thread).
Processes tokens, decodes full blocks, and schedules sending via run_coroutine_threadsafe.
"""
# Ensure value is on CPU and flatten to a list of ints
if value.numel() == 0:
return
# Handle potential shape issues, ensure it's iterable
try:
new_token_ids = value.view(-1).tolist()
except Exception as e:
print(f"Streamer Error: Could not process incoming tensor: {value}, Error: {e}")
return
for t in new_token_ids:
if t == EOS_TOKEN:
# print("Streamer: EOS token encountered.")
# EOS is handled by StoppingCriteria, no action needed here except maybe logging.
break # Stop processing this batch if EOS is found
if t == NEW_BLOCK: |