Tomtom84 commited on
Commit
b87ae72
·
verified ·
1 Parent(s): 2a24991

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -59
app.py CHANGED
@@ -107,73 +107,87 @@ async def tts(ws: WebSocket):
107
  last_tok = None
108
  buf = []
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  while True:
111
- print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
112
- print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging)
113
-
114
- if past is None:
115
- # First generation step
116
- gen = model.generate(
117
- input_ids = ids,
118
- attention_mask = attn,
119
- past_key_values = past, # This will be None
120
- max_new_tokens = 1,
121
- logits_processor=[masker],
122
- do_sample=True, temperature=0.7, top_p=0.95,
 
 
123
  use_cache=True,
124
- return_dict_in_generate=True,
125
- )
126
- else:
127
- # Subsequent generation steps
128
- current_input_ids = torch.tensor([[last_tok]], device=device)
129
- current_attention_mask = torch.ones_like(current_input_ids)
130
- gen = model.generate(
131
- input_ids = current_input_ids,
132
- attention_mask = current_attention_mask,
133
- past_key_values = past, # This will be a Cache object
134
- max_new_tokens = 1,
135
- logits_processor=[masker],
136
- do_sample=True, temperature=0.7, top_p=0.95,
137
- use_cache=True,
138
- return_dict_in_generate=True,
139
- cache_position=torch.tensor([offset_len], device=device) # Explicitly pass cache_position
140
- )
141
-
142
- print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging)
143
 
144
- # Convert legacy tuple cache to DynamicCache if necessary (only after the first step)
145
- if past is None and isinstance(gen.past_key_values, tuple):
146
- past = DynamicCache.from_legacy_cache(gen.past_key_values)
147
- else:
148
- # For subsequent steps, just update past with the new cache object
149
- past = gen.past_key_values
150
 
151
- print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging)
152
 
153
- # ----- neue Tokens heraus schneiden --------------------------
154
- seq = gen.sequences[0].tolist()
155
- new = seq[offset_len:]
156
- if not new: # nichts -> fertig
157
- break
158
- offset_len += len(new)
159
 
160
- # ----- Update last_tok ---------
161
- last_tok = new[-1]
162
-
163
- print("new tokens:", new[:25], flush=True)
164
 
165
  # ----- Token‑Handling ----------------------------------------
166
- for t in new:
167
- if t == EOS_TOKEN: # Re-enabled EOS check
168
- raise StopIteration # Re-enabled EOS check
169
- if t == NEW_BLOCK:
170
- buf.clear()
171
- continue
172
- buf.append(t - AUDIO_BASE) # Reverted to appending relative token
173
- if len(buf) == 7:
174
- await ws.send_bytes(decode_block(buf))
175
- buf.clear()
176
- masker.sent_blocks = 1 # ab jetzt EOS zulässig
177
 
178
  except (StopIteration, WebSocketDisconnect):
179
  pass
 
107
  last_tok = None
108
  buf = []
109
 
110
+ # Initial generation step using model.generate
111
+ with torch.no_grad():
112
+ gen = model.generate(
113
+ input_ids = ids,
114
+ attention_mask = attn,
115
+ past_key_values = None, # Initial call, no past cache
116
+ max_new_tokens = 1,
117
+ logits_processor=[masker],
118
+ do_sample=True, temperature=0.7, top_p=0.95,
119
+ use_cache=True,
120
+ return_dict_in_generate=True,
121
+ )
122
+
123
+ # Get the initial cache and last token
124
+ past = gen.past_key_values
125
+ if isinstance(past, tuple):
126
+ past = DynamicCache.from_legacy_cache(past) # Convert legacy tuple cache
127
+ last_tok = gen.sequences[0].tolist()[-1]
128
+ offset_len += 1 # Increment offset for the first generated token
129
+
130
+ print(f"DEBUG: After initial generate - type of past: {type(past)}", flush=True) # Added logging
131
+ print("new tokens:", [last_tok], flush=True) # Log the first token
132
+
133
+ # Handle the first generated token
134
+ if last_tok == EOS_TOKEN:
135
+ raise StopIteration
136
+ if last_tok == NEW_BLOCK:
137
+ buf.clear()
138
+ else:
139
+ buf.append(last_tok - AUDIO_BASE)
140
+ if len(buf) == 7:
141
+ await ws.send_bytes(decode_block(buf))
142
+ buf.clear()
143
+ masker.sent_blocks = 1
144
+
145
+ # Manual generation loop for subsequent tokens
146
  while True:
147
+ print(f"DEBUG: Before forward - type of past: {type(past)}", flush=True) # Added logging
148
+
149
+ # Prepare inputs for the next token
150
+ current_input_ids = torch.tensor([[last_tok]], device=device)
151
+ current_attention_mask = torch.ones_like(current_input_ids)
152
+ current_cache_position = torch.tensor([offset_len], device=device)
153
+
154
+ # Perform forward pass
155
+ with torch.no_grad():
156
+ outputs = model(
157
+ input_ids=current_input_ids,
158
+ attention_mask=current_attention_mask,
159
+ past_key_values=past,
160
+ cache_position=current_cache_position,
161
  use_cache=True,
162
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ # Sample the next token (greedy sampling)
165
+ next_token_logits = outputs.logits[:, -1, :]
166
+ # Apply logits processor manually
167
+ processed_logits = masker(current_input_ids, next_token_logits.unsqueeze(0))[0]
168
+ next_token_id = torch.argmax(processed_logits).item()
 
169
 
170
+ print(f"DEBUG: After forward - type of outputs.past_key_values: {type(outputs.past_key_values)}", flush=True) # Added logging
171
 
172
+ # Update cache and last token
173
+ past = outputs.past_key_values
174
+ last_tok = next_token_id
175
+ offset_len += 1 # Increment offset for the new token
 
 
176
 
177
+ print("new tokens:", [last_tok], flush=True) # Log the new token
 
 
 
178
 
179
  # ----- Token‑Handling ----------------------------------------
180
+ if last_tok == EOS_TOKEN:
181
+ raise StopIteration
182
+ if last_tok == NEW_BLOCK:
183
+ buf.clear()
184
+ continue # Continue loop to generate the next token
185
+ buf.append(last_tok - AUDIO_BASE)
186
+ if len(buf) == 7:
187
+ await ws.send_bytes(decode_block(buf))
188
+ buf.clear()
189
+ masker.sent_blocks = 1 # ab jetzt EOS zulässig
190
+
191
 
192
  except (StopIteration, WebSocketDisconnect):
193
  pass