leenag commited on
Commit
9ce846a
·
verified ·
1 Parent(s): 99f9cdf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -20
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import gradio as gr
2
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
3
- from transformers import VitsModel, AutoTokenizer
4
  import torch
5
- import torchaudio
 
 
6
 
7
  LANG_MODEL_MAP = {
8
  "English": "facebook/mms-tts-eng",
@@ -15,35 +15,35 @@ LANG_MODEL_MAP = {
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  cache = {}
17
 
18
- def load_model_and_processor(language):
19
  model_name = LANG_MODEL_MAP[language]
20
  if model_name not in cache:
21
- processor = AutoProcessor.from_pretrained(model_name)
22
- model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name).to(device)
23
- cache[model_name] = (processor, model)
24
  return cache[model_name]
25
 
26
- def synthesize(language, text):
27
- processor, model = load_model_and_processor(language)
28
-
29
- inputs = processor(text=text, return_tensors="pt").to(device)
30
  with torch.no_grad():
31
- generated_ids = model.generate(**inputs)
32
- audio = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
33
 
34
- # Decode and return waveform
35
- waveform, sr = torchaudio.load(audio)
36
- return sr, waveform.squeeze().numpy()
 
37
 
38
  iface = gr.Interface(
39
- fn=synthesize,
40
  inputs=[
41
  gr.Dropdown(choices=list(LANG_MODEL_MAP.keys()), label="Select Language"),
42
  gr.Textbox(label="Enter Text", placeholder="Type something...")
43
  ],
44
- outputs=gr.Audio(label="Synthesized Speech", type="numpy"),
45
- title="Multilingual TTS - MMS Facebook",
46
- description="A Gradio demo for multilingual TTS using Meta's MMS models. Supports English, Hindi, Tamil, Malayalam, and Kannada."
47
  )
48
 
49
  if __name__ == "__main__":
 
1
  import gradio as gr
 
 
2
  import torch
3
+ from transformers import VitsModel, AutoTokenizer
4
+ import soundfile as sf
5
+ import tempfile
6
 
7
  LANG_MODEL_MAP = {
8
  "English": "facebook/mms-tts-eng",
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  cache = {}
17
 
18
+ def load_model_and_tokenizer(language):
19
  model_name = LANG_MODEL_MAP[language]
20
  if model_name not in cache:
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ model = VitsModel.from_pretrained(model_name).to(device)
23
+ cache[model_name] = (tokenizer, model)
24
  return cache[model_name]
25
 
26
+ def tts(language, text):
27
+ tokenizer, model = load_model_and_tokenizer(language)
28
+ inputs = tokenizer(text, return_tensors="pt").to(device)
29
+
30
  with torch.no_grad():
31
+ output = model(**inputs)
 
32
 
33
+ # Save waveform to temp file
34
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
35
+ sf.write(f.name, output.waveform.cpu().numpy(), samplerate=16000)
36
+ return f.name
37
 
38
  iface = gr.Interface(
39
+ fn=tts,
40
  inputs=[
41
  gr.Dropdown(choices=list(LANG_MODEL_MAP.keys()), label="Select Language"),
42
  gr.Textbox(label="Enter Text", placeholder="Type something...")
43
  ],
44
+ outputs=gr.Audio(type="filepath", label="Synthesized Audio"),
45
+ title="Multilingual Text-to-Speech (MMS)",
46
+ description="Generate speech in English, Hindi, Tamil, Malayalam, or Kannada using Meta's MMS TTS models."
47
  )
48
 
49
  if __name__ == "__main__":