Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
import os, json, torch, asyncio
|
3 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
4 |
from huggingface_hub import login
|
5 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
|
6 |
from snac import SNAC
|
7 |
|
8 |
# 0) Login + Device ---------------------------------------------------
|
@@ -100,22 +100,28 @@ async def tts(ws: WebSocket):
|
|
100 |
voice = req.get("voice", "Jakob")
|
101 |
|
102 |
ids, attn = build_prompt(text, voice)
|
|
|
|
|
103 |
offset_len = ids.size(1) # wie viele Tokens existieren schon
|
|
|
104 |
buf = []
|
105 |
|
106 |
while True:
|
107 |
-
|
|
|
|
|
108 |
gen = model.generate(
|
109 |
-
input_ids = ids,
|
110 |
-
attention_mask = attn,
|
111 |
-
past_key_values =
|
112 |
-
max_new_tokens = 1,
|
113 |
logits_processor=[masker],
|
114 |
-
do_sample=
|
115 |
-
use_cache=
|
116 |
return_dict_in_generate=True,
|
117 |
return_legacy_cache=True
|
118 |
)
|
|
|
119 |
|
120 |
# ----- neue Tokens heraus schneiden --------------------------
|
121 |
seq = gen.sequences[0].tolist()
|
|
|
2 |
import os, json, torch, asyncio
|
3 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
4 |
from huggingface_hub import login
|
5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StaticCache # Added StaticCache
|
6 |
from snac import SNAC
|
7 |
|
8 |
# 0) Login + Device ---------------------------------------------------
|
|
|
100 |
voice = req.get("voice", "Jakob")
|
101 |
|
102 |
ids, attn = build_prompt(text, voice)
|
103 |
+
# Initialized StaticCache
|
104 |
+
past = StaticCache(config=model.config, max_batch_size=1, max_cache_len=ids.size(1) + CHUNK_TOKENS, device=device, dtype=model.dtype)
|
105 |
offset_len = ids.size(1) # wie viele Tokens existieren schon
|
106 |
+
last_tok = None # Initialized last_tok
|
107 |
buf = []
|
108 |
|
109 |
while True:
|
110 |
+
print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
|
111 |
+
print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging
|
112 |
+
# --- Mini‑Generate (Cache Re-enabled) -------------------------------------------
|
113 |
gen = model.generate(
|
114 |
+
input_ids = ids if past.get_seq_length() == 0 else torch.tensor([[last_tok]], device=device), # Use cache seq length
|
115 |
+
attention_mask = attn if past.get_seq_length() == 0 else None, # Use cache seq length
|
116 |
+
past_key_values = past, # Use StaticCache instance
|
117 |
+
max_new_tokens = 1, # Set max_new_tokens to 1 for debugging cache
|
118 |
logits_processor=[masker],
|
119 |
+
do_sample=True, temperature=0.7, top_p=0.95,
|
120 |
+
use_cache=True, # Re-enabled cache
|
121 |
return_dict_in_generate=True,
|
122 |
return_legacy_cache=True
|
123 |
)
|
124 |
+
print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging
|
125 |
|
126 |
# ----- neue Tokens heraus schneiden --------------------------
|
127 |
seq = gen.sequences[0].tolist()
|