Tomtom84 commited on
Commit
2a24991
·
verified ·
1 Parent(s): 29f8312

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -13
app.py CHANGED
@@ -4,6 +4,7 @@ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
6
  from transformers.generation.utils import Cache # Added import
 
7
  from snac import SNAC
8
 
9
  # 0) Login + Device ---------------------------------------------------
@@ -108,14 +109,14 @@ async def tts(ws: WebSocket):
108
 
109
  while True:
110
  print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
111
- print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging
112
- # --- Mini‑Generate (Cache Re-enabled) -------------------------------------------
113
- # --- Mini‑Generate (Cache Re-enabled) -------------------------------------------
114
  if past is None:
 
115
  gen = model.generate(
116
  input_ids = ids,
117
  attention_mask = attn,
118
- past_key_values = past,
119
  max_new_tokens = 1,
120
  logits_processor=[masker],
121
  do_sample=True, temperature=0.7, top_p=0.95,
@@ -123,14 +124,13 @@ async def tts(ws: WebSocket):
123
  return_dict_in_generate=True,
124
  )
125
  else:
126
- # Provide attention mask for the single new token
127
  current_input_ids = torch.tensor([[last_tok]], device=device)
128
  current_attention_mask = torch.ones_like(current_input_ids)
129
-
130
  gen = model.generate(
131
  input_ids = current_input_ids,
132
  attention_mask = current_attention_mask,
133
- past_key_values = past,
134
  max_new_tokens = 1,
135
  logits_processor=[masker],
136
  do_sample=True, temperature=0.7, top_p=0.95,
@@ -138,7 +138,17 @@ async def tts(ws: WebSocket):
138
  return_dict_in_generate=True,
139
  cache_position=torch.tensor([offset_len], device=device) # Explicitly pass cache_position
140
  )
141
- print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging
 
 
 
 
 
 
 
 
 
 
142
 
143
  # ----- neue Tokens heraus schneiden --------------------------
144
  seq = gen.sequences[0].tolist()
@@ -147,11 +157,7 @@ async def tts(ws: WebSocket):
147
  break
148
  offset_len += len(new)
149
 
150
- # ----- Update past and last_tok (Cache Re-enabled) ---------
151
- # ids = torch.tensor([seq], device=device) # Removed full sequence update
152
- # attn = torch.ones_like(ids) # Removed full sequence update
153
- past = gen.past_key_values # Re-enabled cache update
154
- print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging
155
  last_tok = new[-1]
156
 
157
  print("new tokens:", new[:25], flush=True)
 
4
  from huggingface_hub import login
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
6
  from transformers.generation.utils import Cache # Added import
7
+ from transformers.cache_utils import DynamicCache # Added import
8
  from snac import SNAC
9
 
10
  # 0) Login + Device ---------------------------------------------------
 
109
 
110
  while True:
111
  print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
112
+ print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging)
113
+
 
114
  if past is None:
115
+ # First generation step
116
  gen = model.generate(
117
  input_ids = ids,
118
  attention_mask = attn,
119
+ past_key_values = past, # This will be None
120
  max_new_tokens = 1,
121
  logits_processor=[masker],
122
  do_sample=True, temperature=0.7, top_p=0.95,
 
124
  return_dict_in_generate=True,
125
  )
126
  else:
127
+ # Subsequent generation steps
128
  current_input_ids = torch.tensor([[last_tok]], device=device)
129
  current_attention_mask = torch.ones_like(current_input_ids)
 
130
  gen = model.generate(
131
  input_ids = current_input_ids,
132
  attention_mask = current_attention_mask,
133
+ past_key_values = past, # This will be a Cache object
134
  max_new_tokens = 1,
135
  logits_processor=[masker],
136
  do_sample=True, temperature=0.7, top_p=0.95,
 
138
  return_dict_in_generate=True,
139
  cache_position=torch.tensor([offset_len], device=device) # Explicitly pass cache_position
140
  )
141
+
142
+ print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging)
143
+
144
+ # Convert legacy tuple cache to DynamicCache if necessary (only after the first step)
145
+ if past is None and isinstance(gen.past_key_values, tuple):
146
+ past = DynamicCache.from_legacy_cache(gen.past_key_values)
147
+ else:
148
+ # For subsequent steps, just update past with the new cache object
149
+ past = gen.past_key_values
150
+
151
+ print(f"DEBUG: After cache update - type of past: {type(past)}", flush=True) # Added logging)
152
 
153
  # ----- neue Tokens heraus schneiden --------------------------
154
  seq = gen.sequences[0].tolist()
 
157
  break
158
  offset_len += len(new)
159
 
160
+ # ----- Update last_tok ---------
 
 
 
 
161
  last_tok = new[-1]
162
 
163
  print("new tokens:", new[:25], flush=True)