Audio_Translate / app.py
rohanmiriyala's picture
Create app.py
451ec22 verified
raw
history blame
5.28 kB
# main.py
from __future__ import annotations
import os
import io
import torch
import numpy as np
import torchaudio
import nltk
import gradio as gr
from pydub import AudioSegment
from transformers import (
SeamlessM4TFeatureExtractor,
SeamlessM4TTokenizer,
SeamlessM4Tv2ForSpeechToText,
AutoTokenizer,
AutoFeatureExtractor
)
from parler_tts import ParlerTTSForConditionalGeneration
nltk.download('punkt')
# === CONFIG ===
HF_TOKEN = os.getenv("HF_TOKEN")
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
SAMPLE_RATE = 16000
DEFAULT_TARGET_LANGUAGE = "Hindi"
# === Load translation model ===
trans_model = SeamlessM4Tv2ForSpeechToText.from_pretrained(
"ai4bharat/indic-seamless", torch_dtype=torch_dtype, token=HF_TOKEN
).to(device)
processor = SeamlessM4TFeatureExtractor.from_pretrained("ai4bharat/indic-seamless", token=HF_TOKEN)
tokenizer = SeamlessM4TTokenizer.from_pretrained("ai4bharat/indic-seamless", token=HF_TOKEN)
# === Load TTS models ===
tts_repo = "ai4bharat/indic-parler-tts-pretrained"
tts_finetuned_repo = "ai4bharat/indic-parler-tts"
tts_model = ParlerTTSForConditionalGeneration.from_pretrained(
tts_repo, attn_implementation="eager", torch_dtype=torch_dtype
).to(device)
tts_finetuned_model = ParlerTTSForConditionalGeneration.from_pretrained(
tts_finetuned_repo, attn_implementation="eager", torch_dtype=torch_dtype
).to(device)
desc_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
text_tokenizer = AutoTokenizer.from_pretrained(tts_repo)
tts_sampling_rate = tts_model.audio_encoder.config.sampling_rate
# === Utilities ===
def numpy_to_mp3(audio_array, sampling_rate):
if np.issubdtype(audio_array.dtype, np.floating):
audio_array = (audio_array / np.max(np.abs(audio_array))) * 32767
audio_array = audio_array.astype(np.int16)
segment = AudioSegment(
audio_array.tobytes(),
frame_rate=sampling_rate,
sample_width=audio_array.dtype.itemsize,
channels=1
)
mp3_io = io.BytesIO()
segment.export(mp3_io, format="mp3", bitrate="320k")
return mp3_io.getvalue()
def chunk_text(text, max_words=25):
sentences = nltk.sent_tokenize(text)
chunks, curr = [], ""
for s in sentences:
candidate = f"{curr} {s}".strip()
if len(candidate.split()) > max_words:
if curr: chunks.append(curr)
curr = s
else:
curr = candidate
if curr: chunks.append(curr)
return chunks
# === Translation ===
def translate_audio(input_audio, target_language):
audio, orig_sr = torchaudio.load(input_audio)
audio = torchaudio.functional.resample(audio, orig_sr, SAMPLE_RATE)
inputs = processor(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt").to(device, dtype=torch_dtype)
target_lang_code = "hin" # default Hindi, change as needed
gen_ids = trans_model.generate(**inputs, tgt_lang=target_lang_code)[0]
return tokenizer.decode(gen_ids, skip_special_tokens=True)
# === TTS generation ===
def generate_tts(text, description, use_finetuned=False):
model = tts_finetuned_model if use_finetuned else tts_model
inputs = desc_tokenizer(description, return_tensors="pt").to(device)
chunks = chunk_text(text)
all_audio = []
for chunk in chunks:
prompt = text_tokenizer(chunk, return_tensors="pt").to(device)
gen = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
prompt_input_ids=prompt.input_ids,
prompt_attention_mask=prompt.attention_mask,
do_sample=True,
return_dict_in_generate=True
)
if hasattr(gen, 'sequences') and hasattr(gen, 'audios_length'):
audio = gen.sequences[0, :gen.audios_length[0]]
audio_np = audio.float().cpu().numpy().flatten()
all_audio.append(audio_np)
combined = np.concatenate(all_audio)
return numpy_to_mp3(combined, sampling_rate=tts_sampling_rate)
# === Gradio UI ===
with gr.Blocks() as demo:
gr.Markdown("## πŸŽ™οΈ Speech-to-Text β†’ Text-to-Speech Demo")
with gr.Row():
with gr.Column():
input_audio = gr.Audio(label="Upload or record audio", type="filepath")
target_language = gr.Textbox(label="Target language (default Hindi)", value="Hindi")
btn_translate = gr.Button("Translate to text")
with gr.Column():
translated_text = gr.Textbox(label="Translated text")
btn_translate.click(
translate_audio,
inputs=[input_audio, target_language],
outputs=translated_text
)
with gr.Row():
with gr.Column():
voice_desc = gr.Textbox(label="Voice description", value="A calm, neutral Indian voice, clear audio.")
use_finetuned = gr.Checkbox(label="Use fine-tuned TTS", value=True)
btn_tts = gr.Button("Generate speech")
with gr.Column():
generated_audio = gr.Audio(label="Generated speech", format="mp3", autoplay=True)
btn_tts.click(
generate_tts,
inputs=[translated_text, voice_desc, use_finetuned],
outputs=generated_audio
)
demo.launch(share=True)