Tomtom84's picture
Update app.py
325e9ba verified
raw
history blame
9.5 kB
# app.py ──────────────────────────────────────────────────────────────
import os, json, torch, asyncio
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, DynamicCache # Added StaticCache
from snac import SNAC
# 0) Login + Device ---------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(HF_TOKEN)
device = "cuda" if torch.cuda.is_available() else "cpu"
#torch.backends.cuda.enable_flash_sdp(False) # PyTorch‑2.2‑Bug
# 1) Konstanten -------------------------------------------------------
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
CHUNK_TOKENS = 50
START_TOKEN = 128259
NEW_BLOCK = 128257
EOS_TOKEN = 128258
AUDIO_BASE = 128266
AUDIO_SPAN = 4096 * 7 # 28 672 Codes
AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN) # Renamed VALID_AUDIO to AUDIO_IDS
# 2) Logit‑Mask (NEW_BLOCK + Audio; EOS erst nach 1. Block) ----------
class AudioMask(LogitsProcessor):
def __init__(self, audio_ids: torch.Tensor):
super().__init__()
self.allow = torch.cat([
torch.tensor([NEW_BLOCK], device=audio_ids.device),
audio_ids
])
self.eos = torch.tensor([EOS_TOKEN], device=audio_ids.device)
self.sent_blocks = 0
self.buffer_pos = 0 # Added buffer position
def __call__(self, input_ids, scores):
allow = torch.cat([self.allow, self.eos]) # Reverted masking logic
mask = torch.full_like(scores, float("-inf"))
mask[:, allow] = 0
return scores + mask
# 3) FastAPI GrundgerΓΌst ---------------------------------------------
app = FastAPI()
@app.get("/")
def hello():
return {"status": "ok"}
@app.on_event("startup")
def load_models():
global tok, model, snac, masker
print("⏳ Lade Modelle …", flush=True)
tok = AutoTokenizer.from_pretrained(REPO)
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
model = AutoModelForCausalLM.from_pretrained(
REPO,
device_map={"": 0} if device == "cuda" else None,
torch_dtype=torch.bfloat16 if device == "cuda" else None,
low_cpu_mem_usage=True,
)
model.config.pad_token_id = model.config.eos_token_id
masker = AudioMask(AUDIO_IDS.to(device))
print("βœ… Modelle geladen", flush=True)
# 4) Helper -----------------------------------------------------------
def build_prompt(text: str, voice: str):
prompt_ids = tok(f"{voice}: {text}", return_tensors="pt").input_ids.to(device)
ids = torch.cat([torch.tensor([[START_TOKEN]], device=device),
prompt_ids,
torch.tensor([[128009, 128260]], device=device)], 1)
attn = torch.ones_like(ids)
return ids, attn # Ensure attention mask is created
def decode_block(block7: list[int]) -> bytes:
l1,l2,l3=[],[],[]
l1.append(block7[0] - 0 * 4096) # Subtract position 0 offset
l2.append(block7[1] - 1 * 4096) # Subtract position 1 offset
l3 += [block7[2] - 2 * 4096, block7[3] - 3 * 4096] # Subtract position offsets
l2.append(block7[4] - 4 * 4096) # Subtract position 4 offset
l3 += [block7[5] - 5 * 4096, block7[6] - 6 * 4096] # Subtract position offsets
with torch.no_grad():
codes = [torch.tensor(x, device=device).unsqueeze(0)
for x in (l1,l2,l3)]
audio = snac.decode(codes).squeeze().detach().cpu().numpy()
return (audio*32767).astype("int16").tobytes()
# 5) WebSocket‑Endpoint ----------------------------------------------
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
await ws.accept()
try:
req = json.loads(await ws.receive_text())
text = req.get("text", "")
voice = req.get("voice", "Jakob")
ids, attn = build_prompt(text, voice)
past = None
ids, attn = build_prompt(text, voice)
past = None # Holds the DynamicCache object from past_key_values
buf = []
last_tok = None # Initialize last_tok
while True:
# Determine inputs for this iteration
if past is None:
# First iteration: Use the full prompt
current_input_ids = ids
current_attn_mask = attn
# DO NOT pass cache_position on the first run
current_cache_position = None
else:
# Subsequent iterations: Use only the last token
if last_tok is None:
print("Error: last_tok is None before subsequent generate call.")
break # Should not happen if generation proceeded
current_input_ids = torch.tensor([[last_tok]], device=device)
current_attn_mask = None # Not needed when past_key_values is provided
# DO NOT pass cache_position; let DynamicCache handle it
current_cache_position = None
# --- Call model.generate ---
try:
gen = model.generate(
input_ids=current_input_ids,
attention_mask=current_attn_mask,
past_key_values=past,
cache_position=current_cache_position, # Will be None after first iteration
max_new_tokens=CHUNK_TOKENS,
logits_processor=[masker],
do_sample=True, temperature=0.7, top_p=0.95,
use_cache=True,
return_dict_in_generate=True,
return_legacy_cache=False # Ensures DynamicCache
)
except Exception as e:
print(f"❌ Error during model.generate: {e}")
import traceback
traceback.print_exc()
break # Exit loop on generation error
# --- Process Output ---
# Get the full sequence generated *up to this point*
full_sequence_now = gen.sequences # Get the sequence tensor
# Determine the sequence length *before* this generation call using the cache
# If past is None, the previous length was the initial prompt length
prev_seq_len = past.get_seq_length() if past is not None else ids.shape
# The new tokens are those generated *in this call*
# These appear *after* the previously cached sequence length
# Ensure slicing is correct even if no new tokens are generated
if full_sequence_now.shape > prev_seq_len:
new_token_ids = full_sequence_now[prev_seq_len:]
new = new_token_ids.tolist() # Convert tensor to list
else:
new = [] # No new tokens generated
if not new: # If no new tokens were generated, stop
print("No new tokens generated, stopping.")
break
# Update past_key_values for the *next* iteration
past = gen.past_key_values # Update the cache state
# Get the very last token generated in *this* call for the *next* input
last_tok = new[-1]
# ----- Token‑Handling (process the 'new' list) -----
eos_found = False
for t in new:
if t == EOS_TOKEN:
print("EOS token encountered.")
eos_found = True
break # Stop processing tokens in this chunk
if t == NEW_BLOCK:
buf.clear()
continue
# Check if token is within the expected audio range
if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
buf.append(t - AUDIO_BASE)
else:
# Log unexpected tokens if necessary
# print(f"Warning: Generated token {t} outside expected audio range.")
pass # Ignore unexpected tokens for now
if len(buf) == 7:
await ws.send_bytes(decode_block(buf))
buf.clear()
# Allow EOS only after the first full block is sent
if not masker.sent_blocks:
masker.sent_blocks = 1
if eos_found:
# Handle any remaining buffer content if needed (e.g., log incomplete block)
if len(buf) > 0:
print(f"Warning: Incomplete audio block at EOS: {len(buf)} tokens. Discarding.")
buf.clear()
break # Exit the while loop
except (StopIteration, WebSocketDisconnect):
pass
except Exception as e:
print("❌ WS‑Error:", e, flush=True)
import traceback
traceback.print_exc()
if ws.client_state.name != "DISCONNECTED":
await ws.close(code=1011)
finally:
if ws.client_state.name != "DISCONNECTED":
try:
await ws.close()
except RuntimeError:
pass
# 6) Dev‑Start --------------------------------------------------------
if __name__ == "__main__":
import uvicorn, sys
uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info")