Tomtom84 commited on
Commit
7db0e09
·
verified ·
1 Parent(s): fee868f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -35
app.py CHANGED
@@ -99,49 +99,34 @@ async def tts(ws: WebSocket):
99
  text = req.get("text", "")
100
  voice = req.get("voice", "Jakob")
101
 
102
- ids, attn = build_prompt(text, voice)
103
- past = None # Reverted past initialization
104
- offset_len = ids.size(1) # wie viele Tokens existieren schon
105
- last_tok = None # Initialized last_tok
106
- buf = []
107
- past_key_values = DynamicCache()
 
108
  while True:
109
- print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
110
- print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging
111
- # --- Mini‑Generate (StaticCache via cache_implementation) -------------------------------------------
112
  gen = model.generate(
113
- input_ids = ids if past is None else torch.tensor([[last_tok]], device=device), # Use past is None check
114
- attention_mask = attn if past is None else None, # Use past is None check
115
- past_key_values = past_key_values, # Pass past (will be None initially, then the cache object)
116
- max_new_tokens = 1, # Set max_new_tokens to 1 for debugging cache
 
117
  logits_processor=[masker],
118
  do_sample=True, temperature=0.7, top_p=0.95,
119
- use_cache=True, # Re-enabled cache
120
- return_dict_in_generate=True,
121
- #return_legacy_cache=True,
122
- #cache_implementation="static" # Enabled StaticCache via implementation
123
  )
124
- print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging
125
 
126
- # ----- neue Tokens heraus schneiden --------------------------
127
- seq = gen.sequences[0].tolist()
128
- new = seq[offset_len:]
129
- if not new: # nichts -> fertig
130
  break
131
- offset_len += len(new)
132
-
133
- # ----- Update past and last_tok (Cache Re-enabled) ---------
134
- # ids = torch.tensor([seq], device=device) # Removed full sequence update
135
- # attn = torch.ones_like(ids) # Removed full sequence update
136
- #pkv = gen.past_key_values # Update past with the cache object returned by generate
137
- print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging
138
- #if isinstance(pkv, StaticCache): pkv = pkv.to_legacy()
139
- past = gen.past_key_values
140
- print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging
141
 
142
- last_tok = new[-1]
143
-
144
- print("new tokens:", new[:25], flush=True)
 
145
 
146
  # ----- Token‑Handling ----------------------------------------
147
  for t in new:
 
99
  text = req.get("text", "")
100
  voice = req.get("voice", "Jakob")
101
 
102
+ ids, attn = build_prompt(text, voice)
103
+ past = None
104
+ offset_len = ids.size(1)
105
+ cache_pos = offset_len - 1 # 0‑basiert
106
+ last_tok = None
107
+ buf = []
108
+
109
  while True:
 
 
 
110
  gen = model.generate(
111
+ input_ids = ids if past is None else torch.tensor([[last_tok]], device=device),
112
+ attention_mask = attn if past is None else None,
113
+ past_key_values = past,
114
+ cache_position = None if past is None else torch.tensor([cache_pos], device=device), # ← **NEU**
115
+ max_new_tokens = CHUNK_TOKENS,
116
  logits_processor=[masker],
117
  do_sample=True, temperature=0.7, top_p=0.95,
118
+ use_cache=True, return_dict_in_generate=True,
119
+ return_legacy_cache=False
 
 
120
  )
 
121
 
122
+ new = gen.sequences[0, offset_len:].tolist()
123
+ if not new:
 
 
124
  break
 
 
 
 
 
 
 
 
 
 
125
 
126
+ offset_len += len(new)
127
+ cache_pos = offset_len - 1 # ← **NEU**
128
+ past = gen.past_key_values
129
+ last_tok = new[-1]
130
 
131
  # ----- Token‑Handling ----------------------------------------
132
  for t in new: