Noumida commited on
Commit
069b4ed
·
verified ·
1 Parent(s): 3bb090a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -32
app.py CHANGED
@@ -1,10 +1,11 @@
1
  from __future__ import annotations
2
- import os
3
- import gradio as gr
4
  import torch
5
  import torchaudio
 
6
  import spaces
7
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoModelForCTC
 
 
8
 
9
  LANGUAGE_NAME_TO_CODE = {
10
  "Assamese": "as", "Bengali": "bn", "Bodo": "br", "Dogri": "doi",
@@ -15,55 +16,48 @@ LANGUAGE_NAME_TO_CODE = {
15
  "Telugu": "te", "Urdu": "ur"
16
  }
17
 
18
- DESCRIPTION = "IndicConformer-600M Multilingual ASR (CTC + RNNT)"
19
-
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
- # Load processor and models
23
- processor = AutoProcessor.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True)
24
-
25
- model_ctc = AutoModelForCTC.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True).to(device)
26
- model_ctc.eval()
27
-
28
- model_rnnt = AutoModelForSpeechSeq2Seq.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True).to(device)
29
- model_rnnt.eval()
30
 
31
  @spaces.GPU
32
  def transcribe_ctc_and_rnnt(audio_path, language_name):
33
- lang_id = LANGUAGE_NAME_TO_CODE[language_name]
34
 
 
35
  waveform, sr = torchaudio.load(audio_path)
36
  waveform = waveform.mean(dim=0, keepdim=True) if waveform.shape[0] > 1 else waveform
37
- waveform = torchaudio.functional.resample(waveform, sr, 16000)
38
-
39
- input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_values.to(device)
40
-
41
- with torch.no_grad():
42
- # CTC decoding
43
- ctc_logits = model_ctc(input_values).logits
44
- ctc_ids = torch.argmax(ctc_logits, dim=-1)
45
- ctc_output = processor.batch_decode(ctc_ids)[0]
46
 
47
- # RNNT decoding
48
- rnnt_output = processor.batch_decode(model_rnnt.generate(input_values, decoder_input_ids=torch.tensor([[processor.tokenizer.lang2id[lang_id]]]).to(device)))[0]
 
 
 
 
 
49
 
50
- return ctc_output.strip(), rnnt_output.strip()
51
 
52
- # Gradio interface
53
  with gr.Blocks() as demo:
54
  gr.Markdown(f"## {DESCRIPTION}")
55
  with gr.Row():
56
  with gr.Column():
57
- audio = gr.Audio(label="Upload or record audio", type="filepath")
58
  lang = gr.Dropdown(
59
- label="Select language",
60
- choices=LANGUAGE_NAME_TO_CODE.keys(),
61
  value="Hindi"
62
  )
63
  transcribe_btn = gr.Button("Transcribe (CTC + RNNT)")
64
  with gr.Column():
65
- ctc_output = gr.Textbox(label="CTC Transcription")
66
- rnnt_output = gr.Textbox(label="RNNT Transcription")
 
 
67
 
68
  transcribe_btn.click(fn=transcribe_ctc_and_rnnt, inputs=[audio, lang], outputs=[ctc_output, rnnt_output])
69
 
 
1
  from __future__ import annotations
 
 
2
  import torch
3
  import torchaudio
4
+ import gradio as gr
5
  import spaces
6
+ from transformers import AutoModel
7
+
8
+ DESCRIPTION = "IndicConformer-600M Multilingual ASR (CTC + RNNT)"
9
 
10
  LANGUAGE_NAME_TO_CODE = {
11
  "Assamese": "as", "Bengali": "bn", "Bodo": "br", "Dogri": "doi",
 
16
  "Telugu": "te", "Urdu": "ur"
17
  }
18
 
 
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
+ # Load Indic Conformer model (assumes custom forward handles decoding strategy)
22
+ model = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True).to(device)
23
+ model.eval()
 
 
 
 
 
24
 
25
  @spaces.GPU
26
  def transcribe_ctc_and_rnnt(audio_path, language_name):
27
+ lang_code = LANGUAGE_NAME_TO_CODE[language_name]
28
 
29
+ # Load and preprocess audio
30
  waveform, sr = torchaudio.load(audio_path)
31
  waveform = waveform.mean(dim=0, keepdim=True) if waveform.shape[0] > 1 else waveform
32
+ waveform = torchaudio.functional.resample(waveform, sr, 16000).to(device)
 
 
 
 
 
 
 
 
33
 
34
+ try:
35
+ # Assume model's forward method takes waveform, language code, and decoding type
36
+ with torch.no_grad():
37
+ transcription_ctc = model(waveform, lang_code, "ctc")
38
+ transcription_rnnt = model(waveform, lang_code, "rnnt")
39
+ except Exception as e:
40
+ return f"Error: {str(e)}", ""
41
 
42
+ return transcription_ctc.strip(), transcription_rnnt.strip()
43
 
44
+ # Gradio UI
45
  with gr.Blocks() as demo:
46
  gr.Markdown(f"## {DESCRIPTION}")
47
  with gr.Row():
48
  with gr.Column():
49
+ audio = gr.Audio(label="Upload or Record Audio", type="filepath")
50
  lang = gr.Dropdown(
51
+ label="Select Language",
52
+ choices=list(LANGUAGE_NAME_TO_CODE.keys()),
53
  value="Hindi"
54
  )
55
  transcribe_btn = gr.Button("Transcribe (CTC + RNNT)")
56
  with gr.Column():
57
+ gr.Markdown("### CTC Transcription")
58
+ ctc_output = gr.Textbox(lines=3)
59
+ gr.Markdown("### RNNT Transcription")
60
+ rnnt_output = gr.Textbox(lines=3)
61
 
62
  transcribe_btn.click(fn=transcribe_ctc_and_rnnt, inputs=[audio, lang], outputs=[ctc_output, rnnt_output])
63