hashhac commited on
Commit
35e9187
·
1 Parent(s): 3f5ef1f

have int 16 outputs

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -90,18 +90,22 @@ def text_to_speech(text):
90
 
91
  # Generate speech with SpeechT5
92
  with torch.no_grad():
 
 
 
93
  # Generate speech
94
  speech = tts_model.generate_speech(
95
  inputs["input_ids"].to(device),
96
- speaker_embeddings.to(device),
97
  vocoder=tts_vocoder
98
  )
99
 
100
  # Convert to numpy array
101
- audio_array = speech.cpu().numpy()
 
102
 
103
- # Normalize and convert to int16
104
- audio_array = (audio_array / np.max(np.abs(audio_array)) * 32767).astype(np.int16)
105
 
106
  # Reshape for fastrtc
107
  audio_array = audio_array.reshape(1, -1)
 
90
 
91
  # Generate speech with SpeechT5
92
  with torch.no_grad():
93
+ # Convert speaker embeddings to correct dtype and move to device
94
+ speaker_embeddings_device = speaker_embeddings.to(device).to(torch_dtype)
95
+
96
  # Generate speech
97
  speech = tts_model.generate_speech(
98
  inputs["input_ids"].to(device),
99
+ speaker_embeddings_device,
100
  vocoder=tts_vocoder
101
  )
102
 
103
  # Convert to numpy array
104
+ # Make sure speech is float32 before any conversion to avoid the error
105
+ audio_array = speech.cpu().numpy().astype(np.float32)
106
 
107
+ # Normalize and convert to int16 for output
108
+ audio_array = (audio_array / np.max(np.abs(audio_array) + 1e-6) * 32767).astype(np.int16)
109
 
110
  # Reshape for fastrtc
111
  audio_array = audio_array.reshape(1, -1)