|
|
|
import os |
|
import json |
|
import torch |
|
import asyncio |
|
import traceback |
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
|
from huggingface_hub import login |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StoppingCriteria, StoppingCriteriaList |
|
|
|
from transformers.generation.streamers import BaseStreamer |
|
from snac import SNAC |
|
|
|
|
|
tok = None |
|
model = None |
|
snac = None |
|
masker = None |
|
stopping_criteria = None |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
if HF_TOKEN: |
|
print("π Logging in to Hugging Face Hub...") |
|
|
|
login(HF_TOKEN) |
|
|
|
|
|
|
|
|
|
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" |
|
START_TOKEN = 128259 |
|
NEW_BLOCK = 128257 |
|
EOS_TOKEN = 128258 |
|
AUDIO_BASE = 128266 |
|
AUDIO_SPAN = 4096 * 7 |
|
|
|
AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN) |
|
|
|
|
|
class AudioMask(LogitsProcessor): |
|
""" |
|
Manages allowed tokens during generation. |
|
- Initially allows NEW_BLOCK and AUDIO tokens. |
|
- Allows EOS_TOKEN only after at least one audio block has been sent. |
|
""" |
|
def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int): |
|
super().__init__() |
|
|
|
self.allow_initial = torch.cat([ |
|
torch.tensor([new_block_token_id], device=audio_ids.device), |
|
audio_ids |
|
], dim=0) |
|
self.eos = torch.tensor([eos_token_id], device=audio_ids.device) |
|
|
|
self.allow_with_eos = torch.cat([self.allow_initial, self.eos], dim=0) |
|
self.sent_blocks = 0 |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
|
|
current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow_initial |
|
|
|
|
|
mask = torch.full_like(scores, float("-inf")) |
|
|
|
mask[:, current_allow] = 0 |
|
|
|
return scores + mask |
|
|
|
def reset(self): |
|
"""Resets the state for a new generation request.""" |
|
self.sent_blocks = 0 |
|
|
|
|
|
|
|
class EosStoppingCriteria(StoppingCriteria): |
|
def __init__(self, eos_token_id: int): |
|
self.eos_token_id = eos_token_id |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
|
|
|
|
if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id: |
|
print("StoppingCriteria: EOS detected.") |
|
return True |
|
return False |
|
|
|
|
|
class AudioStreamer(BaseStreamer): |
|
""" |
|
Custom streamer to process audio tokens, decode them using SNAC, |
|
and send audio bytes over a WebSocket. |
|
""" |
|
|
|
def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str): |
|
self.ws = ws |
|
self.snac = snac_decoder |
|
self.masker = audio_mask |
|
self.loop = loop |
|
|
|
self.device = target_device |
|
self.buf: list[int] = [] |
|
self.tasks = set() |
|
|
|
def _decode_block(self, block7: list[int]) -> bytes: |
|
""" |
|
Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes. |
|
NOTE: The mapping from the 7 tokens to the 3 SNAC codebooks (l1, l2, l3) |
|
is CRITICAL and based on the structure used by the specific model. |
|
This implementation uses the mapping derived from the user's previous code. |
|
If audio is distorted, try the alternative mapping commented out below. |
|
""" |
|
if len(block7) != 7: |
|
print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.") |
|
return b"" |
|
|
|
|
|
try: |
|
l1 = [block7[0]] |
|
l2 = [block7[1], block7[4]] |
|
l3 = [block7[2], block7[3], block7[5], block7[6]] |
|
|
|
|
|
|
|
|
|
except IndexError as e: |
|
print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}, Error: {e}") |
|
return b"" |
|
|
|
|
|
try: |
|
codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0) |
|
codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0) |
|
codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0) |
|
codes = [codes_l1, codes_l2, codes_l3] |
|
except Exception as e: |
|
print(f"Streamer Error: Failed converting lists to tensors. Error: {e}") |
|
return b"" |
|
|
|
|
|
try: |
|
with torch.no_grad(): |
|
|
|
audio = self.snac.decode(codes)[0] |
|
except Exception as e: |
|
print(f"Streamer Error: snac.decode failed. Input shapes: {[c.shape for c in codes]}. Error: {e}") |
|
return b"" |
|
|
|
|
|
|
|
audio_np = audio.squeeze().detach().cpu().numpy() |
|
|
|
|
|
audio_bytes = (audio_np * 32767).astype("int16").tobytes() |
|
return audio_bytes |
|
|
|
async def _send_audio_bytes(self, data: bytes): |
|
"""Coroutine to send bytes over WebSocket.""" |
|
if not data: |
|
return |
|
try: |
|
await self.ws.send_bytes(data) |
|
|
|
except WebSocketDisconnect: |
|
print("Streamer: WebSocket disconnected during send.") |
|
except Exception as e: |
|
print(f"Streamer: Error sending bytes: {e}") |
|
|
|
def put(self, value: torch.LongTensor): |
|
""" |
|
Receives new token IDs (Tensor) from generate() (runs in worker thread). |
|
Processes tokens, decodes full blocks, and schedules sending via run_coroutine_threadsafe. |
|
""" |
|
|
|
if value.numel() == 0: |
|
return |
|
|
|
try: |
|
new_token_ids = value.view(-1).tolist() |
|
except Exception as e: |
|
print(f"Streamer Error: Could not process incoming tensor: {value}, Error: {e}") |
|
return |
|
|
|
for t in new_token_ids: |
|
if t == EOS_TOKEN: |
|
|
|
|
|
break |
|
|
|
if t == NEW_BLOCK: |