rohanmiriyala commited on
Commit
a05081b
·
verified ·
1 Parent(s): e6699ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -115
app.py CHANGED
@@ -1,144 +1,242 @@
1
- # main.py
2
  from __future__ import annotations
 
3
  import os
4
- import io
 
 
5
  import torch
6
- import numpy as np
7
  import torchaudio
8
- import nltk
9
- import gradio as gr
10
- from pydub import AudioSegment
11
 
12
  from transformers import (
13
  SeamlessM4TFeatureExtractor,
14
  SeamlessM4TTokenizer,
15
  SeamlessM4Tv2ForSpeechToText,
16
- AutoTokenizer,
17
- AutoFeatureExtractor
18
  )
19
- from parler_tts import ParlerTTSForConditionalGeneration
20
 
21
- nltk.download('punkt')
 
 
 
 
 
 
 
 
22
 
23
- # === CONFIG ===
24
- HF_TOKEN = os.getenv("HF_TOKEN")
25
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
26
  torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
27
- SAMPLE_RATE = 16000
 
 
 
 
 
 
 
 
28
  DEFAULT_TARGET_LANGUAGE = "Hindi"
29
 
30
- # === Load translation model ===
31
- trans_model = SeamlessM4Tv2ForSpeechToText.from_pretrained(
32
- "ai4bharat/indic-seamless", torch_dtype=torch_dtype, token=HF_TOKEN
33
- ).to(device)
34
- processor = SeamlessM4TFeatureExtractor.from_pretrained("ai4bharat/indic-seamless", token=HF_TOKEN)
35
- tokenizer = SeamlessM4TTokenizer.from_pretrained("ai4bharat/indic-seamless", token=HF_TOKEN)
36
-
37
- # === Load TTS models ===
38
- tts_repo = "ai4bharat/indic-parler-tts-pretrained"
39
- tts_finetuned_repo = "ai4bharat/indic-parler-tts"
40
- tts_model = ParlerTTSForConditionalGeneration.from_pretrained(
41
- tts_repo, attn_implementation="eager", torch_dtype=torch_dtype
42
- ).to(device)
43
- tts_finetuned_model = ParlerTTSForConditionalGeneration.from_pretrained(
44
- tts_finetuned_repo, attn_implementation="eager", torch_dtype=torch_dtype
45
- ).to(device)
46
- desc_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
47
- text_tokenizer = AutoTokenizer.from_pretrained(tts_repo)
48
-
49
- tts_sampling_rate = tts_model.audio_encoder.config.sampling_rate
50
-
51
- # === Utilities ===
52
- def numpy_to_mp3(audio_array, sampling_rate):
53
- if np.issubdtype(audio_array.dtype, np.floating):
54
- audio_array = (audio_array / np.max(np.abs(audio_array))) * 32767
55
- audio_array = audio_array.astype(np.int16)
56
- segment = AudioSegment(
57
- audio_array.tobytes(),
58
- frame_rate=sampling_rate,
59
- sample_width=audio_array.dtype.itemsize,
60
- channels=1
61
- )
62
- mp3_io = io.BytesIO()
63
- segment.export(mp3_io, format="mp3", bitrate="320k")
64
- return mp3_io.getvalue()
65
-
66
- def chunk_text(text, max_words=25):
67
- sentences = nltk.sent_tokenize(text)
68
- chunks, curr = [], ""
69
- for s in sentences:
70
- candidate = f"{curr} {s}".strip()
71
- if len(candidate.split()) > max_words:
72
- if curr: chunks.append(curr)
73
- curr = s
74
- else:
75
- curr = candidate
76
- if curr: chunks.append(curr)
77
- return chunks
78
-
79
- # === Translation ===
80
- def translate_audio(input_audio, target_language):
81
- audio, orig_sr = torchaudio.load(input_audio)
82
- audio = torchaudio.functional.resample(audio, orig_sr, SAMPLE_RATE)
83
- inputs = processor(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt").to(device, dtype=torch_dtype)
84
- target_lang_code = "hin" # default Hindi, change as needed
85
- gen_ids = trans_model.generate(**inputs, tgt_lang=target_lang_code)[0]
86
- return tokenizer.decode(gen_ids, skip_special_tokens=True)
87
-
88
- # === TTS generation ===
89
- def generate_tts(text, description, use_finetuned=False):
90
- model = tts_finetuned_model if use_finetuned else tts_model
91
- inputs = desc_tokenizer(description, return_tensors="pt").to(device)
92
- chunks = chunk_text(text)
93
-
94
- all_audio = []
95
- for chunk in chunks:
96
- prompt = text_tokenizer(chunk, return_tensors="pt").to(device)
97
- gen = model.generate(
98
- input_ids=inputs.input_ids,
99
- attention_mask=inputs.attention_mask,
100
- prompt_input_ids=prompt.input_ids,
101
- prompt_attention_mask=prompt.attention_mask,
102
- do_sample=True,
103
- return_dict_in_generate=True
104
- )
105
- if hasattr(gen, 'sequences') and hasattr(gen, 'audios_length'):
106
- audio = gen.sequences[0, :gen.audios_length[0]]
107
- audio_np = audio.float().cpu().numpy().flatten()
108
- all_audio.append(audio_np)
109
- combined = np.concatenate(all_audio)
110
- return numpy_to_mp3(combined, sampling_rate=tts_sampling_rate)
111
-
112
- # === Gradio UI ===
113
- with gr.Blocks() as demo:
114
- gr.Markdown("## 🎙️ Speech-to-Text → Text-to-Speech Demo")
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  with gr.Row():
117
  with gr.Column():
118
- input_audio = gr.Audio(label="Upload or record audio", type="filepath")
119
- target_language = gr.Textbox(label="Target language (default Hindi)", value="Hindi")
120
- btn_translate = gr.Button("Translate to text")
 
 
 
 
 
 
 
 
 
 
121
  with gr.Column():
122
- translated_text = gr.Textbox(label="Translated text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- btn_translate.click(
125
- translate_audio,
126
- inputs=[input_audio, target_language],
127
- outputs=translated_text
 
128
  )
129
 
 
130
  with gr.Row():
131
  with gr.Column():
132
- voice_desc = gr.Textbox(label="Voice description", value="A calm, neutral Indian voice, clear audio.")
133
- use_finetuned = gr.Checkbox(label="Use fine-tuned TTS", value=True)
134
- btn_tts = gr.Button("Generate speech")
 
 
 
 
 
 
 
 
 
 
 
135
  with gr.Column():
136
- generated_audio = gr.Audio(label="Generated speech", format="mp3", autoplay=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- btn_tts.click(
139
- generate_tts,
140
- inputs=[translated_text, voice_desc, use_finetuned],
141
- outputs=generated_audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  )
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  demo.launch(share=True)
 
 
1
  from __future__ import annotations
2
+
3
  import os
4
+
5
+ import gradio as gr
6
+ import spaces
7
  import torch
 
8
  import torchaudio
 
 
 
9
 
10
  from transformers import (
11
  SeamlessM4TFeatureExtractor,
12
  SeamlessM4TTokenizer,
13
  SeamlessM4Tv2ForSpeechToText,
 
 
14
  )
 
15
 
16
+ from lang_list import (
17
+ ASR_TARGET_LANGUAGE_NAMES,
18
+ LANGUAGE_NAME_TO_CODE,
19
+ S2ST_TARGET_LANGUAGE_NAMES,
20
+ S2TT_TARGET_LANGUAGE_NAMES,
21
+ T2ST_TARGET_LANGUAGE_NAMES,
22
+ TEXT_SOURCE_LANGUAGE_NAMES,
23
+ )
24
+
25
 
26
+ DESCRIPTION = """\
27
+ ### **IndicSeamless: Speech-to-Text Translation Model for Indian Languages** 🎙️➡️📜
28
+ This Gradio demo showcases **IndicSeamless**, a fine-tuned **SeamlessM4T-v2-large** model for **speech-to-text translation** across **13 Indian languages and English**. Trained on **BhasaAnuvaad**, the largest open-source speech translation dataset for Indian languages, it delivers **accurate and robust translations** across diverse linguistic and acoustic conditions.
29
+ 🔗 **Model Checkpoint:** [ai4bharat/indic-seamless](https://huggingface.co/ai4bharat/indic-seamless)
30
+ #### **How to Use:**
31
+ 1. **Upload or record** an audio clip in any supported Indian language.
32
+ 2. Click **"Translate"** to generate the corresponding text in the target language.
33
+ 3. View or copy the output for further use.
34
+ 🚀 Try it out and experience seamless speech translation for Indian languages!
35
+ """
36
+
37
+ hf_token = os.getenv("HF_TOKEN")
38
+ device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
39
  torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
40
+
41
+ model = SeamlessM4Tv2ForSpeechToText.from_pretrained("ai4bharat/indic-seamless", torch_dtype=torch_dtype, token=hf_token).to(device)
42
+ processor = SeamlessM4TFeatureExtractor.from_pretrained("ai4bharat/indic-seamless", token=hf_token)
43
+ tokenizer = SeamlessM4TTokenizer.from_pretrained("ai4bharat/indic-seamless", token=hf_token)
44
+
45
+ CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1" and torch.cuda.is_available()
46
+
47
+ AUDIO_SAMPLE_RATE = 16000
48
+ MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
49
  DEFAULT_TARGET_LANGUAGE = "Hindi"
50
 
51
+ def preprocess_audio(input_audio: str) -> None:
52
+ arr, org_sr = torchaudio.load(input_audio)
53
+ new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE)
54
+ max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
55
+ if new_arr.shape[1] > max_length:
56
+ new_arr = new_arr[:, :max_length]
57
+ gr.Warning(f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used.")
58
+ torchaudio.save(input_audio, new_arr, sample_rate=int(AUDIO_SAMPLE_RATE))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ @spaces.GPU
61
+ def run_s2tt(input_audio: str, source_language: str, target_language: str) -> str:
62
+ # preprocess_audio(input_audio)
63
+ # source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
64
+ target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
65
+
66
+ input_audio, orig_freq = torchaudio.load(input_audio)
67
+ input_audio = torchaudio.functional.resample(input_audio, orig_freq=orig_freq, new_freq=16000)
68
+ audio_inputs= processor(input_audio, sampling_rate=16000, return_tensors="pt").to(device=device, dtype=torch_dtype)
69
+
70
+ text_out = model.generate(**audio_inputs, tgt_lang=target_language_code)[0].float().cpu().numpy().squeeze()
71
+
72
+ return tokenizer.decode(text_out, clean_up_tokenization_spaces=True, skip_special_tokens=True)
73
+
74
+ @spaces.GPU
75
+ def run_asr(input_audio: str, target_language: str) -> str:
76
+ # preprocess_audio(input_audio)
77
+ target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
78
+
79
+ input_audio, orig_freq = torchaudio.load(input_audio)
80
+ input_audio = torchaudio.functional.resample(input_audio, orig_freq=orig_freq, new_freq=16000)
81
+ audio_inputs= processor(input_audio, sampling_rate=16000, return_tensors="pt").to(device=device, dtype=torch_dtype)
82
+
83
+ text_out = model.generate(**audio_inputs, tgt_lang=target_language_code)[0].float().cpu().numpy().squeeze()
84
+
85
+ return tokenizer.decode(text_out, clean_up_tokenization_spaces=True, skip_special_tokens=True)
86
+
87
+
88
+
89
+ with gr.Blocks() as demo_s2st:
90
  with gr.Row():
91
  with gr.Column():
92
+ with gr.Group():
93
+ input_audio = gr.Audio(label="Input speech", type="filepath")
94
+ source_language = gr.Dropdown(
95
+ label="Source language",
96
+ choices=ASR_TARGET_LANGUAGE_NAMES,
97
+ value="English",
98
+ )
99
+ target_language = gr.Dropdown(
100
+ label="Target language",
101
+ choices=S2ST_TARGET_LANGUAGE_NAMES,
102
+ value=DEFAULT_TARGET_LANGUAGE,
103
+ )
104
+ btn = gr.Button("Translate")
105
  with gr.Column():
106
+ with gr.Group():
107
+ output_audio = gr.Audio(
108
+ label="Translated speech",
109
+ autoplay=False,
110
+ streaming=False,
111
+ type="numpy",
112
+ )
113
+ output_text = gr.Textbox(label="Translated text")
114
+
115
+ with gr.Blocks() as demo_s2tt:
116
+ with gr.Row():
117
+ with gr.Column():
118
+ with gr.Group():
119
+ input_audio = gr.Audio(label="Input speech", type="filepath")
120
+ source_language = gr.Dropdown(
121
+ label="Source language",
122
+ choices=ASR_TARGET_LANGUAGE_NAMES,
123
+ value="English",
124
+ )
125
+ target_language = gr.Dropdown(
126
+ label="Target language",
127
+ choices=S2TT_TARGET_LANGUAGE_NAMES,
128
+ value=DEFAULT_TARGET_LANGUAGE,
129
+ )
130
+ btn = gr.Button("Translate")
131
+ with gr.Column():
132
+ output_text = gr.Textbox(label="Translated text")
133
+
134
+ gr.Examples(
135
+ examples=[
136
+ ["assets/Bengali.wav", "Bengali", "English"],
137
+ ["assets/Gujarati.wav", "Gujarati", "Hindi"],
138
+ ["assets/Punjabi.wav", "Punjabi", "Hindi"],
139
+
140
+ ],
141
+ inputs=[input_audio, source_language, target_language],
142
+ outputs=output_text,
143
+ fn=run_s2tt,
144
+ cache_examples=CACHE_EXAMPLES,
145
+ api_name=False,
146
+ )
147
 
148
+ btn.click(
149
+ fn=run_s2tt,
150
+ inputs=[input_audio, source_language, target_language],
151
+ outputs=output_text,
152
+ api_name="s2tt",
153
  )
154
 
155
+ with gr.Blocks() as demo_t2st:
156
  with gr.Row():
157
  with gr.Column():
158
+ with gr.Group():
159
+ input_text = gr.Textbox(label="Input text")
160
+ with gr.Row():
161
+ source_language = gr.Dropdown(
162
+ label="Source language",
163
+ choices=TEXT_SOURCE_LANGUAGE_NAMES,
164
+ value="English",
165
+ )
166
+ target_language = gr.Dropdown(
167
+ label="Target language",
168
+ choices=T2ST_TARGET_LANGUAGE_NAMES,
169
+ value=DEFAULT_TARGET_LANGUAGE,
170
+ )
171
+ btn = gr.Button("Translate")
172
  with gr.Column():
173
+ with gr.Group():
174
+ output_audio = gr.Audio(
175
+ label="Translated speech",
176
+ autoplay=False,
177
+ streaming=False,
178
+ type="numpy",
179
+ )
180
+ output_text = gr.Textbox(label="Translated text")
181
+
182
+
183
+
184
+ with gr.Blocks() as demo_asr:
185
+ with gr.Row():
186
+ with gr.Column():
187
+ with gr.Group():
188
+ input_audio = gr.Audio(label="Input speech", type="filepath")
189
+ target_language = gr.Dropdown(
190
+ label="Target language",
191
+ choices=ASR_TARGET_LANGUAGE_NAMES,
192
+ value=DEFAULT_TARGET_LANGUAGE,
193
+ )
194
+ btn = gr.Button("Transcribe")
195
+ with gr.Column():
196
+ output_text = gr.Textbox(label="Transcribed text")
197
+
198
+ gr.Examples(
199
+ examples=[
200
+ ["assets/Bengali.wav", "Bengali", "English"],
201
+ ["assets/Gujarati.wav", "Gujarati", "Hindi"],
202
+ ["assets/Punjabi.wav", "Punjabi", "Hindi"],
203
 
204
+ ],
205
+ inputs=[input_audio, target_language],
206
+ outputs=output_text,
207
+ fn=run_asr,
208
+ cache_examples=CACHE_EXAMPLES,
209
+ api_name=False,
210
+ )
211
+
212
+ btn.click(
213
+ fn=run_asr,
214
+ inputs=[input_audio, target_language],
215
+ outputs=output_text,
216
+ api_name="asr",
217
+ )
218
+
219
+
220
+ with gr.Blocks(css="style.css") as demo:
221
+ gr.Markdown(DESCRIPTION)
222
+ gr.DuplicateButton(
223
+ value="Duplicate Space for private use",
224
+ elem_id="duplicate-button",
225
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
226
  )
227
 
228
+ with gr.Tabs():
229
+ # with gr.Tab(label="S2ST"):
230
+ # demo_s2st.render()
231
+ with gr.Tab(label="S2TT"):
232
+ demo_s2tt.render()
233
+ # with gr.Tab(label="T2ST"):
234
+ # demo_t2st.render()
235
+ # with gr.Tab(label="T2TT"):
236
+ # demo_t2tt.render()
237
+ with gr.Tab(label="ASR"):
238
+ demo_asr.render()
239
+
240
+
241
+
242
  demo.launch(share=True)