Tomtom84 commited on
Commit
06a62cb
·
verified ·
1 Parent(s): 53012c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -9
app.py CHANGED
@@ -247,6 +247,13 @@ async def load_models_startup():
247
  print("StoppingCriteria initialized.")
248
 
249
  print("✅ Modelle geladen und bereit!", flush=True)
 
 
 
 
 
 
 
250
 
251
  @app.get("/")
252
  def hello():
@@ -294,15 +301,20 @@ async def tts(ws: WebSocket):
294
 
295
  print("Starting generation in background thread...")
296
  await asyncio.to_thread(
297
- model.generate,
298
- input_ids=ids,
299
- attention_mask=attn,
300
- max_new_tokens=1500,
301
- logits_processor=[masker],
302
- stopping_criteria=stopping_criteria,
303
- do_sample=False, # Using greedy decoding
304
- use_cache=True,
305
- streamer=streamer
 
 
 
 
 
306
  )
307
  print("Generation thread finished.")
308
 
 
247
  print("StoppingCriteria initialized.")
248
 
249
  print("✅ Modelle geladen und bereit!", flush=True)
250
+ print(f"Tokenizer EOS ID: {tok.eos_token_id}")
251
+ print(f"Model Config EOS ID: {model.config.eos_token_id}")
252
+ print(f"Constant EOS_TOKEN: {EOS_TOKEN}")
253
+ if tok.eos_token_id != EOS_TOKEN or model.config.eos_token_id != EOS_TOKEN:
254
+ print("⚠️ WARNING: EOS_TOKEN constant might not match model/tokenizer configuration!")
255
+ # Consider updating EOS_TOKEN if they differ, e.g.:
256
+ # EOS_TOKEN = model.config.eos_token_id
257
 
258
  @app.get("/")
259
  def hello():
 
301
 
302
  print("Starting generation in background thread...")
303
  await asyncio.to_thread(
304
+ model.generate,
305
+ input_ids=ids,
306
+ attention_mask=attn,
307
+ max_new_tokens=2500, # Keep or increase later if needed
308
+ logits_processor=[masker],
309
+ stopping_criteria=stopping_criteria,
310
+ # --- Changes ---
311
+ do_sample=True, # Enable sampling
312
+ temperature=0.6, # Introduce some randomness (adjust as needed)
313
+ top_p=0.9, # Focus sampling on more likely tokens (adjust as needed)
314
+ repetition_penalty=1.15, # Penalize recently generated tokens (adjust > 1.0)
315
+ # --- End Changes ---
316
+ use_cache=True,
317
+ streamer=streamer
318
  )
319
  print("Generation thread finished.")
320