Update app.py
Browse files
app.py
CHANGED
@@ -100,8 +100,7 @@ async def tts(ws: WebSocket):
|
|
100 |
voice = req.get("voice", "Jakob")
|
101 |
|
102 |
ids, attn = build_prompt(text, voice)
|
103 |
-
#
|
104 |
-
past = StaticCache(config=model.config, max_batch_size=1, max_cache_len=ids.size(1) + CHUNK_TOKENS, device=device, dtype=model.dtype)
|
105 |
offset_len = ids.size(1) # wie viele Tokens existieren schon
|
106 |
last_tok = None # Initialized last_tok
|
107 |
buf = []
|
@@ -109,17 +108,18 @@ async def tts(ws: WebSocket):
|
|
109 |
while True:
|
110 |
print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
|
111 |
print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging
|
112 |
-
# --- Mini‑Generate (
|
113 |
gen = model.generate(
|
114 |
-
input_ids = ids if past
|
115 |
-
attention_mask = attn if past
|
116 |
-
past_key_values = past, #
|
117 |
max_new_tokens = 1, # Set max_new_tokens to 1 for debugging cache
|
118 |
logits_processor=[masker],
|
119 |
do_sample=True, temperature=0.7, top_p=0.95,
|
120 |
use_cache=True, # Re-enabled cache
|
121 |
return_dict_in_generate=True,
|
122 |
-
return_legacy_cache=True
|
|
|
123 |
)
|
124 |
print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging
|
125 |
|
@@ -133,7 +133,7 @@ async def tts(ws: WebSocket):
|
|
133 |
# ----- Update past and last_tok (Cache Re-enabled) ---------
|
134 |
# ids = torch.tensor([seq], device=device) # Removed full sequence update
|
135 |
# attn = torch.ones_like(ids) # Removed full sequence update
|
136 |
-
past = gen.past_key_values #
|
137 |
print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging
|
138 |
last_tok = new[-1]
|
139 |
|
|
|
100 |
voice = req.get("voice", "Jakob")
|
101 |
|
102 |
ids, attn = build_prompt(text, voice)
|
103 |
+
past = None # Reverted past initialization
|
|
|
104 |
offset_len = ids.size(1) # wie viele Tokens existieren schon
|
105 |
last_tok = None # Initialized last_tok
|
106 |
buf = []
|
|
|
108 |
while True:
|
109 |
print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
|
110 |
print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging
|
111 |
+
# --- Mini‑Generate (StaticCache via cache_implementation) -------------------------------------------
|
112 |
gen = model.generate(
|
113 |
+
input_ids = ids if past is None else torch.tensor([[last_tok]], device=device), # Use past is None check
|
114 |
+
attention_mask = attn if past is None else None, # Use past is None check
|
115 |
+
past_key_values = past, # Pass past (will be None initially, then the cache object)
|
116 |
max_new_tokens = 1, # Set max_new_tokens to 1 for debugging cache
|
117 |
logits_processor=[masker],
|
118 |
do_sample=True, temperature=0.7, top_p=0.95,
|
119 |
use_cache=True, # Re-enabled cache
|
120 |
return_dict_in_generate=True,
|
121 |
+
return_legacy_cache=True,
|
122 |
+
cache_implementation="static" # Enabled StaticCache via implementation
|
123 |
)
|
124 |
print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging
|
125 |
|
|
|
133 |
# ----- Update past and last_tok (Cache Re-enabled) ---------
|
134 |
# ids = torch.tensor([seq], device=device) # Removed full sequence update
|
135 |
# attn = torch.ones_like(ids) # Removed full sequence update
|
136 |
+
past = gen.past_key_values # Update past with the cache object returned by generate
|
137 |
print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging
|
138 |
last_tok = new[-1]
|
139 |
|