hashhac commited on
Commit
be00791
·
1 Parent(s): a70a34d
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -15,17 +15,23 @@ asr_model = SpeechT5ForSpeechToText.from_pretrained("microsoft/speecht5_asr").to
15
  tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
16
 
17
  # Function to convert speech to text
18
- def speech_to_text(audio):
19
- inputs = processor(audio, sampling_rate=16000, return_tensors="pt").input_values.to(device)
 
 
 
 
 
20
  with torch.no_grad():
21
  logits = asr_model(inputs).logits
 
22
  predicted_ids = torch.argmax(logits, dim=-1)
23
  transcription = processor.batch_decode(predicted_ids)[0]
24
  return transcription
25
 
26
  # Function to convert text to speech
27
  def text_to_speech(text):
28
- inputs = processor(text, return_tensors="pt").input_ids.to(device)
29
  with torch.no_grad():
30
  speech = tts_model.generate_speech(inputs)
31
  return speech
 
15
  tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
16
 
17
  # Function to convert speech to text
18
+ def speech_to_text(audio_dict):
19
+ # Extract the audio array from the dictionary
20
+ audio_array = audio_dict["array"]
21
+
22
+ # Pass the audio array directly to the processor
23
+ inputs = processor(audio=audio_array, sampling_rate=16000, return_tensors="pt").input_values.to(device)
24
+
25
  with torch.no_grad():
26
  logits = asr_model(inputs).logits
27
+
28
  predicted_ids = torch.argmax(logits, dim=-1)
29
  transcription = processor.batch_decode(predicted_ids)[0]
30
  return transcription
31
 
32
  # Function to convert text to speech
33
  def text_to_speech(text):
34
+ inputs = processor(text=text, return_tensors="pt").input_ids.to(device)
35
  with torch.no_grad():
36
  speech = tts_model.generate_speech(inputs)
37
  return speech