Tttt / app.py
Athspi's picture
Update app.py
b6b39ee verified
raw
history blame
2.14 kB
import os
import gradio as gr
import torch
import numpy as np
from transformers import AutoTokenizer
import onnxruntime
import scipy.io.wavfile
from huggingface_hub import hf_hub_download
# Define the Hugging Face repository/model ID.
repo_id = "Athspi/Gg"
# Download the ONNX model file from the repository.
# This will download "mms_tts_eng.onnx" from: https://huggingface.co/Athspi/Gg/resolve/main/mms_tts_eng.onnx
onnx_model_path = hf_hub_download(repo_id=repo_id, filename="mms_tts_eng.onnx")
# Load the tokenizer from the repository.
tokenizer = AutoTokenizer.from_pretrained(repo_id)
# Initialize the ONNX runtime session for inference.
ort_session = onnxruntime.InferenceSession(
onnx_model_path, providers=['CPUExecutionProvider']
)
# Define the fixed sampling rate (adjust if your model uses a different rate)
sampling_rate = 16000
def tts_inference(text: str):
"""
Convert input text to speech waveform using the ONNX model.
Parameters:
text (str): Input text to synthesize.
Returns:
waveform (np.ndarray): Synthesized audio waveform.
sampling_rate (int): The sampling rate of the waveform.
"""
# Tokenize the input text.
inputs = tokenizer(text, return_tensors="pt")
# Prepare inputs for the ONNX model.
input_ids = inputs.input_ids.cpu().to(torch.long).numpy()
# Run inference on the ONNX model.
onnx_outputs = ort_session.run(None, {"input_ids": input_ids})
waveform = onnx_outputs[0]
# Ensure waveform is in float32 format (required by Gradio).
waveform = waveform.astype(np.float32)
# Remove unnecessary dimensions.
waveform = np.squeeze(waveform)
# Return the waveform and its sampling rate.
return waveform, sampling_rate
# Build a Gradio interface.
iface = gr.Interface(
fn=tts_inference,
inputs=gr.Textbox(lines=2, placeholder="Enter text here..."),
outputs=gr.Audio(type="numpy"),
title="ONNX TTS Demo",
description="Text-to-Speech synthesis using an ONNX model from the Athspi/Gg repository on Hugging Face."
)
if __name__ == "__main__":
iface.launch()