File size: 7,917 Bytes
0b5b901 87012a8 4189fe1 9bf14d0 fee868f d9ea17d 0316ec3 e3958ab 479f253 2008a3f 1ab029d e3958ab 83532d0 f4406f3 e3958ab 479f253 e3958ab 3d65908 e3958ab a0cc672 e3958ab 9bf14d0 0dfc310 9bf14d0 e3958ab 9bf14d0 e3958ab 5031731 e3958ab 0b5b901 9bf14d0 5031731 e3958ab bca75ea d44e840 f63f843 e3958ab b17f5cd e3958ab 0b5b901 7bb84b7 e3958ab 9e2fbd8 e3958ab 9e2fbd8 e3958ab 0b5b901 e3958ab a8606ac d44e840 a09ea48 4189fe1 d44e840 e3958ab 29123b3 e3958ab c417a58 a0cc672 fee868f fd51bc6 c417a58 29123b3 b87ae72 29123b3 fee868f c417a58 b87ae72 c417a58 b87ae72 fee868f b87ae72 c417a58 b87ae72 fd51bc6 e3958ab b17f5cd fee868f b17f5cd fee868f b17f5cd b87ae72 fd51bc6 b17f5cd fd51bc6 b17f5cd fd51bc6 bca75ea 5031731 479f253 a09ea48 e3958ab 83532d0 5031731 479f253 5031731 e3958ab 5031731 e3958ab a4cfefc e3958ab 83532d0 |
|
# app.py ──────────────────────────────────────────────────────────────
import os, json, torch, asyncio
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, DynamicCache # Added StaticCache
from snac import SNAC
# 0) Login + Device ---------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(HF_TOKEN)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.enable_flash_sdp(False) # PyTorch‑2.2‑Bug
# 1) Konstanten -------------------------------------------------------
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
CHUNK_TOKENS = 50
START_TOKEN = 128259
NEW_BLOCK = 128257
EOS_TOKEN = 128258
AUDIO_BASE = 128266
AUDIO_SPAN = 4096 * 7 # 28 672 Codes
AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN) # Renamed VALID_AUDIO to AUDIO_IDS
# 2) Logit‑Mask (NEW_BLOCK + Audio; EOS erst nach 1. Block) ----------
class AudioMask(LogitsProcessor):
def __init__(self, audio_ids: torch.Tensor):
super().__init__()
self.allow = torch.cat([
torch.tensor([NEW_BLOCK], device=audio_ids.device),
audio_ids
])
self.eos = torch.tensor([EOS_TOKEN], device=audio_ids.device)
self.sent_blocks = 0
self.buffer_pos = 0 # Added buffer position
def __call__(self, input_ids, scores):
allow = torch.cat([self.allow, self.eos]) # Reverted masking logic
mask = torch.full_like(scores, float("-inf"))
mask[:, allow] = 0
return scores + mask
# 3) FastAPI Grundgerüst ---------------------------------------------
app = FastAPI()
@app.get("/")
def hello():
return {"status": "ok"}
@app.on_event("startup")
def load_models():
global tok, model, snac, masker
print("⏳ Lade Modelle …", flush=True)
tok = AutoTokenizer.from_pretrained(REPO)
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
model = AutoModelForCausalLM.from_pretrained(
REPO,
device_map={"": 0} if device == "cuda" else None,
torch_dtype=torch.bfloat16 if device == "cuda" else None,
low_cpu_mem_usage=True,
)
model.config.pad_token_id = model.config.eos_token_id
masker = AudioMask(AUDIO_IDS.to(device))
print("✅ Modelle geladen", flush=True)
# 4) Helper -----------------------------------------------------------
def build_prompt(text: str, voice: str):
prompt_ids = tok(f"{voice}: {text}", return_tensors="pt").input_ids.to(device)
ids = torch.cat([torch.tensor([[START_TOKEN]], device=device),
prompt_ids,
torch.tensor([[128009, 128260]], device=device)], 1)
attn = torch.ones_like(ids)
return ids, attn # Ensure attention mask is created
def decode_block(block7: list[int]) -> bytes:
l1,l2,l3=[],[],[]
l1.append(block7[0] - 0 * 4096) # Subtract position 0 offset
l2.append(block7[1] - 1 * 4096) # Subtract position 1 offset
l3 += [block7[2] - 2 * 4096, block7[3] - 3 * 4096] # Subtract position offsets
l2.append(block7[4] - 4 * 4096) # Subtract position 4 offset
l3 += [block7[5] - 5 * 4096, block7[6] - 6 * 4096] # Subtract position offsets
with torch.no_grad():
codes = [torch.tensor(x, device=device).unsqueeze(0)
for x in (l1,l2,l3)]
audio = snac.decode(codes).squeeze().detach().cpu().numpy()
return (audio*32767).astype("int16").tobytes()
# 5) WebSocket‑Endpoint ----------------------------------------------
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
await ws.accept()
try:
req = json.loads(await ws.receive_text())
text = req.get("text", "")
voice = req.get("voice", "Jakob")
ids, attn = build_prompt(text, voice)
past = None # Reverted past initialization
offset_len = ids.size(1) # wie viele Tokens existieren schon
last_tok = None # Initialized last_tok
buf = []
past_key_values = DynamicCache()
while True:
print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging
# --- Mini‑Generate (StaticCache via cache_implementation) -------------------------------------------
gen = model.generate(
input_ids = ids if past is None else torch.tensor([[last_tok]], device=device), # Use past is None check
attention_mask = attn if past is None else None, # Use past is None check
past_key_values = past_key_values, # Pass past (will be None initially, then the cache object)
max_new_tokens = 1, # Set max_new_tokens to 1 for debugging cache
logits_processor=[masker],
do_sample=True, temperature=0.7, top_p=0.95,
use_cache=True, # Re-enabled cache
return_dict_in_generate=True,
#return_legacy_cache=True,
#cache_implementation="static" # Enabled StaticCache via implementation
)
print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging
# ----- neue Tokens heraus schneiden --------------------------
seq = gen.sequences[0].tolist()
new = seq[offset_len:]
if not new: # nichts -> fertig
break
offset_len += len(new)
# ----- Update past and last_tok (Cache Re-enabled) ---------
# ids = torch.tensor([seq], device=device) # Removed full sequence update
# attn = torch.ones_like(ids) # Removed full sequence update
#pkv = gen.past_key_values # Update past with the cache object returned by generate
print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging
#if isinstance(pkv, StaticCache): pkv = pkv.to_legacy()
past = gen.past_key_values
print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging
last_tok = new[-1]
print("new tokens:", new[:25], flush=True)
# ----- Token‑Handling ----------------------------------------
for t in new:
if t == EOS_TOKEN: # Re-enabled EOS check
raise StopIteration # Re-enabled EOS check
if t == NEW_BLOCK:
buf.clear()
continue
buf.append(t - AUDIO_BASE) # Reverted to appending relative token
if len(buf) == 7:
await ws.send_bytes(decode_block(buf))
buf.clear()
masker.sent_blocks = 1 # ab jetzt EOS zulässig
except (StopIteration, WebSocketDisconnect):
pass
except Exception as e:
print("❌ WS‑Error:", e, flush=True)
import traceback
traceback.print_exc()
if ws.client_state.name != "DISCONNECTED":
await ws.close(code=1011)
finally:
if ws.client_state.name != "DISCONNECTED":
try:
await ws.close()
except RuntimeError:
pass
# 6) Dev‑Start --------------------------------------------------------
if __name__ == "__main__":
import uvicorn, sys
uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info") |