Tomtom84 commited on
Commit
c417a58
·
verified ·
1 Parent(s): 7f16630

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -2,7 +2,7 @@
2
  import os, json, torch, asyncio
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
6
  from snac import SNAC
7
 
8
  # 0) Login + Device ---------------------------------------------------
@@ -100,22 +100,28 @@ async def tts(ws: WebSocket):
100
  voice = req.get("voice", "Jakob")
101
 
102
  ids, attn = build_prompt(text, voice)
 
 
103
  offset_len = ids.size(1) # wie viele Tokens existieren schon
 
104
  buf = []
105
 
106
  while True:
107
- # --- Mini‑Generate (Cache Disabled) -------------------------------------------
 
 
108
  gen = model.generate(
109
- input_ids = ids,
110
- attention_mask = attn,
111
- past_key_values = None, # Cache disabled
112
- max_new_tokens = 1,
113
  logits_processor=[masker],
114
- do_sample=False, temperature=0.7, top_p=0.95,
115
- use_cache=False, # Cache disabled
116
  return_dict_in_generate=True,
117
  return_legacy_cache=True
118
  )
 
119
 
120
  # ----- neue Tokens heraus schneiden --------------------------
121
  seq = gen.sequences[0].tolist()
 
2
  import os, json, torch, asyncio
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StaticCache # Added StaticCache
6
  from snac import SNAC
7
 
8
  # 0) Login + Device ---------------------------------------------------
 
100
  voice = req.get("voice", "Jakob")
101
 
102
  ids, attn = build_prompt(text, voice)
103
+ # Initialized StaticCache
104
+ past = StaticCache(config=model.config, max_batch_size=1, max_cache_len=ids.size(1) + CHUNK_TOKENS, device=device, dtype=model.dtype)
105
  offset_len = ids.size(1) # wie viele Tokens existieren schon
106
+ last_tok = None # Initialized last_tok
107
  buf = []
108
 
109
  while True:
110
+ print(f"DEBUG: Before generate - past is None: {past is None}", flush=True) # Added logging
111
+ print(f"DEBUG: Before generate - type of past: {type(past) if past is not None else 'None'}", flush=True) # Added logging
112
+ # --- Mini‑Generate (Cache Re-enabled) -------------------------------------------
113
  gen = model.generate(
114
+ input_ids = ids if past.get_seq_length() == 0 else torch.tensor([[last_tok]], device=device), # Use cache seq length
115
+ attention_mask = attn if past.get_seq_length() == 0 else None, # Use cache seq length
116
+ past_key_values = past, # Use StaticCache instance
117
+ max_new_tokens = 1, # Set max_new_tokens to 1 for debugging cache
118
  logits_processor=[masker],
119
+ do_sample=True, temperature=0.7, top_p=0.95,
120
+ use_cache=True, # Re-enabled cache
121
  return_dict_in_generate=True,
122
  return_legacy_cache=True
123
  )
124
+ print(f"DEBUG: After generate - type of gen.past_key_values: {type(gen.past_key_values)}", flush=True) # Added logging
125
 
126
  # ----- neue Tokens heraus schneiden --------------------------
127
  seq = gen.sequences[0].tolist()