Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
import os, json, torch, asyncio
|
3 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
4 |
from huggingface_hub import login
|
5 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor,
|
6 |
from snac import SNAC
|
7 |
|
8 |
# 0) Login + Device ---------------------------------------------------
|
@@ -104,7 +104,7 @@ async def tts(ws: WebSocket):
|
|
104 |
offset_len = ids.size(1) # wie viele Tokens existieren schon
|
105 |
last_tok = None # Initialized last_tok
|
106 |
buf = []
|
107 |
-
|
108 |
while True:
|
109 |
print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
|
110 |
print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging
|
@@ -112,14 +112,14 @@ async def tts(ws: WebSocket):
|
|
112 |
gen = model.generate(
|
113 |
input_ids = ids if past is None else torch.tensor([[last_tok]], device=device), # Use past is None check
|
114 |
attention_mask = attn if past is None else None, # Use past is None check
|
115 |
-
past_key_values =
|
116 |
max_new_tokens = 1, # Set max_new_tokens to 1 for debugging cache
|
117 |
logits_processor=[masker],
|
118 |
do_sample=True, temperature=0.7, top_p=0.95,
|
119 |
use_cache=True, # Re-enabled cache
|
120 |
return_dict_in_generate=True,
|
121 |
-
return_legacy_cache=True,
|
122 |
-
cache_implementation="static" # Enabled StaticCache via implementation
|
123 |
)
|
124 |
print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging
|
125 |
|
@@ -133,8 +133,12 @@ async def tts(ws: WebSocket):
|
|
133 |
# ----- Update past and last_tok (Cache Re-enabled) ---------
|
134 |
# ids = torch.tensor([seq], device=device) # Removed full sequence update
|
135 |
# attn = torch.ones_like(ids) # Removed full sequence update
|
136 |
-
|
|
|
|
|
|
|
137 |
print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging
|
|
|
138 |
last_tok = new[-1]
|
139 |
|
140 |
print("new tokens:", new[:25], flush=True)
|
|
|
2 |
import os, json, torch, asyncio
|
3 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
4 |
from huggingface_hub import login
|
5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, DynamicCache # Added StaticCache
|
6 |
from snac import SNAC
|
7 |
|
8 |
# 0) Login + Device ---------------------------------------------------
|
|
|
104 |
offset_len = ids.size(1) # wie viele Tokens existieren schon
|
105 |
last_tok = None # Initialized last_tok
|
106 |
buf = []
|
107 |
+
past_key_values = DynamicCache()
|
108 |
while True:
|
109 |
print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
|
110 |
print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging
|
|
|
112 |
gen = model.generate(
|
113 |
input_ids = ids if past is None else torch.tensor([[last_tok]], device=device), # Use past is None check
|
114 |
attention_mask = attn if past is None else None, # Use past is None check
|
115 |
+
past_key_values = past_key_values, # Pass past (will be None initially, then the cache object)
|
116 |
max_new_tokens = 1, # Set max_new_tokens to 1 for debugging cache
|
117 |
logits_processor=[masker],
|
118 |
do_sample=True, temperature=0.7, top_p=0.95,
|
119 |
use_cache=True, # Re-enabled cache
|
120 |
return_dict_in_generate=True,
|
121 |
+
#return_legacy_cache=True,
|
122 |
+
#cache_implementation="static" # Enabled StaticCache via implementation
|
123 |
)
|
124 |
print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging
|
125 |
|
|
|
133 |
# ----- Update past and last_tok (Cache Re-enabled) ---------
|
134 |
# ids = torch.tensor([seq], device=device) # Removed full sequence update
|
135 |
# attn = torch.ones_like(ids) # Removed full sequence update
|
136 |
+
#pkv = gen.past_key_values # Update past with the cache object returned by generate
|
137 |
+
print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging
|
138 |
+
#if isinstance(pkv, StaticCache): pkv = pkv.to_legacy()
|
139 |
+
past = gen.past_key_values
|
140 |
print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging
|
141 |
+
|
142 |
last_tok = new[-1]
|
143 |
|
144 |
print("new tokens:", new[:25], flush=True)
|