Tomtom84 commited on
Commit
2a41e43
·
verified ·
1 Parent(s): e271f39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -15
app.py CHANGED
@@ -102,44 +102,44 @@ async def tts(ws: WebSocket):
102
  ids, attn = build_prompt(text, voice)
103
  past = None
104
  offset_len = ids.size(1)
105
- cache_pos = offset_len # 0‑basiert
106
  last_tok = None
107
  buf = []
108
 
109
  while True:
 
 
110
  gen = model.generate(
111
  input_ids = ids if past is None else torch.tensor([[last_tok]], device=device),
112
  attention_mask = attn if past is None else None,
113
  past_key_values = past,
114
- cache_position = None if past is None else torch.tensor([cache_pos], device=device), # ← **NEU**
115
  max_new_tokens = CHUNK_TOKENS,
116
  logits_processor=[masker],
117
  do_sample=True, temperature=0.7, top_p=0.95,
118
  use_cache=True, return_dict_in_generate=True,
119
- return_legacy_cache=False
120
  )
121
 
122
- new = gen.sequences[0, offset_len:].tolist()
123
- if not new:
 
124
  break
125
 
126
- offset_len += len(new)
127
- cache_pos = offset_len - 1 # ← **NEU**
128
- past = gen.past_key_values
129
- last_tok = new[-1]
130
 
131
- # ----- Token‑Handling ----------------------------------------
132
- for t in new:
133
- if t == EOS_TOKEN: # Re-enabled EOS check
134
- raise StopIteration # Re-enabled EOS check
135
  if t == NEW_BLOCK:
136
  buf.clear()
137
  continue
138
- buf.append(t - AUDIO_BASE) # Reverted to appending relative token
139
  if len(buf) == 7:
140
  await ws.send_bytes(decode_block(buf))
141
  buf.clear()
142
- masker.sent_blocks = 1 # ab jetzt EOS zulässig
143
 
144
  except (StopIteration, WebSocketDisconnect):
145
  pass
 
102
  ids, attn = build_prompt(text, voice)
103
  past = None
104
  offset_len = ids.size(1)
105
+ past = None
106
  last_tok = None
107
  buf = []
108
 
109
  while True:
110
+ next_cache_pos = torch.tensor([offset_len], device=device) if past is not None else None
111
+
112
  gen = model.generate(
113
  input_ids = ids if past is None else torch.tensor([[last_tok]], device=device),
114
  attention_mask = attn if past is None else None,
115
  past_key_values = past,
116
+ cache_position = next_cache_pos, # **hier nur ab 2. Durchlauf**
117
  max_new_tokens = CHUNK_TOKENS,
118
  logits_processor=[masker],
119
  do_sample=True, temperature=0.7, top_p=0.95,
120
  use_cache=True, return_dict_in_generate=True,
 
121
  )
122
 
123
+ # neu erzeugte Tokens hinter dem bisherigen Ende
124
+ new_tokens = gen.sequences[0, offset_len:].tolist()
125
+ if not new_tokens:
126
  break
127
 
128
+ offset_len += len(new_tokens) # Cache ist jetzt größer
129
+ past = gen.past_key_values # Cache zurück für nächste Runde
130
+ last_tok = new_tokens[-1]
 
131
 
132
+ for t in new_tokens:
133
+ if t == EOS_TOKEN:
134
+ raise StopIteration
 
135
  if t == NEW_BLOCK:
136
  buf.clear()
137
  continue
138
+ buf.append(t - AUDIO_BASE)
139
  if len(buf) == 7:
140
  await ws.send_bytes(decode_block(buf))
141
  buf.clear()
142
+ masker.sent_blocks = 1 # ab jetzt darf EOS
143
 
144
  except (StopIteration, WebSocketDisconnect):
145
  pass