File size: 2,137 Bytes
6417dc9 c21b225 6417dc9 4368215 1f4b6af c21b225 1f4b6af 4368215 1f4b6af 4368215 1f4b6af 4368215 6417dc9 c21b225 6417dc9 c21b225 6417dc9 b6b39ee 6417dc9 c21b225 6417dc9 4368215 6417dc9 c21b225 4368215 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import os
import gradio as gr
import torch
import numpy as np
from transformers import AutoTokenizer
import onnxruntime
import scipy.io.wavfile
from huggingface_hub import hf_hub_download
# Define the Hugging Face repository/model ID.
repo_id = "Athspi/Gg"
# Download the ONNX model file from the repository.
# This will download "mms_tts_eng.onnx" from: https://huggingface.co/Athspi/Gg/resolve/main/mms_tts_eng.onnx
onnx_model_path = hf_hub_download(repo_id=repo_id, filename="mms_tts_eng.onnx")
# Load the tokenizer from the repository.
tokenizer = AutoTokenizer.from_pretrained(repo_id)
# Initialize the ONNX runtime session for inference.
ort_session = onnxruntime.InferenceSession(
onnx_model_path, providers=['CPUExecutionProvider']
)
# Define the fixed sampling rate (adjust if your model uses a different rate)
sampling_rate = 16000
def tts_inference(text: str):
"""
Convert input text to speech waveform using the ONNX model.
Parameters:
text (str): Input text to synthesize.
Returns:
waveform (np.ndarray): Synthesized audio waveform.
sampling_rate (int): The sampling rate of the waveform.
"""
# Tokenize the input text.
inputs = tokenizer(text, return_tensors="pt")
# Prepare inputs for the ONNX model.
input_ids = inputs.input_ids.cpu().to(torch.long).numpy()
# Run inference on the ONNX model.
onnx_outputs = ort_session.run(None, {"input_ids": input_ids})
waveform = onnx_outputs[0]
# Ensure waveform is in float32 format (required by Gradio).
waveform = waveform.astype(np.float32)
# Remove unnecessary dimensions.
waveform = np.squeeze(waveform)
# Return the waveform and its sampling rate.
return waveform, sampling_rate
# Build a Gradio interface.
iface = gr.Interface(
fn=tts_inference,
inputs=gr.Textbox(lines=2, placeholder="Enter text here..."),
outputs=gr.Audio(type="numpy"),
title="ONNX TTS Demo",
description="Text-to-Speech synthesis using an ONNX model from the Athspi/Gg repository on Hugging Face."
)
if __name__ == "__main__":
iface.launch() |