Spaces:
Paused
Paused
# app.py ────────────────────────────────────────────────────────────── | |
import os, json, torch, asyncio | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
from huggingface_hub import login | |
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, DynamicCache # Added StaticCache | |
from snac import SNAC | |
# 0) Login + Device --------------------------------------------------- | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if HF_TOKEN: | |
login(HF_TOKEN) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
#torch.backends.cuda.enable_flash_sdp(False) # PyTorch‑2.2‑Bug | |
# 1) Konstanten ------------------------------------------------------- | |
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" | |
CHUNK_TOKENS = 50 | |
START_TOKEN = 128259 | |
NEW_BLOCK = 128257 | |
EOS_TOKEN = 128258 | |
AUDIO_BASE = 128266 | |
AUDIO_SPAN = 4096 * 7 # 28 672 Codes | |
AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN) # Renamed VALID_AUDIO to AUDIO_IDS | |
# 2) Logit‑Mask (NEW_BLOCK + Audio; EOS erst nach 1. Block) ---------- | |
class AudioMask(LogitsProcessor): | |
def __init__(self, audio_ids: torch.Tensor): | |
super().__init__() | |
self.allow = torch.cat([ | |
torch.tensor([NEW_BLOCK], device=audio_ids.device), | |
audio_ids | |
]) | |
self.eos = torch.tensor([EOS_TOKEN], device=audio_ids.device) | |
self.sent_blocks = 0 | |
self.buffer_pos = 0 # Added buffer position | |
def __call__(self, input_ids, scores): | |
allow = torch.cat([self.allow, self.eos]) # Reverted masking logic | |
mask = torch.full_like(scores, float("-inf")) | |
mask[:, allow] = 0 | |
return scores + mask | |
# 3) FastAPI Grundgerüst --------------------------------------------- | |
app = FastAPI() | |
def hello(): | |
return {"status": "ok"} | |
def load_models(): | |
global tok, model, snac, masker | |
print("⏳ Lade Modelle …", flush=True) | |
tok = AutoTokenizer.from_pretrained(REPO) | |
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) | |
model = AutoModelForCausalLM.from_pretrained( | |
REPO, | |
device_map={"": 0} if device == "cuda" else None, | |
torch_dtype=torch.bfloat16 if device == "cuda" else None, | |
low_cpu_mem_usage=True, | |
) | |
model.config.pad_token_id = model.config.eos_token_id | |
masker = AudioMask(AUDIO_IDS.to(device)) | |
print("✅ Modelle geladen", flush=True) | |
# 4) Helper ----------------------------------------------------------- | |
def build_prompt(text: str, voice: str): | |
prompt_ids = tok(f"{voice}: {text}", return_tensors="pt").input_ids.to(device) | |
ids = torch.cat([torch.tensor([[START_TOKEN]], device=device), | |
prompt_ids, | |
torch.tensor([[128009, 128260]], device=device)], 1) | |
attn = torch.ones_like(ids) | |
return ids, attn # Ensure attention mask is created | |
def decode_block(block7: list[int]) -> bytes: | |
l1,l2,l3=[],[],[] | |
l1.append(block7[0] - 0 * 4096) # Subtract position 0 offset | |
l2.append(block7[1] - 1 * 4096) # Subtract position 1 offset | |
l3 += [block7[2] - 2 * 4096, block7[3] - 3 * 4096] # Subtract position offsets | |
l2.append(block7[4] - 4 * 4096) # Subtract position 4 offset | |
l3 += [block7[5] - 5 * 4096, block7[6] - 6 * 4096] # Subtract position offsets | |
with torch.no_grad(): | |
codes = [torch.tensor(x, device=device).unsqueeze(0) | |
for x in (l1,l2,l3)] | |
audio = snac.decode(codes).squeeze().detach().cpu().numpy() | |
return (audio*32767).astype("int16").tobytes() | |
# 5) WebSocket‑Endpoint ---------------------------------------------- | |
async def tts(ws: WebSocket): | |
await ws.accept() | |
try: | |
req = json.loads(await ws.receive_text()) | |
text = req.get("text", "") | |
voice = req.get("voice", "Jakob") | |
ids, attn = build_prompt(text, voice) | |
past = None | |
ids, attn = build_prompt(text, voice) | |
past = None # Holds the DynamicCache object from past_key_values | |
buf = [] | |
last_tok = None # Initialize last_tok | |
while True: | |
# Determine inputs for this iteration | |
if past is None: | |
# First iteration: Use the full prompt | |
current_input_ids = ids | |
current_attn_mask = attn | |
# DO NOT pass cache_position on the first run | |
current_cache_position = None | |
else: | |
# Subsequent iterations: Use only the last token | |
if last_tok is None: | |
print("Error: last_tok is None before subsequent generate call.") | |
break # Should not happen if generation proceeded | |
current_input_ids = torch.tensor([[last_tok]], device=device) | |
current_attn_mask = None # Not needed when past_key_values is provided | |
# DO NOT pass cache_position; let DynamicCache handle it | |
current_cache_position = None | |
# --- Call model.generate --- | |
try: | |
gen = model.generate( | |
input_ids=current_input_ids, | |
attention_mask=current_attn_mask, | |
past_key_values=past, | |
cache_position=current_cache_position, # Will be None after first iteration | |
max_new_tokens=CHUNK_TOKENS, | |
logits_processor=[masker], | |
do_sample=True, temperature=0.7, top_p=0.95, | |
use_cache=True, | |
return_dict_in_generate=True, | |
return_legacy_cache=False # Ensures DynamicCache | |
) | |
except Exception as e: | |
print(f"❌ Error during model.generate: {e}") | |
import traceback | |
traceback.print_exc() | |
break # Exit loop on generation error | |
# --- Process Output --- | |
# Get the full sequence generated *up to this point* | |
full_sequence_now = gen.sequences # Get the sequence tensor | |
# Determine the sequence length *before* this generation call using the cache | |
# If past is None, the previous length was the initial prompt length | |
prev_seq_len = past.get_seq_length() if past is not None else ids.shape | |
# The new tokens are those generated *in this call* | |
# These appear *after* the previously cached sequence length | |
# Ensure slicing is correct even if no new tokens are generated | |
if full_sequence_now.shape > prev_seq_len: | |
new_token_ids = full_sequence_now[prev_seq_len:] | |
new = new_token_ids.tolist() # Convert tensor to list | |
else: | |
new = [] # No new tokens generated | |
if not new: # If no new tokens were generated, stop | |
print("No new tokens generated, stopping.") | |
break | |
# Update past_key_values for the *next* iteration | |
past = gen.past_key_values # Update the cache state | |
# Get the very last token generated in *this* call for the *next* input | |
last_tok = new[-1] | |
# ----- Token‑Handling (process the 'new' list) ----- | |
eos_found = False | |
for t in new: | |
if t == EOS_TOKEN: | |
print("EOS token encountered.") | |
eos_found = True | |
break # Stop processing tokens in this chunk | |
if t == NEW_BLOCK: | |
buf.clear() | |
continue | |
# Check if token is within the expected audio range | |
if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN: | |
buf.append(t - AUDIO_BASE) | |
else: | |
# Log unexpected tokens if necessary | |
# print(f"Warning: Generated token {t} outside expected audio range.") | |
pass # Ignore unexpected tokens for now | |
if len(buf) == 7: | |
await ws.send_bytes(decode_block(buf)) | |
buf.clear() | |
# Allow EOS only after the first full block is sent | |
if not masker.sent_blocks: | |
masker.sent_blocks = 1 | |
if eos_found: | |
# Handle any remaining buffer content if needed (e.g., log incomplete block) | |
if len(buf) > 0: | |
print(f"Warning: Incomplete audio block at EOS: {len(buf)} tokens. Discarding.") | |
buf.clear() | |
break # Exit the while loop | |
except (StopIteration, WebSocketDisconnect): | |
pass | |
except Exception as e: | |
print("❌ WS‑Error:", e, flush=True) | |
import traceback | |
traceback.print_exc() | |
if ws.client_state.name != "DISCONNECTED": | |
await ws.close(code=1011) | |
finally: | |
if ws.client_state.name != "DISCONNECTED": | |
try: | |
await ws.close() | |
except RuntimeError: | |
pass | |
# 6) Dev‑Start -------------------------------------------------------- | |
if __name__ == "__main__": | |
import uvicorn, sys | |
uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info") |