Tomtom84 commited on
Commit
55515cc
·
verified ·
1 Parent(s): 93bffbb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -11
app.py CHANGED
@@ -3,6 +3,7 @@ import os, json, torch, asyncio
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
 
6
  from snac import SNAC
7
 
8
  # 0) Login + Device ---------------------------------------------------
@@ -109,17 +110,33 @@ async def tts(ws: WebSocket):
109
  print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
110
  print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging
111
  # --- Mini‑Generate (Cache Re-enabled) -------------------------------------------
112
- gen = model.generate(
113
- input_ids = ids if past is None else torch.tensor([[last_tok]], device=device), # Re-enabled cache input
114
- attention_mask = attn if past is None else None, # Re-enabled cache attention
115
- past_key_values = past, # Re-enabled cache
116
- max_new_tokens = 1, # Set max_new_tokens to 1 for debugging cache
117
- logits_processor=[masker],
118
- do_sample=True, temperature=0.7, top_p=0.95,
119
- use_cache=True, # Re-enabled cache
120
- return_dict_in_generate=True,
121
- return_legacy_cache=True
122
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging
124
 
125
  # ----- neue Tokens heraus schneiden --------------------------
 
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
6
+ from transformers.generation.utils import Cache # Added import
7
  from snac import SNAC
8
 
9
  # 0) Login + Device ---------------------------------------------------
 
110
  print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
111
  print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging
112
  # --- Mini‑Generate (Cache Re-enabled) -------------------------------------------
113
+ # --- Mini‑Generate (Cache Re-enabled) -------------------------------------------
114
+ if past is None:
115
+ gen = model.generate(
116
+ input_ids = ids,
117
+ attention_mask = attn,
118
+ past_key_values = past,
119
+ max_new_tokens = 1,
120
+ logits_processor=[masker],
121
+ do_sample=True, temperature=0.7, top_p=0.95,
122
+ use_cache=True,
123
+ return_dict_in_generate=True,
124
+ )
125
+ else:
126
+ # Provide attention mask for the single new token
127
+ current_input_ids = torch.tensor([[last_tok]], device=device)
128
+ current_attention_mask = torch.ones_like(current_input_ids)
129
+
130
+ gen = model.generate(
131
+ input_ids = current_input_ids,
132
+ attention_mask = current_attention_mask,
133
+ past_key_values = past,
134
+ max_new_tokens = 1,
135
+ logits_processor=[masker],
136
+ do_sample=True, temperature=0.7, top_p=0.95,
137
+ use_cache=True,
138
+ return_dict_in_generate=True,
139
+ )
140
  print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging
141
 
142
  # ----- neue Tokens heraus schneiden --------------------------