File size: 4,597 Bytes
f07e098 9efb144 2e87863 f07e098 9efb144 f07e098 9efb144 f07e098 9efb144 f07e098 9efb144 f07e098 9efb144 f07e098 9efb144 f07e098 9efb144 8b742a3 9efb144 8b742a3 9efb144 8b742a3 9efb144 8b742a3 9efb144 8b742a3 9efb144 8b742a3 9efb144 8b742a3 9efb144 8b742a3 9efb144 8b742a3 f07e098 9efb144 f07e098 9efb144 f07e098 9efb144 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import gradio as gr
import torch
import numpy as np
from transformers import AutoTokenizer
import onnxruntime
from huggingface_hub import hf_hub_download
import os
# --- Configuration ---
repo_id = "Athspi/Gg"
onnx_filename = "mms_tts_eng.onnx"
sampling_rate = 16000
# --- Download ONNX Model ---
onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename)
print(f"ONNX model downloaded to (cache): {onnx_model_path}")
# --- Load Tokenizer ---
tokenizer = AutoTokenizer.from_pretrained(repo_id)
# --- ONNX Runtime Session Setup ---
session_options = onnxruntime.SessionOptions()
session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
try:
import psutil
num_physical_cores = psutil.cpu_count(logical=False)
except ImportError:
print("psutil not installed. Install with: pip install psutil")
num_physical_cores = 4
print(f"Using default: {num_physical_cores}")
session_options.intra_op_num_threads = num_physical_cores
session_options.inter_op_num_threads = 1
ort_session = onnxruntime.InferenceSession(
onnx_model_path,
providers=['CPUExecutionProvider'],
sess_options=session_options,
)
# --- IO Binding Setup ---
io_binding = ort_session.io_binding()
input_meta = ort_session.get_inputs()[0]
output_meta = ort_session.get_outputs()[0]
dummy_input = tokenizer("a", return_tensors="pt")["input_ids"].to(torch.long)
input_shape = tuple(dummy_input.shape)
input_type = dummy_input.numpy().dtype
input_tensor = torch.empty(input_shape, dtype=torch.int64, device="cpu").contiguous()
max_output_length = input_shape[1] * 10
output_shape = (1, 1, max_output_length)
output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
# Initial binding
io_binding.bind_input(
name=input_meta.name, device_type="cpu", device_id=0,
element_type=input_type, shape=input_shape, buffer_ptr=input_tensor.data_ptr(),
)
io_binding.bind_output(
name=output_meta.name, device_type="cpu", device_id=0,
element_type=np.float32, shape=output_shape, buffer_ptr=output_tensor.data_ptr(),
)
# --- Inference Function ---
def tts_inference_io_binding(text: str):
"""TTS inference with IO Binding."""
global input_tensor, output_tensor, io_binding
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs.input_ids.to(torch.long)
current_input_shape = tuple(input_ids.shape)
# Resize and re-bind input if necessary
if current_input_shape[1] > input_tensor.shape[1]:
input_tensor = torch.empty(current_input_shape, dtype=torch.int64, device="cpu").contiguous()
io_binding.bind_input(
name=input_meta.name, device_type="cpu", device_id=0,
element_type=input_type, shape=current_input_shape,
buffer_ptr=input_tensor.data_ptr(),
)
# Copy input data to the pre-allocated tensor
input_tensor[:current_input_shape[0], :current_input_shape[1]].copy_(input_ids)
# Resize and re-bind *output* if necessary
required_output_length = current_input_shape[1] * 10 # Estimate
if required_output_length > output_tensor.shape[2]:
output_shape = (1, 1, required_output_length)
output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
io_binding.bind_output( # Re-bind output
name=output_meta.name, device_type="cpu", device_id=0,
element_type=np.float32, shape=output_shape,
buffer_ptr=output_tensor.data_ptr(),
)
# Clear outputs *before* running inference, *after* (re)binding
io_binding.clear_binding_outputs()
ort_session.run_with_iobinding(io_binding) # Run inference
# The output data is now *already* in output_tensor, so we just get it
ort_outputs = io_binding.get_outputs() # Get a list with the output information.
output_data = ort_outputs[0].numpy() # Get the data as a NumPy array
return (sampling_rate, output_data.squeeze())
# --- Gradio Interface ---
iface = gr.Interface(
fn=tts_inference_io_binding,
inputs=gr.Textbox(lines=3, placeholder="Enter text here..."),
outputs=gr.Audio(type="numpy", label="Generated Speech"),
title="Optimized MMS-TTS (English)",
description="Fast TTS with ONNX Runtime and IO Binding (Hugging Face Hub).",
examples=[
["Hello, this is a demonstration."],
["This uses ONNX Runtime and IO Binding."],
["The quick brown fox jumps over the lazy dog."],
["Try your own text!"]
],
cache_examples=False,
)
if __name__ == "__main__":
iface.launch() |