leenag commited on
Commit
c573494
·
verified ·
1 Parent(s): 6716bc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -87
app.py CHANGED
@@ -1,95 +1,49 @@
1
- import torch
2
- import soundfile as sf
3
- import uuid
4
  import gradio as gr
5
- import numpy as np
6
- import re
7
- from parler_tts import ParlerTTSForConditionalGeneration
8
- from transformers import AutoTokenizer
9
-
10
- # Load model and tokenizers
11
- model_name = "ai4bharat/indic-parler-tts"
12
- device = "cpu"
13
-
14
- print("Loading model...")
15
- model = ParlerTTSForConditionalGeneration.from_pretrained(model_name).to(device).eval()
16
- tokenizer = AutoTokenizer.from_pretrained(model_name)
17
- desc_tokenizer = AutoTokenizer.from_pretrained(model.config.text_encoder._name_or_path)
18
-
19
- print("Applying dynamic quantization...")
20
- quantized_model = torch.quantization.quantize_dynamic(
21
- model,
22
- {torch.nn.Linear},
23
- dtype=torch.qint8
24
- )
25
-
26
- # Sentence splitter
27
- def split_text(text, max_len=150):
28
- chunks = re.split(r'(?<=[.!?]) +', text)
29
- refined = []
30
- for chunk in chunks:
31
- if len(chunk) <= max_len:
32
- refined.append(chunk)
33
- else:
34
- words = chunk.split()
35
- temp = []
36
- buf_len = 0
37
- for word in words:
38
- temp.append(word)
39
- buf_len += len(word) + 1
40
- if buf_len > max_len:
41
- refined.append(' '.join(temp))
42
- temp = []
43
- buf_len = 0
44
- if temp:
45
- refined.append(' '.join(temp))
46
- return refined
47
-
48
- # Core TTS function
49
- def synthesize(language, text, gender, emotion, speed):
50
- description = (
51
- f"A native {language.lower()} female speaker with an expressive tone."
52
- )
53
-
54
- audio_chunks = []
55
- text_chunks = split_text(text)
56
-
57
- for chunk in text_chunks:
58
- # New tokenization for each chunk
59
- desc_input = desc_tokenizer(description, return_tensors="pt").to(device)
60
- prompt_input = tokenizer(chunk, return_tensors="pt").to(device)
61
-
62
- with torch.no_grad():
63
- output = quantized_model.generate(
64
- input_ids=desc_input.input_ids,
65
- attention_mask=desc_input.attention_mask,
66
- prompt_input_ids=prompt_input.input_ids,
67
- prompt_attention_mask=torch.ones_like(prompt_input.input_ids).to(device)
68
- )
69
-
70
- audio = output.cpu().numpy().squeeze()
71
- audio_chunks.append(audio)
72
-
73
- full_audio = np.concatenate(audio_chunks)
74
- filename = f"{uuid.uuid4().hex}.wav"
75
- sf.write(filename, full_audio, quantized_model.config.sampling_rate)
76
- return filename
77
 
78
- # Gradio UI
79
  iface = gr.Interface(
80
  fn=synthesize,
81
  inputs=[
82
- gr.Dropdown(["Malayalam", "Hindi", "Tamil", "English", "Kannada"], label="Language"),
83
- gr.Textbox(label="Text to Synthesize", lines=6, placeholder="Enter your sentence here..."),
84
- # gr.Radio(["Male", "Female"], label="Speaker Gender"),
85
- # gr.Dropdown(["Neutral", "Happy", "Sad", "Angry"], label="Emotion"),
86
- # gr.Dropdown(["Slow", "Moderate", "Fast"], label="Speaking Rate"),
87
- #gr.Dropdown(["Low", "Normal", "High"], label="Pitch"),
88
- #gr.Dropdown(["Basic", "Refined"], label="Voice Quality"),
89
  ],
90
- outputs=gr.Audio(type="filepath", label="Synthesized Speech"),
91
- title="Multilingual Indic TTS (Quantized + Chunked)",
92
- description="CPU-based TTS with quantized Parler-TTS and chunked input for Malayalam, Hindi, Tamil, and English.",
93
  )
94
 
95
- iface.launch()
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
3
+ import torch
4
+ import torchaudio
5
+
6
+ LANG_MODEL_MAP = {
7
+ "English": "facebook/mms-tts-eng",
8
+ "Hindi": "facebook/mms-tts-hin",
9
+ "Tamil": "facebook/mms-tts-tam",
10
+ "Malayalam": "facebook/mms-tts-mal",
11
+ "Kannada": "facebook/mms-tts-kan"
12
+ }
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ cache = {}
16
+
17
+ def load_model_and_processor(language):
18
+ model_name = LANG_MODEL_MAP[language]
19
+ if model_name not in cache:
20
+ processor = AutoProcessor.from_pretrained(model_name)
21
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name).to(device)
22
+ cache[model_name] = (processor, model)
23
+ return cache[model_name]
24
+
25
+ def synthesize(language, text):
26
+ processor, model = load_model_and_processor(language)
27
+
28
+ inputs = processor(text=text, return_tensors="pt").to(device)
29
+ with torch.no_grad():
30
+ generated_ids = model.generate(**inputs)
31
+ audio = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
32
+
33
+ # Decode and return waveform
34
+ waveform, sr = torchaudio.load(audio)
35
+ return sr, waveform.squeeze().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
 
37
  iface = gr.Interface(
38
  fn=synthesize,
39
  inputs=[
40
+ gr.Dropdown(choices=list(LANG_MODEL_MAP.keys()), label="Select Language"),
41
+ gr.Textbox(label="Enter Text", placeholder="Type something...")
 
 
 
 
 
42
  ],
43
+ outputs=gr.Audio(label="Synthesized Speech", type="numpy"),
44
+ title="Multilingual TTS - MMS Facebook",
45
+ description="A Gradio demo for multilingual TTS using Meta's MMS models. Supports English, Hindi, Tamil, Malayalam, and Kannada."
46
  )
47
 
48
+ if __name__ == "__main__":
49
+ iface.launch()