Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
4 |
from huggingface_hub import login
|
5 |
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
|
6 |
from transformers.generation.utils import Cache # Added import
|
|
|
7 |
from snac import SNAC
|
8 |
|
9 |
# 0) Login + Device ---------------------------------------------------
|
@@ -108,14 +109,14 @@ async def tts(ws: WebSocket):
|
|
108 |
|
109 |
while True:
|
110 |
print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
|
111 |
-
print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging
|
112 |
-
|
113 |
-
# --- Mini‑Generate (Cache Re-enabled) -------------------------------------------
|
114 |
if past is None:
|
|
|
115 |
gen = model.generate(
|
116 |
input_ids = ids,
|
117 |
attention_mask = attn,
|
118 |
-
past_key_values = past,
|
119 |
max_new_tokens = 1,
|
120 |
logits_processor=[masker],
|
121 |
do_sample=True, temperature=0.7, top_p=0.95,
|
@@ -123,14 +124,13 @@ async def tts(ws: WebSocket):
|
|
123 |
return_dict_in_generate=True,
|
124 |
)
|
125 |
else:
|
126 |
-
#
|
127 |
current_input_ids = torch.tensor([[last_tok]], device=device)
|
128 |
current_attention_mask = torch.ones_like(current_input_ids)
|
129 |
-
|
130 |
gen = model.generate(
|
131 |
input_ids = current_input_ids,
|
132 |
attention_mask = current_attention_mask,
|
133 |
-
past_key_values = past,
|
134 |
max_new_tokens = 1,
|
135 |
logits_processor=[masker],
|
136 |
do_sample=True, temperature=0.7, top_p=0.95,
|
@@ -138,7 +138,17 @@ async def tts(ws: WebSocket):
|
|
138 |
return_dict_in_generate=True,
|
139 |
cache_position=torch.tensor([offset_len], device=device) # Explicitly pass cache_position
|
140 |
)
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
# ----- neue Tokens heraus schneiden --------------------------
|
144 |
seq = gen.sequences[0].tolist()
|
@@ -147,11 +157,7 @@ async def tts(ws: WebSocket):
|
|
147 |
break
|
148 |
offset_len += len(new)
|
149 |
|
150 |
-
# ----- Update
|
151 |
-
# ids = torch.tensor([seq], device=device) # Removed full sequence update
|
152 |
-
# attn = torch.ones_like(ids) # Removed full sequence update
|
153 |
-
past = gen.past_key_values # Re-enabled cache update
|
154 |
-
print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging
|
155 |
last_tok = new[-1]
|
156 |
|
157 |
print("new tokens:", new[:25], flush=True)
|
|
|
4 |
from huggingface_hub import login
|
5 |
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
|
6 |
from transformers.generation.utils import Cache # Added import
|
7 |
+
from transformers.cache_utils import DynamicCache # Added import
|
8 |
from snac import SNAC
|
9 |
|
10 |
# 0) Login + Device ---------------------------------------------------
|
|
|
109 |
|
110 |
while True:
|
111 |
print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
|
112 |
+
print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging)
|
113 |
+
|
|
|
114 |
if past is None:
|
115 |
+
# First generation step
|
116 |
gen = model.generate(
|
117 |
input_ids = ids,
|
118 |
attention_mask = attn,
|
119 |
+
past_key_values = past, # This will be None
|
120 |
max_new_tokens = 1,
|
121 |
logits_processor=[masker],
|
122 |
do_sample=True, temperature=0.7, top_p=0.95,
|
|
|
124 |
return_dict_in_generate=True,
|
125 |
)
|
126 |
else:
|
127 |
+
# Subsequent generation steps
|
128 |
current_input_ids = torch.tensor([[last_tok]], device=device)
|
129 |
current_attention_mask = torch.ones_like(current_input_ids)
|
|
|
130 |
gen = model.generate(
|
131 |
input_ids = current_input_ids,
|
132 |
attention_mask = current_attention_mask,
|
133 |
+
past_key_values = past, # This will be a Cache object
|
134 |
max_new_tokens = 1,
|
135 |
logits_processor=[masker],
|
136 |
do_sample=True, temperature=0.7, top_p=0.95,
|
|
|
138 |
return_dict_in_generate=True,
|
139 |
cache_position=torch.tensor([offset_len], device=device) # Explicitly pass cache_position
|
140 |
)
|
141 |
+
|
142 |
+
print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging)
|
143 |
+
|
144 |
+
# Convert legacy tuple cache to DynamicCache if necessary (only after the first step)
|
145 |
+
if past is None and isinstance(gen.past_key_values, tuple):
|
146 |
+
past = DynamicCache.from_legacy_cache(gen.past_key_values)
|
147 |
+
else:
|
148 |
+
# For subsequent steps, just update past with the new cache object
|
149 |
+
past = gen.past_key_values
|
150 |
+
|
151 |
+
print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging)
|
152 |
|
153 |
# ----- neue Tokens heraus schneiden --------------------------
|
154 |
seq = gen.sequences[0].tolist()
|
|
|
157 |
break
|
158 |
offset_len += len(new)
|
159 |
|
160 |
+
# ----- Update last_tok ---------
|
|
|
|
|
|
|
|
|
161 |
last_tok = new[-1]
|
162 |
|
163 |
print("new tokens:", new[:25], flush=True)
|