Update app.py
Browse files
app.py
CHANGED
@@ -73,7 +73,7 @@ def build_prompt(text: str, voice: str):
|
|
73 |
prompt_ids,
|
74 |
torch.tensor([[128009, 128260]], device=device)], 1)
|
75 |
attn = torch.ones_like(ids)
|
76 |
-
return ids, attn
|
77 |
|
78 |
def decode_block(block7: list[int]) -> bytes:
|
79 |
l1,l2,l3=[],[],[]
|
@@ -130,20 +130,23 @@ async def tts(ws: WebSocket):
|
|
130 |
break
|
131 |
offset_len += len(new)
|
132 |
|
133 |
-
# ----- Update
|
134 |
-
ids = torch.tensor([seq], device=device)
|
135 |
-
attn = torch.ones_like(ids)
|
|
|
|
|
|
|
136 |
|
137 |
print("new tokens:", new[:25], flush=True)
|
138 |
|
139 |
# ----- Token‑Handling ----------------------------------------
|
140 |
for t in new:
|
141 |
-
if t == EOS_TOKEN:
|
142 |
-
raise StopIteration
|
143 |
if t == NEW_BLOCK:
|
144 |
buf.clear()
|
145 |
continue
|
146 |
-
buf.append(t - AUDIO_BASE)
|
147 |
if len(buf) == 7:
|
148 |
await ws.send_bytes(decode_block(buf))
|
149 |
buf.clear()
|
|
|
73 |
prompt_ids,
|
74 |
torch.tensor([[128009, 128260]], device=device)], 1)
|
75 |
attn = torch.ones_like(ids)
|
76 |
+
return ids, attn # Ensure attention mask is created
|
77 |
|
78 |
def decode_block(block7: list[int]) -> bytes:
|
79 |
l1,l2,l3=[],[],[]
|
|
|
130 |
break
|
131 |
offset_len += len(new)
|
132 |
|
133 |
+
# ----- Update past and last_tok (Cache Re-enabled) ---------
|
134 |
+
# ids = torch.tensor([seq], device=device) # Removed full sequence update
|
135 |
+
# attn = torch.ones_like(ids) # Removed full sequence update
|
136 |
+
past = gen.past_key_values # Re-enabled cache update
|
137 |
+
print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging
|
138 |
+
last_tok = new[-1]
|
139 |
|
140 |
print("new tokens:", new[:25], flush=True)
|
141 |
|
142 |
# ----- Token‑Handling ----------------------------------------
|
143 |
for t in new:
|
144 |
+
if t == EOS_TOKEN: # Re-enabled EOS check
|
145 |
+
raise StopIteration # Re-enabled EOS check
|
146 |
if t == NEW_BLOCK:
|
147 |
buf.clear()
|
148 |
continue
|
149 |
+
buf.append(t - AUDIO_BASE) # Reverted to appending relative token
|
150 |
if len(buf) == 7:
|
151 |
await ws.send_bytes(decode_block(buf))
|
152 |
buf.clear()
|