Update app.py
Browse files
app.py
CHANGED
@@ -11,7 +11,8 @@ from huggingface_hub import hf_hub_download
|
|
11 |
repo_id = "Athspi/Gg"
|
12 |
|
13 |
# Download the ONNX model file from the repository.
|
14 |
-
# This will download "mms_tts_eng.onnx" from:
|
|
|
15 |
onnx_model_path = hf_hub_download(repo_id=repo_id, filename="mms_tts_eng.onnx")
|
16 |
|
17 |
# Load the tokenizer from the repository.
|
@@ -33,7 +34,7 @@ def tts_inference(text: str):
|
|
33 |
text (str): Input text to synthesize.
|
34 |
|
35 |
Returns:
|
36 |
-
waveform (np.ndarray): Synthesized audio waveform.
|
37 |
sampling_rate (int): The sampling rate of the waveform.
|
38 |
"""
|
39 |
# Tokenize the input text.
|
@@ -46,16 +47,19 @@ def tts_inference(text: str):
|
|
46 |
onnx_outputs = ort_session.run(None, {"input_ids": input_ids})
|
47 |
waveform = onnx_outputs[0]
|
48 |
|
49 |
-
# Ensure
|
|
|
|
|
|
|
|
|
50 |
waveform = waveform.astype(np.float32)
|
51 |
-
|
52 |
-
# Remove
|
53 |
waveform = np.squeeze(waveform)
|
54 |
|
55 |
-
# Return the waveform and its sampling rate.
|
56 |
return waveform, sampling_rate
|
57 |
|
58 |
-
# Build
|
59 |
iface = gr.Interface(
|
60 |
fn=tts_inference,
|
61 |
inputs=gr.Textbox(lines=2, placeholder="Enter text here..."),
|
|
|
11 |
repo_id = "Athspi/Gg"
|
12 |
|
13 |
# Download the ONNX model file from the repository.
|
14 |
+
# This will download "mms_tts_eng.onnx" from:
|
15 |
+
# https://huggingface.co/Athspi/Gg/resolve/main/mms_tts_eng.onnx
|
16 |
onnx_model_path = hf_hub_download(repo_id=repo_id, filename="mms_tts_eng.onnx")
|
17 |
|
18 |
# Load the tokenizer from the repository.
|
|
|
34 |
text (str): Input text to synthesize.
|
35 |
|
36 |
Returns:
|
37 |
+
waveform (np.ndarray): Synthesized audio waveform in float32 format.
|
38 |
sampling_rate (int): The sampling rate of the waveform.
|
39 |
"""
|
40 |
# Tokenize the input text.
|
|
|
47 |
onnx_outputs = ort_session.run(None, {"input_ids": input_ids})
|
48 |
waveform = onnx_outputs[0]
|
49 |
|
50 |
+
# Ensure the output is a NumPy array.
|
51 |
+
if not isinstance(waveform, np.ndarray):
|
52 |
+
waveform = np.array(waveform)
|
53 |
+
|
54 |
+
# Convert waveform to float32 (required by Gradio's Audio component).
|
55 |
waveform = waveform.astype(np.float32)
|
56 |
+
|
57 |
+
# Remove any extra dimensions.
|
58 |
waveform = np.squeeze(waveform)
|
59 |
|
|
|
60 |
return waveform, sampling_rate
|
61 |
|
62 |
+
# Build the Gradio interface.
|
63 |
iface = gr.Interface(
|
64 |
fn=tts_inference,
|
65 |
inputs=gr.Textbox(lines=2, placeholder="Enter text here..."),
|