Update app.py
Browse files
app.py
CHANGED
@@ -107,73 +107,87 @@ async def tts(ws: WebSocket):
|
|
107 |
last_tok = None
|
108 |
buf = []
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
while True:
|
111 |
-
print(f"DEBUG: Before
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
123 |
use_cache=True,
|
124 |
-
|
125 |
-
)
|
126 |
-
else:
|
127 |
-
# Subsequent generation steps
|
128 |
-
current_input_ids = torch.tensor([[last_tok]], device=device)
|
129 |
-
current_attention_mask = torch.ones_like(current_input_ids)
|
130 |
-
gen = model.generate(
|
131 |
-
input_ids = current_input_ids,
|
132 |
-
attention_mask = current_attention_mask,
|
133 |
-
past_key_values = past, # This will be a Cache object
|
134 |
-
max_new_tokens = 1,
|
135 |
-
logits_processor=[masker],
|
136 |
-
do_sample=True, temperature=0.7, top_p=0.95,
|
137 |
-
use_cache=True,
|
138 |
-
return_dict_in_generate=True,
|
139 |
-
cache_position=torch.tensor([offset_len], device=device) # Explicitly pass cache_position
|
140 |
-
)
|
141 |
-
|
142 |
-
print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging)
|
143 |
|
144 |
-
#
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
past = gen.past_key_values
|
150 |
|
151 |
-
print(f"DEBUG: After
|
152 |
|
153 |
-
#
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
break
|
158 |
-
offset_len += len(new)
|
159 |
|
160 |
-
#
|
161 |
-
last_tok = new[-1]
|
162 |
-
|
163 |
-
print("new tokens:", new[:25], flush=True)
|
164 |
|
165 |
# ----- Token‑Handling ----------------------------------------
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
|
178 |
except (StopIteration, WebSocketDisconnect):
|
179 |
pass
|
|
|
107 |
last_tok = None
|
108 |
buf = []
|
109 |
|
110 |
+
# Initial generation step using model.generate
|
111 |
+
with torch.no_grad():
|
112 |
+
gen = model.generate(
|
113 |
+
input_ids = ids,
|
114 |
+
attention_mask = attn,
|
115 |
+
past_key_values = None, # Initial call, no past cache
|
116 |
+
max_new_tokens = 1,
|
117 |
+
logits_processor=[masker],
|
118 |
+
do_sample=True, temperature=0.7, top_p=0.95,
|
119 |
+
use_cache=True,
|
120 |
+
return_dict_in_generate=True,
|
121 |
+
)
|
122 |
+
|
123 |
+
# Get the initial cache and last token
|
124 |
+
past = gen.past_key_values
|
125 |
+
if isinstance(past, tuple):
|
126 |
+
past = DynamicCache.from_legacy_cache(past) # Convert legacy tuple cache
|
127 |
+
last_tok = gen.sequences[0].tolist()[-1]
|
128 |
+
offset_len += 1 # Increment offset for the first generated token
|
129 |
+
|
130 |
+
print(f"DEBUG: After initial generate - type of past: {type(past)}", flush=True) # Added logging
|
131 |
+
print("new tokens:", [last_tok], flush=True) # Log the first token
|
132 |
+
|
133 |
+
# Handle the first generated token
|
134 |
+
if last_tok == EOS_TOKEN:
|
135 |
+
raise StopIteration
|
136 |
+
if last_tok == NEW_BLOCK:
|
137 |
+
buf.clear()
|
138 |
+
else:
|
139 |
+
buf.append(last_tok - AUDIO_BASE)
|
140 |
+
if len(buf) == 7:
|
141 |
+
await ws.send_bytes(decode_block(buf))
|
142 |
+
buf.clear()
|
143 |
+
masker.sent_blocks = 1
|
144 |
+
|
145 |
+
# Manual generation loop for subsequent tokens
|
146 |
while True:
|
147 |
+
print(f"DEBUG: Before forward - type of past: {type(past)}", flush=True) # Added logging
|
148 |
+
|
149 |
+
# Prepare inputs for the next token
|
150 |
+
current_input_ids = torch.tensor([[last_tok]], device=device)
|
151 |
+
current_attention_mask = torch.ones_like(current_input_ids)
|
152 |
+
current_cache_position = torch.tensor([offset_len], device=device)
|
153 |
+
|
154 |
+
# Perform forward pass
|
155 |
+
with torch.no_grad():
|
156 |
+
outputs = model(
|
157 |
+
input_ids=current_input_ids,
|
158 |
+
attention_mask=current_attention_mask,
|
159 |
+
past_key_values=past,
|
160 |
+
cache_position=current_cache_position,
|
161 |
use_cache=True,
|
162 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
+
# Sample the next token (greedy sampling)
|
165 |
+
next_token_logits = outputs.logits[:, -1, :]
|
166 |
+
# Apply logits processor manually
|
167 |
+
processed_logits = masker(current_input_ids, next_token_logits.unsqueeze(0))[0]
|
168 |
+
next_token_id = torch.argmax(processed_logits).item()
|
|
|
169 |
|
170 |
+
print(f"DEBUG: After forward - type of outputs.past_key_values: {type(outputs.past_key_values)}", flush=True) # Added logging
|
171 |
|
172 |
+
# Update cache and last token
|
173 |
+
past = outputs.past_key_values
|
174 |
+
last_tok = next_token_id
|
175 |
+
offset_len += 1 # Increment offset for the new token
|
|
|
|
|
176 |
|
177 |
+
print("new tokens:", [last_tok], flush=True) # Log the new token
|
|
|
|
|
|
|
178 |
|
179 |
# ----- Token‑Handling ----------------------------------------
|
180 |
+
if last_tok == EOS_TOKEN:
|
181 |
+
raise StopIteration
|
182 |
+
if last_tok == NEW_BLOCK:
|
183 |
+
buf.clear()
|
184 |
+
continue # Continue loop to generate the next token
|
185 |
+
buf.append(last_tok - AUDIO_BASE)
|
186 |
+
if len(buf) == 7:
|
187 |
+
await ws.send_bytes(decode_block(buf))
|
188 |
+
buf.clear()
|
189 |
+
masker.sent_blocks = 1 # ab jetzt EOS zulässig
|
190 |
+
|
191 |
|
192 |
except (StopIteration, WebSocketDisconnect):
|
193 |
pass
|