Spaces:
Paused
Paused
| # app.py ────────────────────────────────────────────────────────────── | |
| import os, json, torch, asyncio | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from huggingface_hub import login | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor | |
| from snac import SNAC | |
| # 0) Login + Device --------------------------------------------------- | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if HF_TOKEN: | |
| login(HF_TOKEN) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch.backends.cuda.enable_flash_sdp(False) # PyTorch‑2.2‑Bug | |
| # 1) Konstanten ------------------------------------------------------- | |
| REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" | |
| CHUNK_TOKENS = 50 | |
| START_TOKEN = 128259 | |
| NEW_BLOCK = 128257 | |
| EOS_TOKEN = 128258 | |
| AUDIO_BASE = 128266 | |
| AUDIO_SPAN = 4096 * 7 # 28 672 Codes | |
| AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN) # Renamed VALID_AUDIO to AUDIO_IDS | |
| # 2) Logit‑Mask (NEW_BLOCK + Audio; EOS erst nach 1. Block) ---------- | |
| class AudioMask(LogitsProcessor): | |
| def __init__(self, audio_ids: torch.Tensor): | |
| super().__init__() | |
| self.allow = torch.cat([ | |
| torch.tensor([NEW_BLOCK], device=audio_ids.device), | |
| audio_ids | |
| ]) | |
| self.eos = torch.tensor([EOS_TOKEN], device=audio_ids.device) | |
| self.sent_blocks = 0 | |
| self.buffer_pos = 0 # Added buffer position | |
| def __call__(self, input_ids, logits): | |
| # Calculate allowed tokens based on buffer position | |
| start_token = AUDIO_BASE + self.buffer_pos * 4096 | |
| end_token = start_token + 4096 | |
| allowed_audio = torch.arange(start_token, end_token, device=self.allow.device) | |
| # Only allow NEW_BLOCK if buffer is full, otherwise only allow audio tokens | |
| if self.buffer_pos == 7: | |
| allowed = torch.cat([ | |
| torch.tensor([NEW_BLOCK], device=self.allow.device), | |
| allowed_audio | |
| ]) | |
| else: | |
| allowed = allowed_audio # Only allow audio tokens | |
| if self.sent_blocks: # ab 1. Block EOS zulassen | |
| allowed = torch.cat([allowed, self.eos]) | |
| mask = logits.new_full(logits.shape, float("-inf")) | |
| mask = logits.new_full(logits.shape, float("-inf")) | |
| mask[:, allowed] = 0 | |
| return logits + mask | |
| # 3) FastAPI Grundgerüst --------------------------------------------- | |
| app = FastAPI() | |
| def hello(): | |
| return {"status": "ok"} | |
| def load_models(): | |
| global tok, model, snac, masker | |
| print("⏳ Lade Modelle …", flush=True) | |
| tok = AutoTokenizer.from_pretrained(REPO) | |
| snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| REPO, | |
| device_map={"": 0} if device == "cuda" else None, | |
| torch_dtype=torch.bfloat16 if device == "cuda" else None, | |
| low_cpu_mem_usage=True, | |
| ) | |
| model.config.pad_token_id = model.config.eos_token_id | |
| masker = AudioMask(AUDIO_IDS.to(device)) | |
| print("✅ Modelle geladen", flush=True) | |
| # 4) Helper ----------------------------------------------------------- | |
| def build_prompt(text: str, voice: str): | |
| prompt_ids = tok(f"{voice}: {text}", return_tensors="pt").input_ids.to(device) | |
| ids = torch.cat([torch.tensor([[START_TOKEN]], device=device), | |
| prompt_ids, | |
| torch.tensor([[128009, 128260]], device=device)], 1) | |
| attn = torch.ones_like(ids) | |
| return ids, attn | |
| def decode_block(block7: list[int]) -> bytes: | |
| l1,l2,l3=[],[],[] | |
| l1.append(block7[0] - AUDIO_BASE) # Subtract AUDIO_BASE | |
| l2.append((block7[1] - AUDIO_BASE) - 4096) # Subtract AUDIO_BASE then position offset | |
| l3 += [(block7[2] - AUDIO_BASE) - 8192, (block7[3] - AUDIO_BASE) - 12288] # Subtract AUDIO_BASE then position offsets | |
| l2.append((block7[4] - AUDIO_BASE) - 16384) # Subtract AUDIO_BASE then position offset | |
| l3 += [(block7[5] - AUDIO_BASE) - 20480, (block7[6] - AUDIO_BASE) - 24576] # Subtract AUDIO_BASE then position offsets | |
| with torch.no_grad(): | |
| codes = [torch.tensor(x, device=device).unsqueeze(0) | |
| for x in (l1,l2,l3)] | |
| audio = snac.decode(codes).squeeze().detach().cpu().numpy() | |
| return (audio*32767).astype("int16").tobytes() | |
| # 5) WebSocket‑Endpoint ---------------------------------------------- | |
| async def tts(ws: WebSocket): | |
| await ws.accept() | |
| try: | |
| req = json.loads(await ws.receive_text()) | |
| text = req.get("text", "") | |
| voice = req.get("voice", "Jakob") | |
| ids, attn = build_prompt(text, voice) | |
| past = None | |
| offset_len = ids.size(1) # wie viele Tokens existieren schon | |
| last_tok = None | |
| buf = [] | |
| # masker.buffer_pos = 0 # Removed initialization here | |
| while True: | |
| # Update buffer_pos based on current buffer length before generation | |
| masker.buffer_pos = len(buf) | |
| # --- Mini‑Generate (Cache Disabled for Debugging) ------------------------------------------- | |
| gen = model.generate( | |
| input_ids = ids, # Always use full sequence | |
| attention_mask = attn, # Always use full attention mask | |
| # past_key_values= past, # Disabled cache | |
| max_new_tokens = CHUNK_TOKENS, | |
| logits_processor=[masker], | |
| do_sample=True, temperature=0.7, top_p=0.95, | |
| use_cache=False, # Disabled cache | |
| return_dict_in_generate=True, | |
| return_legacy_cache=True | |
| ) | |
| # ----- neue Tokens heraus schneiden -------------------------- | |
| seq = gen.sequences[0].tolist() | |
| new = seq[offset_len:] | |
| if not new: # nichts -> fertig | |
| break | |
| offset_len += len(new) | |
| # ----- Update ids and attn with the full sequence (Cache Disabled) --------- | |
| ids = torch.tensor([seq], device=device) # Re-added | |
| attn = torch.ones_like(ids) # Re-added | |
| # past = gen.past_key_values # Disabled cache access | |
| last_tok = new[-1] | |
| print("new tokens:", new[:25], flush=True) | |
| # ----- Token‑Handling ---------------------------------------- | |
| for t in new: | |
| if t == EOS_TOKEN: # Re-enabled EOS check | |
| raise StopIteration # Re-enabled EOS check | |
| if t == NEW_BLOCK: | |
| buf.clear() | |
| continue | |
| # Only append if it's an audio token | |
| # Only append if it's an audio token | |
| if t >= AUDIO_BASE and t < AUDIO_BASE + AUDIO_SPAN: | |
| buf.append(t) # Append original token | |
| # masker.buffer_pos += 1 # Removed increment here | |
| if len(buf) == 7: | |
| await ws.send_bytes(decode_block(buf)) | |
| buf.clear() | |
| masker.sent_blocks = 1 # ab jetzt EOS zulässig | |
| # masker.buffer_pos = 0 # Removed reset here | |
| else: | |
| # Optional: Log unexpected tokens | |
| print(f"DEBUG: Skipping non-audio token: {t}", flush=True) | |
| except (StopIteration, WebSocketDisconnect): | |
| pass | |
| except Exception as e: | |
| print("❌ WS‑Error:", e, flush=True) | |
| import traceback | |
| traceback.print_exc() | |
| if ws.client_state.name != "DISCONNECTED": | |
| await ws.close(code=1011) | |
| finally: | |
| if ws.client_state.name != "DISCONNECTED": | |
| try: | |
| await ws.close() | |
| except RuntimeError: | |
| pass | |
| # 6) Dev‑Start -------------------------------------------------------- | |
| if __name__ == "__main__": | |
| import uvicorn, sys | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info") |