Tttt / app.py
Athspi's picture
Update app.py
652ba8a verified
raw
history blame
2.03 kB
import os
import gradio as gr
import torch
import numpy as np
from transformers import AutoTokenizer
import onnxruntime
import scipy.io.wavfile
# Specify the Hugging Face repository/model directory.
# This repository (Athspi/Gg) should contain the tokenizer files and the ONNX model file.
model_dir = "Athspi/Gg"
# Define the ONNX model filename. Adjust the filename if needed.
onnx_model_filename = "mms_tts_eng.onnx"
onnx_model_path = os.path.join(model_dir, onnx_model_filename)
# Load the tokenizer from the Hugging Face model repository
tokenizer = AutoTokenizer.from_pretrained(model_dir)
# Initialize the ONNX runtime session for inference.
ort_session = onnxruntime.InferenceSession(
onnx_model_path, providers=['CPUExecutionProvider']
)
# Define the fixed sampling rate (adjust if your model uses a different rate)
sampling_rate = 16000
def tts_inference(text: str):
"""
Convert input text to speech waveform using the ONNX model.
Parameters:
text (str): Input text to synthesize.
Returns:
waveform (np.ndarray): Synthesized audio waveform.
sampling_rate (int): The sampling rate of the waveform.
"""
# Tokenize the input text.
inputs = tokenizer(text, return_tensors="pt")
# Prepare inputs for the ONNX model.
input_ids = inputs.input_ids.cpu().to(torch.long).numpy()
# Run inference on the ONNX model.
onnx_outputs = ort_session.run(None, {"input_ids": input_ids})
waveform = onnx_outputs[0]
# Remove unnecessary dimensions.
waveform = np.squeeze(waveform)
# Return the waveform and its sampling rate.
return waveform, sampling_rate
# Build a Gradio interface.
iface = gr.Interface(
fn=tts_inference,
inputs=gr.Textbox(lines=2, placeholder="Enter text here..."),
outputs=gr.Audio(type="numpy"),
title="ONNX TTS Demo",
description="Text-to-Speech synthesis using an ONNX model from the Athspi/Gg repository on Hugging Face."
)
if __name__ == "__main__":
iface.launch()