Spaces:
Running
Running
File size: 5,276 Bytes
451ec22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
# main.py
from __future__ import annotations
import os
import io
import torch
import numpy as np
import torchaudio
import nltk
import gradio as gr
from pydub import AudioSegment
from transformers import (
SeamlessM4TFeatureExtractor,
SeamlessM4TTokenizer,
SeamlessM4Tv2ForSpeechToText,
AutoTokenizer,
AutoFeatureExtractor
)
from parler_tts import ParlerTTSForConditionalGeneration
nltk.download('punkt')
# === CONFIG ===
HF_TOKEN = os.getenv("HF_TOKEN")
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
SAMPLE_RATE = 16000
DEFAULT_TARGET_LANGUAGE = "Hindi"
# === Load translation model ===
trans_model = SeamlessM4Tv2ForSpeechToText.from_pretrained(
"ai4bharat/indic-seamless", torch_dtype=torch_dtype, token=HF_TOKEN
).to(device)
processor = SeamlessM4TFeatureExtractor.from_pretrained("ai4bharat/indic-seamless", token=HF_TOKEN)
tokenizer = SeamlessM4TTokenizer.from_pretrained("ai4bharat/indic-seamless", token=HF_TOKEN)
# === Load TTS models ===
tts_repo = "ai4bharat/indic-parler-tts-pretrained"
tts_finetuned_repo = "ai4bharat/indic-parler-tts"
tts_model = ParlerTTSForConditionalGeneration.from_pretrained(
tts_repo, attn_implementation="eager", torch_dtype=torch_dtype
).to(device)
tts_finetuned_model = ParlerTTSForConditionalGeneration.from_pretrained(
tts_finetuned_repo, attn_implementation="eager", torch_dtype=torch_dtype
).to(device)
desc_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
text_tokenizer = AutoTokenizer.from_pretrained(tts_repo)
tts_sampling_rate = tts_model.audio_encoder.config.sampling_rate
# === Utilities ===
def numpy_to_mp3(audio_array, sampling_rate):
if np.issubdtype(audio_array.dtype, np.floating):
audio_array = (audio_array / np.max(np.abs(audio_array))) * 32767
audio_array = audio_array.astype(np.int16)
segment = AudioSegment(
audio_array.tobytes(),
frame_rate=sampling_rate,
sample_width=audio_array.dtype.itemsize,
channels=1
)
mp3_io = io.BytesIO()
segment.export(mp3_io, format="mp3", bitrate="320k")
return mp3_io.getvalue()
def chunk_text(text, max_words=25):
sentences = nltk.sent_tokenize(text)
chunks, curr = [], ""
for s in sentences:
candidate = f"{curr} {s}".strip()
if len(candidate.split()) > max_words:
if curr: chunks.append(curr)
curr = s
else:
curr = candidate
if curr: chunks.append(curr)
return chunks
# === Translation ===
def translate_audio(input_audio, target_language):
audio, orig_sr = torchaudio.load(input_audio)
audio = torchaudio.functional.resample(audio, orig_sr, SAMPLE_RATE)
inputs = processor(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt").to(device, dtype=torch_dtype)
target_lang_code = "hin" # default Hindi, change as needed
gen_ids = trans_model.generate(**inputs, tgt_lang=target_lang_code)[0]
return tokenizer.decode(gen_ids, skip_special_tokens=True)
# === TTS generation ===
def generate_tts(text, description, use_finetuned=False):
model = tts_finetuned_model if use_finetuned else tts_model
inputs = desc_tokenizer(description, return_tensors="pt").to(device)
chunks = chunk_text(text)
all_audio = []
for chunk in chunks:
prompt = text_tokenizer(chunk, return_tensors="pt").to(device)
gen = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
prompt_input_ids=prompt.input_ids,
prompt_attention_mask=prompt.attention_mask,
do_sample=True,
return_dict_in_generate=True
)
if hasattr(gen, 'sequences') and hasattr(gen, 'audios_length'):
audio = gen.sequences[0, :gen.audios_length[0]]
audio_np = audio.float().cpu().numpy().flatten()
all_audio.append(audio_np)
combined = np.concatenate(all_audio)
return numpy_to_mp3(combined, sampling_rate=tts_sampling_rate)
# === Gradio UI ===
with gr.Blocks() as demo:
gr.Markdown("## 🎙️ Speech-to-Text → Text-to-Speech Demo")
with gr.Row():
with gr.Column():
input_audio = gr.Audio(label="Upload or record audio", type="filepath")
target_language = gr.Textbox(label="Target language (default Hindi)", value="Hindi")
btn_translate = gr.Button("Translate to text")
with gr.Column():
translated_text = gr.Textbox(label="Translated text")
btn_translate.click(
translate_audio,
inputs=[input_audio, target_language],
outputs=translated_text
)
with gr.Row():
with gr.Column():
voice_desc = gr.Textbox(label="Voice description", value="A calm, neutral Indian voice, clear audio.")
use_finetuned = gr.Checkbox(label="Use fine-tuned TTS", value=True)
btn_tts = gr.Button("Generate speech")
with gr.Column():
generated_audio = gr.Audio(label="Generated speech", format="mp3", autoplay=True)
btn_tts.click(
generate_tts,
inputs=[translated_text, voice_desc, use_finetuned],
outputs=generated_audio
)
demo.launch(share=True)
|