Tomtom84 commited on
Commit
93bffbb
·
verified ·
1 Parent(s): 14f1558

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -8
app.py CHANGED
@@ -104,22 +104,23 @@ async def tts(ws: WebSocket):
104
  offset_len = ids.size(1) # wie viele Tokens existieren schon
105
  last_tok = None
106
  buf = []
107
- masker.buffer_pos = 0 # Removed initialization here
108
- # Removed buffer_pos update before generation
109
 
110
  while True:
 
 
111
  # --- Mini‑Generate (Cache Re-enabled) -------------------------------------------
112
  gen = model.generate(
113
  input_ids = ids if past is None else torch.tensor([[last_tok]], device=device), # Re-enabled cache input
114
  attention_mask = attn if past is None else None, # Re-enabled cache attention
115
  past_key_values = past, # Re-enabled cache
116
- max_new_tokens = CHUNK_TOKENS,
117
  logits_processor=[masker],
118
  do_sample=True, temperature=0.7, top_p=0.95,
119
  use_cache=True, # Re-enabled cache
120
  return_dict_in_generate=True,
121
  return_legacy_cache=True
122
  )
 
123
 
124
  # ----- neue Tokens heraus schneiden --------------------------
125
  seq = gen.sequences[0].tolist()
@@ -132,6 +133,7 @@ async def tts(ws: WebSocket):
132
  # ids = torch.tensor([seq], device=device) # Removed full sequence update
133
  # attn = torch.ones_like(ids) # Removed full sequence update
134
  past = gen.past_key_values # Re-enabled cache update
 
135
  last_tok = new[-1]
136
 
137
  print("new tokens:", new[:25], flush=True)
@@ -143,16 +145,11 @@ async def tts(ws: WebSocket):
143
  if t == NEW_BLOCK:
144
  buf.clear()
145
  continue
146
- # Only append if it's an audio token
147
- # Only append if it's an audio token
148
  buf.append(t - AUDIO_BASE) # Reverted to appending relative token
149
- masker.buffer_pos += 1 # Removed increment here
150
  if len(buf) == 7:
151
  await ws.send_bytes(decode_block(buf))
152
  buf.clear()
153
  masker.sent_blocks = 1 # ab jetzt EOS zulässig
154
- masker.buffer_pos = 0 # Removed reset here
155
- # Removed else block for skipping non-audio tokens
156
 
157
  except (StopIteration, WebSocketDisconnect):
158
  pass
 
104
  offset_len = ids.size(1) # wie viele Tokens existieren schon
105
  last_tok = None
106
  buf = []
 
 
107
 
108
  while True:
109
+ print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
110
+ print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging
111
  # --- Mini‑Generate (Cache Re-enabled) -------------------------------------------
112
  gen = model.generate(
113
  input_ids = ids if past is None else torch.tensor([[last_tok]], device=device), # Re-enabled cache input
114
  attention_mask = attn if past is None else None, # Re-enabled cache attention
115
  past_key_values = past, # Re-enabled cache
116
+ max_new_tokens = 1, # Set max_new_tokens to 1 for debugging cache
117
  logits_processor=[masker],
118
  do_sample=True, temperature=0.7, top_p=0.95,
119
  use_cache=True, # Re-enabled cache
120
  return_dict_in_generate=True,
121
  return_legacy_cache=True
122
  )
123
+ print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging
124
 
125
  # ----- neue Tokens heraus schneiden --------------------------
126
  seq = gen.sequences[0].tolist()
 
133
  # ids = torch.tensor([seq], device=device) # Removed full sequence update
134
  # attn = torch.ones_like(ids) # Removed full sequence update
135
  past = gen.past_key_values # Re-enabled cache update
136
+ print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging
137
  last_tok = new[-1]
138
 
139
  print("new tokens:", new[:25], flush=True)
 
145
  if t == NEW_BLOCK:
146
  buf.clear()
147
  continue
 
 
148
  buf.append(t - AUDIO_BASE) # Reverted to appending relative token
 
149
  if len(buf) == 7:
150
  await ws.send_bytes(decode_block(buf))
151
  buf.clear()
152
  masker.sent_blocks = 1 # ab jetzt EOS zulässig
 
 
153
 
154
  except (StopIteration, WebSocketDisconnect):
155
  pass