Tomtom84 commited on
Commit
fee868f
·
verified ·
1 Parent(s): 29123b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -2,7 +2,7 @@
2
  import os, json, torch, asyncio
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StaticCache # Added StaticCache
6
  from snac import SNAC
7
 
8
  # 0) Login + Device ---------------------------------------------------
@@ -104,7 +104,7 @@ async def tts(ws: WebSocket):
104
  offset_len = ids.size(1) # wie viele Tokens existieren schon
105
  last_tok = None # Initialized last_tok
106
  buf = []
107
-
108
  while True:
109
  print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
110
  print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging
@@ -112,14 +112,14 @@ async def tts(ws: WebSocket):
112
  gen = model.generate(
113
  input_ids = ids if past is None else torch.tensor([[last_tok]], device=device), # Use past is None check
114
  attention_mask = attn if past is None else None, # Use past is None check
115
- past_key_values = past, # Pass past (will be None initially, then the cache object)
116
  max_new_tokens = 1, # Set max_new_tokens to 1 for debugging cache
117
  logits_processor=[masker],
118
  do_sample=True, temperature=0.7, top_p=0.95,
119
  use_cache=True, # Re-enabled cache
120
  return_dict_in_generate=True,
121
- return_legacy_cache=True,
122
- cache_implementation="static" # Enabled StaticCache via implementation
123
  )
124
  print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging
125
 
@@ -133,8 +133,12 @@ async def tts(ws: WebSocket):
133
  # ----- Update past and last_tok (Cache Re-enabled) ---------
134
  # ids = torch.tensor([seq], device=device) # Removed full sequence update
135
  # attn = torch.ones_like(ids) # Removed full sequence update
136
- past = gen.past_key_values # Update past with the cache object returned by generate
 
 
 
137
  print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging
 
138
  last_tok = new[-1]
139
 
140
  print("new tokens:", new[:25], flush=True)
 
2
  import os, json, torch, asyncio
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, DynamicCache # Added StaticCache
6
  from snac import SNAC
7
 
8
  # 0) Login + Device ---------------------------------------------------
 
104
  offset_len = ids.size(1) # wie viele Tokens existieren schon
105
  last_tok = None # Initialized last_tok
106
  buf = []
107
+ past_key_values = DynamicCache()
108
  while True:
109
  print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
110
  print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging
 
112
  gen = model.generate(
113
  input_ids = ids if past is None else torch.tensor([[last_tok]], device=device), # Use past is None check
114
  attention_mask = attn if past is None else None, # Use past is None check
115
+ past_key_values = past_key_values, # Pass past (will be None initially, then the cache object)
116
  max_new_tokens = 1, # Set max_new_tokens to 1 for debugging cache
117
  logits_processor=[masker],
118
  do_sample=True, temperature=0.7, top_p=0.95,
119
  use_cache=True, # Re-enabled cache
120
  return_dict_in_generate=True,
121
+ #return_legacy_cache=True,
122
+ #cache_implementation="static" # Enabled StaticCache via implementation
123
  )
124
  print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging
125
 
 
133
  # ----- Update past and last_tok (Cache Re-enabled) ---------
134
  # ids = torch.tensor([seq], device=device) # Removed full sequence update
135
  # attn = torch.ones_like(ids) # Removed full sequence update
136
+ #pkv = gen.past_key_values # Update past with the cache object returned by generate
137
+ print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging
138
+ #if isinstance(pkv, StaticCache): pkv = pkv.to_legacy()
139
+ past = gen.past_key_values
140
  print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging
141
+
142
  last_tok = new[-1]
143
 
144
  print("new tokens:", new[:25], flush=True)