Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ import os, json, torch, asyncio
|
|
3 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
4 |
from huggingface_hub import login
|
5 |
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
|
|
|
6 |
from snac import SNAC
|
7 |
|
8 |
# 0) Login + Device ---------------------------------------------------
|
@@ -109,17 +110,33 @@ async def tts(ws: WebSocket):
|
|
109 |
print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
|
110 |
print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging
|
111 |
# --- Mini‑Generate (Cache Re-enabled) -------------------------------------------
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging
|
124 |
|
125 |
# ----- neue Tokens heraus schneiden --------------------------
|
|
|
3 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
4 |
from huggingface_hub import login
|
5 |
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
|
6 |
+
from transformers.generation.utils import Cache # Added import
|
7 |
from snac import SNAC
|
8 |
|
9 |
# 0) Login + Device ---------------------------------------------------
|
|
|
110 |
print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
|
111 |
print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging
|
112 |
# --- Mini‑Generate (Cache Re-enabled) -------------------------------------------
|
113 |
+
# --- Mini‑Generate (Cache Re-enabled) -------------------------------------------
|
114 |
+
if past is None:
|
115 |
+
gen = model.generate(
|
116 |
+
input_ids = ids,
|
117 |
+
attention_mask = attn,
|
118 |
+
past_key_values = past,
|
119 |
+
max_new_tokens = 1,
|
120 |
+
logits_processor=[masker],
|
121 |
+
do_sample=True, temperature=0.7, top_p=0.95,
|
122 |
+
use_cache=True,
|
123 |
+
return_dict_in_generate=True,
|
124 |
+
)
|
125 |
+
else:
|
126 |
+
# Provide attention mask for the single new token
|
127 |
+
current_input_ids = torch.tensor([[last_tok]], device=device)
|
128 |
+
current_attention_mask = torch.ones_like(current_input_ids)
|
129 |
+
|
130 |
+
gen = model.generate(
|
131 |
+
input_ids = current_input_ids,
|
132 |
+
attention_mask = current_attention_mask,
|
133 |
+
past_key_values = past,
|
134 |
+
max_new_tokens = 1,
|
135 |
+
logits_processor=[masker],
|
136 |
+
do_sample=True, temperature=0.7, top_p=0.95,
|
137 |
+
use_cache=True,
|
138 |
+
return_dict_in_generate=True,
|
139 |
+
)
|
140 |
print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging
|
141 |
|
142 |
# ----- neue Tokens heraus schneiden --------------------------
|