Tomtom84 commited on
Commit
b17f5cd
·
verified ·
1 Parent(s): c417a58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -73,7 +73,7 @@ def build_prompt(text: str, voice: str):
73
  prompt_ids,
74
  torch.tensor([[128009, 128260]], device=device)], 1)
75
  attn = torch.ones_like(ids)
76
- return ids, attn
77
 
78
  def decode_block(block7: list[int]) -> bytes:
79
  l1,l2,l3=[],[],[]
@@ -130,20 +130,23 @@ async def tts(ws: WebSocket):
130
  break
131
  offset_len += len(new)
132
 
133
- # ----- Update ids and attn for next iteration (Cache Disabled) ---------
134
- ids = torch.tensor([seq], device=device)
135
- attn = torch.ones_like(ids)
 
 
 
136
 
137
  print("new tokens:", new[:25], flush=True)
138
 
139
  # ----- Token‑Handling ----------------------------------------
140
  for t in new:
141
- if t == EOS_TOKEN:
142
- raise StopIteration
143
  if t == NEW_BLOCK:
144
  buf.clear()
145
  continue
146
- buf.append(t - AUDIO_BASE)
147
  if len(buf) == 7:
148
  await ws.send_bytes(decode_block(buf))
149
  buf.clear()
 
73
  prompt_ids,
74
  torch.tensor([[128009, 128260]], device=device)], 1)
75
  attn = torch.ones_like(ids)
76
+ return ids, attn # Ensure attention mask is created
77
 
78
  def decode_block(block7: list[int]) -> bytes:
79
  l1,l2,l3=[],[],[]
 
130
  break
131
  offset_len += len(new)
132
 
133
+ # ----- Update past and last_tok (Cache Re-enabled) ---------
134
+ # ids = torch.tensor([seq], device=device) # Removed full sequence update
135
+ # attn = torch.ones_like(ids) # Removed full sequence update
136
+ past = gen.past_key_values # Re-enabled cache update
137
+ print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging
138
+ last_tok = new[-1]
139
 
140
  print("new tokens:", new[:25], flush=True)
141
 
142
  # ----- Token‑Handling ----------------------------------------
143
  for t in new:
144
+ if t == EOS_TOKEN: # Re-enabled EOS check
145
+ raise StopIteration # Re-enabled EOS check
146
  if t == NEW_BLOCK:
147
  buf.clear()
148
  continue
149
+ buf.append(t - AUDIO_BASE) # Reverted to appending relative token
150
  if len(buf) == 7:
151
  await ws.send_bytes(decode_block(buf))
152
  buf.clear()