Futuresony commited on
Commit
7b1a576
·
verified ·
1 Parent(s): d7c7caa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -36
app.py CHANGED
@@ -1,76 +1,79 @@
1
  import gradio as gr
2
- from asr import transcribe_auto
 
 
3
  from huggingface_hub import InferenceClient
4
  from ttsmms import download, TTS
5
  from langdetect import detect
6
 
7
- # Initialize text generation client
 
 
 
 
 
8
  client = InferenceClient("Futuresony/future_ai_12_10_2024.gguf")
9
 
10
- # Download and load TTS models for Swahili and English
 
 
 
11
  swahili_dir = download("swh", "./data/swahili")
12
- english_dir = download("eng", "./data/english") # Ensure an English TTS model is available
13
 
14
  swahili_tts = TTS(swahili_dir)
15
  english_tts = TTS(english_dir)
16
 
17
- def is_uncertain(question, response):
18
- """Check if the model's response is unreliable."""
19
- if len(response.split()) < 4 or response.lower() in question.lower():
20
- return True
21
- uncertain_phrases = ["Kulingana na utafiti", "Inaaminika kuwa", "Ninadhani", "It is believed that", "Some people say"]
22
- return any(phrase.lower() in response.lower() for phrase in uncertain_phrases)
23
-
 
 
 
 
 
 
24
  def generate_text(prompt):
25
- """Generate a response from the text generation model."""
26
- messages = [{"role": "user", "content": prompt}]
27
-
28
  response = ""
29
  for message in client.chat_completion(messages, max_tokens=512, stream=True, temperature=0.7, top_p=0.95):
30
  token = message.choices[0].delta.content
31
  response += token
32
-
33
- if is_uncertain(prompt, response):
34
- return "AI is uncertain about the response."
35
-
36
- return response
37
 
38
- # Function to detect language and generate speech
39
  def text_to_speech(text):
40
- lang = detect(text) # Detect language
41
  wav_path = "./output.wav"
42
-
43
- if lang == "sw": # Swahili
44
  swahili_tts.synthesis(text, wav_path=wav_path)
45
- else: # Default to English if not Swahili
46
  english_tts.synthesis(text, wav_path=wav_path)
47
-
48
  return wav_path
49
 
 
50
  def process_audio(audio):
51
- # Step 1: Transcribe the audio
52
- transcription = transcribe_auto(audio)
53
-
54
- # Step 2: Generate text based on the transcription
55
  generated_text = generate_text(transcription)
56
-
57
- # Step 3: Convert the generated text to speech
58
  speech = text_to_speech(generated_text)
59
-
60
  return transcription, generated_text, speech
61
 
62
  # Gradio Interface
63
  with gr.Blocks() as demo:
64
  gr.Markdown("<p align='center' style='font-size: 20px;'>End-to-End ASR, Text Generation, and TTS</p>")
65
  gr.HTML("<center>Upload or record audio. The model will transcribe, generate a response, and read it out.</center>")
66
-
67
  audio_input = gr.Audio(label="Input Audio", type="filepath")
68
  text_output = gr.Textbox(label="Transcription")
69
  generated_text_output = gr.Textbox(label="Generated Text")
70
  audio_output = gr.Audio(label="Output Speech")
71
-
72
  submit_btn = gr.Button("Submit")
73
-
74
  submit_btn.click(
75
  fn=process_audio,
76
  inputs=audio_input,
@@ -78,4 +81,4 @@ with gr.Blocks() as demo:
78
  )
79
 
80
  if __name__ == "__main__":
81
- demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
  from huggingface_hub import InferenceClient
6
  from ttsmms import download, TTS
7
  from langdetect import detect
8
 
9
+ # Load ASR Model
10
+ asr_model_name = "Futuresony/Future-sw_ASR-24-02-2025"
11
+ processor = Wav2Vec2Processor.from_pretrained(asr_model_name)
12
+ asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model_name)
13
+
14
+ # Load Text Generation Model
15
  client = InferenceClient("Futuresony/future_ai_12_10_2024.gguf")
16
 
17
+ def format_prompt(user_input):
18
+ return f"### User: {user_input}\n### Assistant:"
19
+
20
+ # Load TTS Models
21
  swahili_dir = download("swh", "./data/swahili")
22
+ english_dir = download("eng", "./data/english")
23
 
24
  swahili_tts = TTS(swahili_dir)
25
  english_tts = TTS(english_dir)
26
 
27
+ # ASR Function
28
+ def transcribe(audio_file):
29
+ speech_array, sample_rate = torchaudio.load(audio_file)
30
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
31
+ speech_array = resampler(speech_array).squeeze().numpy()
32
+ input_values = processor(speech_array, sampling_rate=16000, return_tensors="pt").input_values
33
+ with torch.no_grad():
34
+ logits = asr_model(input_values).logits
35
+ predicted_ids = torch.argmax(logits, dim=-1)
36
+ transcription = processor.batch_decode(predicted_ids)[0]
37
+ return transcription
38
+
39
+ # Text Generation Function
40
  def generate_text(prompt):
41
+ formatted_prompt = format_prompt(prompt)
42
+ messages = [{"role": "user", "content": formatted_prompt}]
 
43
  response = ""
44
  for message in client.chat_completion(messages, max_tokens=512, stream=True, temperature=0.7, top_p=0.95):
45
  token = message.choices[0].delta.content
46
  response += token
47
+ return response.strip()
 
 
 
 
48
 
49
+ # TTS Function
50
  def text_to_speech(text):
51
+ lang = detect(text)
52
  wav_path = "./output.wav"
53
+ if lang == "sw":
 
54
  swahili_tts.synthesis(text, wav_path=wav_path)
55
+ else:
56
  english_tts.synthesis(text, wav_path=wav_path)
 
57
  return wav_path
58
 
59
+ # Combined Processing Function
60
  def process_audio(audio):
61
+ transcription = transcribe(audio)
 
 
 
62
  generated_text = generate_text(transcription)
 
 
63
  speech = text_to_speech(generated_text)
 
64
  return transcription, generated_text, speech
65
 
66
  # Gradio Interface
67
  with gr.Blocks() as demo:
68
  gr.Markdown("<p align='center' style='font-size: 20px;'>End-to-End ASR, Text Generation, and TTS</p>")
69
  gr.HTML("<center>Upload or record audio. The model will transcribe, generate a response, and read it out.</center>")
70
+
71
  audio_input = gr.Audio(label="Input Audio", type="filepath")
72
  text_output = gr.Textbox(label="Transcription")
73
  generated_text_output = gr.Textbox(label="Generated Text")
74
  audio_output = gr.Audio(label="Output Speech")
 
75
  submit_btn = gr.Button("Submit")
76
+
77
  submit_btn.click(
78
  fn=process_audio,
79
  inputs=audio_input,
 
81
  )
82
 
83
  if __name__ == "__main__":
84
+ demo.launch()