husseinelsaadi commited on
Commit
4c6a61f
Β·
verified Β·
1 Parent(s): 83570c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -3
app.py CHANGED
@@ -1482,17 +1482,27 @@ bark_voice_preset = "v2/en_speaker_5"
1482
  def bark_tts(text):
1483
  print(f"πŸ” Synthesizing TTS for: {text}")
1484
  inputs = processor_bark(text, return_tensors="pt", voice_preset=bark_voice_preset)
1485
- inputs = {k: v.to(model_bark.device) for k, v in inputs.items()}
1486
- inputs["max_new_tokens"] = 100 # Add this to the input dictionary
 
 
 
1487
  start = time.time()
1488
- speech_values = model_bark.generate(**inputs)
 
 
 
 
 
1489
  print(f"βœ… Bark finished in {round(time.time() - start, 2)}s")
 
1490
  speech = speech_values.cpu().numpy().squeeze()
1491
  speech = (speech * 32767).astype(np.int16)
1492
  temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
1493
  wavfile.write(temp_wav.name, 22050, speech)
1494
  return temp_wav.name
1495
 
 
1496
  # Whisper STT
1497
  print("πŸ” Loading Whisper model...")
1498
  whisper_model = whisper.load_model("base", device="cuda")
 
1482
  def bark_tts(text):
1483
  print(f"πŸ” Synthesizing TTS for: {text}")
1484
  inputs = processor_bark(text, return_tensors="pt", voice_preset=bark_voice_preset)
1485
+ input_ids = inputs["input_ids"].to(model_bark.device)
1486
+ attention_mask = inputs.get("attention_mask", None)
1487
+ if attention_mask is not None:
1488
+ attention_mask = attention_mask.to(model_bark.device)
1489
+
1490
  start = time.time()
1491
+ speech_values = model_bark.generate(
1492
+ input_ids=input_ids,
1493
+ attention_mask=attention_mask,
1494
+ max_new_tokens=100, # βœ… Correctly passed outside inputs
1495
+ pad_token_id=10000 # βœ… Optional to avoid warnings
1496
+ )
1497
  print(f"βœ… Bark finished in {round(time.time() - start, 2)}s")
1498
+
1499
  speech = speech_values.cpu().numpy().squeeze()
1500
  speech = (speech * 32767).astype(np.int16)
1501
  temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
1502
  wavfile.write(temp_wav.name, 22050, speech)
1503
  return temp_wav.name
1504
 
1505
+
1506
  # Whisper STT
1507
  print("πŸ” Loading Whisper model...")
1508
  whisper_model = whisper.load_model("base", device="cuda")