Spaces:
Paused
Paused
File size: 9,497 Bytes
0b5b901 87012a8 4189fe1 9bf14d0 fee868f d9ea17d 0316ec3 e3958ab 479f253 2008a3f 1ab029d 325e9ba e3958ab 325e9ba e3958ab 83532d0 f4406f3 e3958ab 479f253 e3958ab 3d65908 e3958ab a0cc672 e3958ab 9bf14d0 0dfc310 9bf14d0 e3958ab 9bf14d0 e3958ab 5031731 e3958ab 0b5b901 9bf14d0 5031731 e3958ab bca75ea d44e840 f63f843 e3958ab b17f5cd e3958ab 0b5b901 7bb84b7 e3958ab 9e2fbd8 e3958ab 9e2fbd8 e3958ab 0b5b901 e3958ab a8606ac d44e840 a09ea48 4189fe1 d44e840 e3958ab 7db0e09 325e9ba 7db0e09 fd51bc6 325e9ba fd51bc6 fee868f 325e9ba fd51bc6 325e9ba 2a41e43 325e9ba fd51bc6 325e9ba fd51bc6 325e9ba bca75ea 5031731 479f253 a09ea48 e3958ab 83532d0 5031731 479f253 5031731 e3958ab 5031731 e3958ab a4cfefc e3958ab 83532d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
# app.py ──────────────────────────────────────────────────────────────
import os, json, torch, asyncio
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, DynamicCache # Added StaticCache
from snac import SNAC
# 0) Login + Device ---------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
login(HF_TOKEN)
device = "cuda" if torch.cuda.is_available() else "cpu"
#torch.backends.cuda.enable_flash_sdp(False) # PyTorch‑2.2‑Bug
# 1) Konstanten -------------------------------------------------------
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
CHUNK_TOKENS = 50
START_TOKEN = 128259
NEW_BLOCK = 128257
EOS_TOKEN = 128258
AUDIO_BASE = 128266
AUDIO_SPAN = 4096 * 7 # 28 672 Codes
AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN) # Renamed VALID_AUDIO to AUDIO_IDS
# 2) Logit‑Mask (NEW_BLOCK + Audio; EOS erst nach 1. Block) ----------
class AudioMask(LogitsProcessor):
def __init__(self, audio_ids: torch.Tensor):
super().__init__()
self.allow = torch.cat([
torch.tensor([NEW_BLOCK], device=audio_ids.device),
audio_ids
])
self.eos = torch.tensor([EOS_TOKEN], device=audio_ids.device)
self.sent_blocks = 0
self.buffer_pos = 0 # Added buffer position
def __call__(self, input_ids, scores):
allow = torch.cat([self.allow, self.eos]) # Reverted masking logic
mask = torch.full_like(scores, float("-inf"))
mask[:, allow] = 0
return scores + mask
# 3) FastAPI Grundgerüst ---------------------------------------------
app = FastAPI()
@app.get("/")
def hello():
return {"status": "ok"}
@app.on_event("startup")
def load_models():
global tok, model, snac, masker
print("⏳ Lade Modelle …", flush=True)
tok = AutoTokenizer.from_pretrained(REPO)
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device)
model = AutoModelForCausalLM.from_pretrained(
REPO,
device_map={"": 0} if device == "cuda" else None,
torch_dtype=torch.bfloat16 if device == "cuda" else None,
low_cpu_mem_usage=True,
)
model.config.pad_token_id = model.config.eos_token_id
masker = AudioMask(AUDIO_IDS.to(device))
print("✅ Modelle geladen", flush=True)
# 4) Helper -----------------------------------------------------------
def build_prompt(text: str, voice: str):
prompt_ids = tok(f"{voice}: {text}", return_tensors="pt").input_ids.to(device)
ids = torch.cat([torch.tensor([[START_TOKEN]], device=device),
prompt_ids,
torch.tensor([[128009, 128260]], device=device)], 1)
attn = torch.ones_like(ids)
return ids, attn # Ensure attention mask is created
def decode_block(block7: list[int]) -> bytes:
l1,l2,l3=[],[],[]
l1.append(block7[0] - 0 * 4096) # Subtract position 0 offset
l2.append(block7[1] - 1 * 4096) # Subtract position 1 offset
l3 += [block7[2] - 2 * 4096, block7[3] - 3 * 4096] # Subtract position offsets
l2.append(block7[4] - 4 * 4096) # Subtract position 4 offset
l3 += [block7[5] - 5 * 4096, block7[6] - 6 * 4096] # Subtract position offsets
with torch.no_grad():
codes = [torch.tensor(x, device=device).unsqueeze(0)
for x in (l1,l2,l3)]
audio = snac.decode(codes).squeeze().detach().cpu().numpy()
return (audio*32767).astype("int16").tobytes()
# 5) WebSocket‑Endpoint ----------------------------------------------
@app.websocket("/ws/tts")
async def tts(ws: WebSocket):
await ws.accept()
try:
req = json.loads(await ws.receive_text())
text = req.get("text", "")
voice = req.get("voice", "Jakob")
ids, attn = build_prompt(text, voice)
past = None
ids, attn = build_prompt(text, voice)
past = None # Holds the DynamicCache object from past_key_values
buf = []
last_tok = None # Initialize last_tok
while True:
# Determine inputs for this iteration
if past is None:
# First iteration: Use the full prompt
current_input_ids = ids
current_attn_mask = attn
# DO NOT pass cache_position on the first run
current_cache_position = None
else:
# Subsequent iterations: Use only the last token
if last_tok is None:
print("Error: last_tok is None before subsequent generate call.")
break # Should not happen if generation proceeded
current_input_ids = torch.tensor([[last_tok]], device=device)
current_attn_mask = None # Not needed when past_key_values is provided
# DO NOT pass cache_position; let DynamicCache handle it
current_cache_position = None
# --- Call model.generate ---
try:
gen = model.generate(
input_ids=current_input_ids,
attention_mask=current_attn_mask,
past_key_values=past,
cache_position=current_cache_position, # Will be None after first iteration
max_new_tokens=CHUNK_TOKENS,
logits_processor=[masker],
do_sample=True, temperature=0.7, top_p=0.95,
use_cache=True,
return_dict_in_generate=True,
return_legacy_cache=False # Ensures DynamicCache
)
except Exception as e:
print(f"❌ Error during model.generate: {e}")
import traceback
traceback.print_exc()
break # Exit loop on generation error
# --- Process Output ---
# Get the full sequence generated *up to this point*
full_sequence_now = gen.sequences # Get the sequence tensor
# Determine the sequence length *before* this generation call using the cache
# If past is None, the previous length was the initial prompt length
prev_seq_len = past.get_seq_length() if past is not None else ids.shape
# The new tokens are those generated *in this call*
# These appear *after* the previously cached sequence length
# Ensure slicing is correct even if no new tokens are generated
if full_sequence_now.shape > prev_seq_len:
new_token_ids = full_sequence_now[prev_seq_len:]
new = new_token_ids.tolist() # Convert tensor to list
else:
new = [] # No new tokens generated
if not new: # If no new tokens were generated, stop
print("No new tokens generated, stopping.")
break
# Update past_key_values for the *next* iteration
past = gen.past_key_values # Update the cache state
# Get the very last token generated in *this* call for the *next* input
last_tok = new[-1]
# ----- Token‑Handling (process the 'new' list) -----
eos_found = False
for t in new:
if t == EOS_TOKEN:
print("EOS token encountered.")
eos_found = True
break # Stop processing tokens in this chunk
if t == NEW_BLOCK:
buf.clear()
continue
# Check if token is within the expected audio range
if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
buf.append(t - AUDIO_BASE)
else:
# Log unexpected tokens if necessary
# print(f"Warning: Generated token {t} outside expected audio range.")
pass # Ignore unexpected tokens for now
if len(buf) == 7:
await ws.send_bytes(decode_block(buf))
buf.clear()
# Allow EOS only after the first full block is sent
if not masker.sent_blocks:
masker.sent_blocks = 1
if eos_found:
# Handle any remaining buffer content if needed (e.g., log incomplete block)
if len(buf) > 0:
print(f"Warning: Incomplete audio block at EOS: {len(buf)} tokens. Discarding.")
buf.clear()
break # Exit the while loop
except (StopIteration, WebSocketDisconnect):
pass
except Exception as e:
print("❌ WS‑Error:", e, flush=True)
import traceback
traceback.print_exc()
if ws.client_state.name != "DISCONNECTED":
await ws.close(code=1011)
finally:
if ws.client_state.name != "DISCONNECTED":
try:
await ws.close()
except RuntimeError:
pass
# 6) Dev‑Start --------------------------------------------------------
if __name__ == "__main__":
import uvicorn, sys
uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info") |