Update app.py
Browse files
app.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
import os, json, torch, asyncio
|
| 3 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 4 |
from huggingface_hub import login
|
| 5 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
|
| 6 |
from snac import SNAC
|
| 7 |
|
| 8 |
# 0) Login + Device ---------------------------------------------------
|
|
@@ -100,22 +100,28 @@ async def tts(ws: WebSocket):
|
|
| 100 |
voice = req.get("voice", "Jakob")
|
| 101 |
|
| 102 |
ids, attn = build_prompt(text, voice)
|
|
|
|
|
|
|
| 103 |
offset_len = ids.size(1) # wie viele Tokens existieren schon
|
|
|
|
| 104 |
buf = []
|
| 105 |
|
| 106 |
while True:
|
| 107 |
-
|
|
|
|
|
|
|
| 108 |
gen = model.generate(
|
| 109 |
-
input_ids = ids,
|
| 110 |
-
attention_mask = attn,
|
| 111 |
-
past_key_values =
|
| 112 |
-
max_new_tokens = 1,
|
| 113 |
logits_processor=[masker],
|
| 114 |
-
do_sample=
|
| 115 |
-
use_cache=
|
| 116 |
return_dict_in_generate=True,
|
| 117 |
return_legacy_cache=True
|
| 118 |
)
|
|
|
|
| 119 |
|
| 120 |
# ----- neue Tokens heraus schneiden --------------------------
|
| 121 |
seq = gen.sequences[0].tolist()
|
|
|
|
| 2 |
import os, json, torch, asyncio
|
| 3 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 4 |
from huggingface_hub import login
|
| 5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StaticCache # Added StaticCache
|
| 6 |
from snac import SNAC
|
| 7 |
|
| 8 |
# 0) Login + Device ---------------------------------------------------
|
|
|
|
| 100 |
voice = req.get("voice", "Jakob")
|
| 101 |
|
| 102 |
ids, attn = build_prompt(text, voice)
|
| 103 |
+
# Initialized StaticCache
|
| 104 |
+
past = StaticCache(config=model.config, max_batch_size=1, max_cache_len=ids.size(1) + CHUNK_TOKENS, device=device, dtype=model.dtype)
|
| 105 |
offset_len = ids.size(1) # wie viele Tokens existieren schon
|
| 106 |
+
last_tok = None # Initialized last_tok
|
| 107 |
buf = []
|
| 108 |
|
| 109 |
while True:
|
| 110 |
+
print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
|
| 111 |
+
print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging
|
| 112 |
+
# --- Mini‑Generate (Cache Re-enabled) -------------------------------------------
|
| 113 |
gen = model.generate(
|
| 114 |
+
input_ids = ids if past.get_seq_length() == 0 else torch.tensor([[last_tok]], device=device), # Use cache seq length
|
| 115 |
+
attention_mask = attn if past.get_seq_length() == 0 else None, # Use cache seq length
|
| 116 |
+
past_key_values = past, # Use StaticCache instance
|
| 117 |
+
max_new_tokens = 1, # Set max_new_tokens to 1 for debugging cache
|
| 118 |
logits_processor=[masker],
|
| 119 |
+
do_sample=True, temperature=0.7, top_p=0.95,
|
| 120 |
+
use_cache=True, # Re-enabled cache
|
| 121 |
return_dict_in_generate=True,
|
| 122 |
return_legacy_cache=True
|
| 123 |
)
|
| 124 |
+
print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging
|
| 125 |
|
| 126 |
# ----- neue Tokens heraus schneiden --------------------------
|
| 127 |
seq = gen.sequences[0].tolist()
|