Futuresony commited on
Commit
f75518e
·
verified ·
1 Parent(s): af87fba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -17
app.py CHANGED
@@ -1,41 +1,47 @@
1
  import gradio as gr
2
  import torch
3
  import torchaudio
 
4
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
 
6
- # Load the Swahili ASR model
7
  model_name = "Futuresony/Future-sw_ASR-24-02-2025"
8
  processor = Wav2Vec2Processor.from_pretrained(model_name)
9
  model = Wav2Vec2ForCTC.from_pretrained(model_name)
10
 
11
- # Function to process live audio stream
12
- def transcribe_live(microphone_audio):
13
- speech_array, sample_rate = torchaudio.load(microphone_audio)
 
14
 
15
- # Resample to 16kHz
16
- resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
17
- speech_array = resampler(speech_array).squeeze().numpy()
18
 
19
- # Process and transcribe
20
- input_values = processor(speech_array, sampling_rate=16000, return_tensors="pt").input_values
 
 
 
 
 
21
  with torch.no_grad():
22
  logits = model(input_values).logits
23
  predicted_ids = torch.argmax(logits, dim=-1)
24
-
25
- # Decode the text
26
  transcription = processor.batch_decode(predicted_ids)[0]
27
  return transcription
28
 
29
- # Create Gradio interface with live microphone input
30
  interface = gr.Interface(
31
  fn=transcribe_live,
32
- inputs=gr.Audio(sources=["microphone"], type="filepath"),
33
- outputs="text",
34
  live=True,
35
- title="Live Swahili ASR Transcription",
36
- description="Speak into your microphone, and the model will transcribe in real-time.",
37
  )
38
 
39
- # Launch the app
40
  if __name__ == "__main__":
41
  interface.launch()
 
1
  import gradio as gr
2
  import torch
3
  import torchaudio
4
+ import numpy as np
5
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
6
 
7
+ # Load your trained Swahili ASR model
8
  model_name = "Futuresony/Future-sw_ASR-24-02-2025"
9
  processor = Wav2Vec2Processor.from_pretrained(model_name)
10
  model = Wav2Vec2ForCTC.from_pretrained(model_name)
11
 
12
+ # Process microphone input in real-time
13
+ def transcribe_live(audio):
14
+ if audio is None:
15
+ return ""
16
 
17
+ # Convert NumPy array to PyTorch tensor
18
+ speech_array = torch.from_numpy(audio).float()
 
19
 
20
+ # Resample audio to 16kHz (if needed)
21
+ sample_rate = 16000 # Since streaming provides 16kHz by default
22
+
23
+ # Process input
24
+ input_values = processor(speech_array, sampling_rate=sample_rate, return_tensors="pt").input_values
25
+
26
+ # Predict transcription
27
  with torch.no_grad():
28
  logits = model(input_values).logits
29
  predicted_ids = torch.argmax(logits, dim=-1)
30
+
31
+ # Decode text
32
  transcription = processor.batch_decode(predicted_ids)[0]
33
  return transcription
34
 
35
+ # Create Gradio interface with real-time streaming
36
  interface = gr.Interface(
37
  fn=transcribe_live,
38
+ inputs=gr.Audio(streaming=True, type="numpy"), # Live streaming input
39
+ outputs=gr.Textbox(label="Live Transcription"),
40
  live=True,
41
+ title="Live Swahili ASR Streaming",
42
+ description="Talk and see real-time Swahili subtitles appear below!",
43
  )
44
 
45
+ # Launch the live streaming ASR app
46
  if __name__ == "__main__":
47
  interface.launch()