FarmerlineML commited on
Commit
5a8d3bb
·
verified ·
1 Parent(s): e0a1a99

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import torch
4
+ import numpy as np
5
+ import gradio as gr
6
+ import scipy.io.wavfile as wavfile
7
+ from pydub import AudioSegment
8
+ from transformers import VitsModel, AutoTokenizer
9
+
10
+ # ---------- Configuration --------------------------------------------------
11
+ # Define available TTS models here. Add new entries as needed.
12
+ TTS_MODELS = {
13
+ "Swahili": {
14
+ "tokenizer": "FarmerlineML/swahili-tts-2025",
15
+ "checkpoint": "FarmerlineML/Swahili-tts-2025_part4"
16
+ },
17
+ "Krio": {
18
+ "tokenizer": "facebook/mms-tts-kri",
19
+ "checkpoint": "facebook/mms-tts-kri"
20
+ },
21
+ }
22
+
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ # ---------- Load all models & tokenizers -----------------------------------
26
+ models = {}
27
+ tokenizers = {}
28
+ for name, paths in TTS_MODELS.items():
29
+ print(f"Loading {name} model...")
30
+ model = VitsModel.from_pretrained(paths["checkpoint"]).to(device)
31
+ model.eval()
32
+ # Apply clear-speech inference parameters (tweak per model if desired)
33
+ model.noise_scale = 0.7
34
+ model.noise_scale_duration = 0.667
35
+ model.speaking_rate = 0.75
36
+ models[name] = model
37
+ tokenizers[name] = AutoTokenizer.from_pretrained(paths["tokenizer"])
38
+
39
+ # ---------- Utility: WAV ➔ MP3 Conversion -----------------------------------
40
+ def _wav_to_mp3(wave_np: np.ndarray, sr: int) -> str:
41
+ """Convert int16 numpy waveform to an MP3 temp file, return its path."""
42
+ # Ensure int16 for pydub
43
+ if wave_np.dtype != np.int16:
44
+ wave_np = (wave_np * 32767).astype(np.int16)
45
+
46
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tf:
47
+ wavfile.write(tf.name, sr, wave_np)
48
+ wav_path = tf.name
49
+
50
+ mp3_path = wav_path.replace(".wav", ".mp3")
51
+ AudioSegment.from_wav(wav_path).export(mp3_path, format="mp3", bitrate="64k")
52
+ os.remove(wav_path)
53
+ return mp3_path
54
+
55
+ # ---------- TTS Generation ---------------------------------------------------
56
+ def tts_generate(model_name: str, text: str):
57
+ """Generate speech for `text` using the selected model."""
58
+ if not text:
59
+ return None
60
+ model = models[model_name]
61
+ tokenizer = tokenizers[model_name]
62
+ inputs = tokenizer(text, return_tensors="pt").to(device)
63
+ with torch.no_grad():
64
+ wave = model(**inputs).waveform[0].cpu().numpy()
65
+ return _wav_to_mp3(wave, model.config.sampling_rate)
66
+
67
+ # ---------- Gradio Interface ------------------------------------------------
68
+ examples = [
69
+ ["Swahili", "zao kusaidia kuondoa umaskini na kujenga kampeni za mwamko wa virusi vya ukimwi amezitembelea"],
70
+ ["Swahili", "Kidole hiki ni tofauti na vidole vingine kwa sababu mwelekeo wake ni wa pekee."],
71
+ ["Swahili", "Tafadhali hakikisha umefunga mlango kabla ya kuondoka."],
72
+ ["Krio", "Wetin na yu nem?"],
73
+ ["Krio", "Usai yu kɔmɔt?"],
74
+ ["Krio", "A gladi fɔ mit yu."],
75
+ ]
76
+
77
+ demo = gr.Interface(
78
+ fn=tts_generate,
79
+ inputs=[
80
+ gr.Dropdown(choices=list(TTS_MODELS.keys()), default="Swahili", label="Choose TTS Model"),
81
+ gr.Textbox(lines=3, placeholder="Enter text here", label="Input Text")
82
+ ],
83
+ outputs=gr.Audio(type="filepath", label="Audio", autoplay=True),
84
+ title="Multi‐Model Text-to-Speech",
85
+ description=(
86
+ "Select a TTS model from the dropdown and enter text to generate speech."
87
+ ),
88
+ examples=examples,
89
+ cache_examples=True,
90
+ )
91
+
92
+ if __name__ == "__main__":
93
+ demo.launch()