leenag's picture
Update app.py
16df6a6 verified
raw
history blame
3.67 kB
import torch
import soundfile as sf
import uuid
import gradio as gr
import numpy as np
import re
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
# Load model and tokenizers
model_name = "ai4bharat/indic-parler-tts"
device = "cpu"
print("Loading model...")
model = ParlerTTSForConditionalGeneration.from_pretrained(model_name).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
desc_tokenizer = AutoTokenizer.from_pretrained(model.config.text_encoder._name_or_path)
print("Applying dynamic quantization...")
quantized_model = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear},
dtype=torch.qint8
)
# Sentence splitter (splits by full stop, exclamation, or question mark)
def split_text(text, max_len=150):
# First, try to split by sentence punctuation
chunks = re.split(r'(?<=[.!?]) +', text)
# If any chunk is still too long, split further
refined_chunks = []
for chunk in chunks:
if len(chunk) <= max_len:
refined_chunks.append(chunk)
else:
# Break on space while respecting max_len
words = chunk.split()
buffer = []
length = 0
for word in words:
buffer.append(word)
length += len(word) + 1
if length > max_len:
refined_chunks.append(' '.join(buffer))
buffer = []
length = 0
if buffer:
refined_chunks.append(' '.join(buffer))
return refined_chunks
# Main synthesis function
def synthesize(language, text, gender, emotion, speed, pitch, quality):
description = (
f"A native {language.lower()} {gender.lower()} speaker with a {emotion.lower()} and expressive tone, "
f"speaking at a {speed.lower()} rate."
)
description_input = desc_tokenizer(description, return_tensors="pt").to(device)
chunks = split_text(text)
audio_pieces = []
for chunk in chunks:
prompt_input = tokenizer(chunk, return_tensors="pt").to(device)
with torch.no_grad():
generation = quantized_model.generate(
input_ids=description_input.input_ids,
attention_mask=description_input.attention_mask,
prompt_input_ids=prompt_input.input_ids,
prompt_attention_mask=torch.ones_like(prompt_input.input_ids).to(device)
)
audio_chunk = generation.cpu().numpy().squeeze()
audio_pieces.append(audio_chunk)
# Concatenate all audio chunks
final_audio = np.concatenate(audio_pieces)
filename = f"{uuid.uuid4().hex}.wav"
sf.write(filename, final_audio, quantized_model.config.sampling_rate)
return filename
# Gradio Interface
iface = gr.Interface(
fn=synthesize,
inputs=[
gr.Dropdown(["Malayalam", "Hindi", "Tamil", "English"], label="Language"),
gr.Textbox(label="Text to Synthesize", lines=6, placeholder="Enter your sentence here..."),
gr.Radio(["Male", "Female"], label="Speaker Gender"),
gr.Dropdown(["Neutral", "Happy", "Sad", "Angry"], label="Emotion"),
gr.Dropdown(["Slow", "Moderate", "Fast"], label="Speaking Rate"),
gr.Dropdown(["Low", "Normal", "High"], label="Pitch"),
gr.Dropdown(["Basic", "Refined"], label="Voice Quality"),
],
outputs=gr.Audio(type="filepath", label="Synthesized Speech"),
title="Multilingual Indic TTS (Quantized + Chunked)",
description="Fast CPU-based TTS with quantized Parler-TTS and text chunking for Malayalam, Hindi, Tamil, and English.",
)
iface.launch()