Tomtom84 commited on
Commit
29123b3
·
verified ·
1 Parent(s): b17f5cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -100,8 +100,7 @@ async def tts(ws: WebSocket):
100
  voice = req.get("voice", "Jakob")
101
 
102
  ids, attn = build_prompt(text, voice)
103
- # Initialized StaticCache
104
- past = StaticCache(config=model.config, max_batch_size=1, max_cache_len=ids.size(1) + CHUNK_TOKENS, device=device, dtype=model.dtype)
105
  offset_len = ids.size(1) # wie viele Tokens existieren schon
106
  last_tok = None # Initialized last_tok
107
  buf = []
@@ -109,17 +108,18 @@ async def tts(ws: WebSocket):
109
  while True:
110
  print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
111
  print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging
112
- # --- Mini‑Generate (Cache Re-enabled) -------------------------------------------
113
  gen = model.generate(
114
- input_ids = ids if past.get_seq_length() == 0 else torch.tensor([[last_tok]], device=device), # Use cache seq length
115
- attention_mask = attn if past.get_seq_length() == 0 else None, # Use cache seq length
116
- past_key_values = past, # Use StaticCache instance
117
  max_new_tokens = 1, # Set max_new_tokens to 1 for debugging cache
118
  logits_processor=[masker],
119
  do_sample=True, temperature=0.7, top_p=0.95,
120
  use_cache=True, # Re-enabled cache
121
  return_dict_in_generate=True,
122
- return_legacy_cache=True
 
123
  )
124
  print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging
125
 
@@ -133,7 +133,7 @@ async def tts(ws: WebSocket):
133
  # ----- Update past and last_tok (Cache Re-enabled) ---------
134
  # ids = torch.tensor([seq], device=device) # Removed full sequence update
135
  # attn = torch.ones_like(ids) # Removed full sequence update
136
- past = gen.past_key_values # Re-enabled cache update
137
  print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging
138
  last_tok = new[-1]
139
 
 
100
  voice = req.get("voice", "Jakob")
101
 
102
  ids, attn = build_prompt(text, voice)
103
+ past = None # Reverted past initialization
 
104
  offset_len = ids.size(1) # wie viele Tokens existieren schon
105
  last_tok = None # Initialized last_tok
106
  buf = []
 
108
  while True:
109
  print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
110
  print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging
111
+ # --- Mini‑Generate (StaticCache via cache_implementation) -------------------------------------------
112
  gen = model.generate(
113
+ input_ids = ids if past is None else torch.tensor([[last_tok]], device=device), # Use past is None check
114
+ attention_mask = attn if past is None else None, # Use past is None check
115
+ past_key_values = past, # Pass past (will be None initially, then the cache object)
116
  max_new_tokens = 1, # Set max_new_tokens to 1 for debugging cache
117
  logits_processor=[masker],
118
  do_sample=True, temperature=0.7, top_p=0.95,
119
  use_cache=True, # Re-enabled cache
120
  return_dict_in_generate=True,
121
+ return_legacy_cache=True,
122
+ cache_implementation="static" # Enabled StaticCache via implementation
123
  )
124
  print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging
125
 
 
133
  # ----- Update past and last_tok (Cache Re-enabled) ---------
134
  # ids = torch.tensor([seq], device=device) # Removed full sequence update
135
  # attn = torch.ones_like(ids) # Removed full sequence update
136
+ past = gen.past_key_values # Update past with the cache object returned by generate
137
  print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging
138
  last_tok = new[-1]
139