Tomtom84 commited on
Commit
fd51bc6
·
verified ·
1 Parent(s): b87ae72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -74
app.py CHANGED
@@ -3,8 +3,6 @@ import os, json, torch, asyncio
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
6
- from transformers.generation.utils import Cache # Added import
7
- from transformers.cache_utils import DynamicCache # Added import
8
  from snac import SNAC
9
 
10
  # 0) Login + Device ---------------------------------------------------
@@ -102,92 +100,48 @@ async def tts(ws: WebSocket):
102
  voice = req.get("voice", "Jakob")
103
 
104
  ids, attn = build_prompt(text, voice)
105
- past = None
106
  offset_len = ids.size(1) # wie viele Tokens existieren schon
107
- last_tok = None
108
  buf = []
109
 
110
- # Initial generation step using model.generate
111
- with torch.no_grad():
112
  gen = model.generate(
113
  input_ids = ids,
114
  attention_mask = attn,
115
- past_key_values = None, # Initial call, no past cache
116
  max_new_tokens = 1,
117
  logits_processor=[masker],
118
  do_sample=True, temperature=0.7, top_p=0.95,
119
- use_cache=True,
120
  return_dict_in_generate=True,
 
121
  )
122
 
123
- # Get the initial cache and last token
124
- past = gen.past_key_values
125
- if isinstance(past, tuple):
126
- past = DynamicCache.from_legacy_cache(past) # Convert legacy tuple cache
127
- last_tok = gen.sequences[0].tolist()[-1]
128
- offset_len += 1 # Increment offset for the first generated token
129
-
130
- print(f"DEBUG: After initial generate - type of past: {type(past)}", flush=True) # Added logging
131
- print("new tokens:", [last_tok], flush=True) # Log the first token
132
-
133
- # Handle the first generated token
134
- if last_tok == EOS_TOKEN:
135
- raise StopIteration
136
- if last_tok == NEW_BLOCK:
137
- buf.clear()
138
- else:
139
- buf.append(last_tok - AUDIO_BASE)
140
- if len(buf) == 7:
141
- await ws.send_bytes(decode_block(buf))
142
- buf.clear()
143
- masker.sent_blocks = 1
144
-
145
- # Manual generation loop for subsequent tokens
146
- while True:
147
- print(f"DEBUG: Before forward - type of past: {type(past)}", flush=True) # Added logging
148
-
149
- # Prepare inputs for the next token
150
- current_input_ids = torch.tensor([[last_tok]], device=device)
151
- current_attention_mask = torch.ones_like(current_input_ids)
152
- current_cache_position = torch.tensor([offset_len], device=device)
153
-
154
- # Perform forward pass
155
- with torch.no_grad():
156
- outputs = model(
157
- input_ids=current_input_ids,
158
- attention_mask=current_attention_mask,
159
- past_key_values=past,
160
- cache_position=current_cache_position,
161
- use_cache=True,
162
- )
163
-
164
- # Sample the next token (greedy sampling)
165
- next_token_logits = outputs.logits[:, -1, :]
166
- # Apply logits processor manually
167
- processed_logits = masker(current_input_ids, next_token_logits.unsqueeze(0))[0]
168
- next_token_id = torch.argmax(processed_logits).item()
169
-
170
- print(f"DEBUG: After forward - type of outputs.past_key_values: {type(outputs.past_key_values)}", flush=True) # Added logging
171
-
172
- # Update cache and last token
173
- past = outputs.past_key_values
174
- last_tok = next_token_id
175
- offset_len += 1 # Increment offset for the new token
176
-
177
- print("new tokens:", [last_tok], flush=True) # Log the new token
178
 
179
- # ----- Token‑Handling ----------------------------------------
180
- if last_tok == EOS_TOKEN:
181
- raise StopIteration
182
- if last_tok == NEW_BLOCK:
183
- buf.clear()
184
- continue # Continue loop to generate the next token
185
- buf.append(last_tok - AUDIO_BASE)
186
- if len(buf) == 7:
187
- await ws.send_bytes(decode_block(buf))
188
- buf.clear()
189
- masker.sent_blocks = 1 # ab jetzt EOS zulässig
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  except (StopIteration, WebSocketDisconnect):
193
  pass
 
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
 
 
6
  from snac import SNAC
7
 
8
  # 0) Login + Device ---------------------------------------------------
 
100
  voice = req.get("voice", "Jakob")
101
 
102
  ids, attn = build_prompt(text, voice)
 
103
  offset_len = ids.size(1) # wie viele Tokens existieren schon
 
104
  buf = []
105
 
106
+ while True:
107
+ # --- Mini‑Generate (Cache Disabled) -------------------------------------------
108
  gen = model.generate(
109
  input_ids = ids,
110
  attention_mask = attn,
111
+ past_key_values = None, # Cache disabled
112
  max_new_tokens = 1,
113
  logits_processor=[masker],
114
  do_sample=True, temperature=0.7, top_p=0.95,
115
+ use_cache=False, # Cache disabled
116
  return_dict_in_generate=True,
117
+ return_legacy_cache=True
118
  )
119
 
120
+ # ----- neue Tokens heraus schneiden --------------------------
121
+ seq = gen.sequences[0].tolist()
122
+ new = seq[offset_len:]
123
+ if not new: # nichts -> fertig
124
+ break
125
+ offset_len += len(new)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ # ----- Update ids and attn for next iteration (Cache Disabled) ---------
128
+ ids = torch.tensor([seq], device=device)
129
+ attn = torch.ones_like(ids)
 
 
 
 
 
 
 
 
130
 
131
+ print("new tokens:", new[:25], flush=True)
132
+
133
+ # ----- Token‑Handling ----------------------------------------
134
+ for t in new:
135
+ if t == EOS_TOKEN:
136
+ raise StopIteration
137
+ if t == NEW_BLOCK:
138
+ buf.clear()
139
+ continue
140
+ buf.append(t - AUDIO_BASE)
141
+ if len(buf) == 7:
142
+ await ws.send_bytes(decode_block(buf))
143
+ buf.clear()
144
+ masker.sent_blocks = 1 # ab jetzt EOS zulässig
145
 
146
  except (StopIteration, WebSocketDisconnect):
147
  pass