Tg / app.py
Athspi's picture
Update app.py
b805946 verified
raw
history blame
4.95 kB
import gradio as gr
import torch
import numpy as np
from transformers import AutoTokenizer
import onnxruntime
from huggingface_hub import hf_hub_download
import os # Import the 'os' module
# --- Configuration ---
repo_id = "Athspi/Gg" # Your Hugging Face Hub repository ID
onnx_filename = "mms_tts_eng.onnx" # Name of the ONNX file
sampling_rate = 16000
# --- Download ONNX Model (and handle location) ---
# Option 1: Use the cached path (Recommended)
onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename)
print(f"ONNX model downloaded to (cache): {onnx_model_path}")
# Option 2: Download to a specific directory (e.g., the current working directory)
# output_dir = "." # Current directory
# onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename, cache_dir=output_dir)
# print(f"ONNX model downloaded to: {onnx_model_path}")
# Option 3: Download to a custom directory:
# output_dir = "models" # Or any directory you want
# os.makedirs(output_dir, exist_ok=True) # Create directory if it doesn't exist
# onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename, cache_dir=output_dir)
# print(f"ONNX model downloaded to: {onnx_model_path}")
# --- Load Tokenizer ---
tokenizer = AutoTokenizer.from_pretrained(repo_id)
# --- ONNX Runtime Session Setup ---
session_options = onnxruntime.SessionOptions()
session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
try:
import psutil
num_physical_cores = psutil.cpu_count(logical=False)
except ImportError:
print("psutil not installed. Install with: pip install psutil")
num_physical_cores = 4
print(f"Using default: {num_physical_cores}")
session_options.intra_op_num_threads = num_physical_cores
session_options.inter_op_num_threads = 1
ort_session = onnxruntime.InferenceSession(
onnx_model_path,
providers=['CPUExecutionProvider'],
sess_options=session_options,
)
# --- IO Binding Setup ---
io_binding = ort_session.io_binding()
input_meta = ort_session.get_inputs()[0]
output_meta = ort_session.get_outputs()[0]
dummy_input = tokenizer("a", return_tensors="pt")["input_ids"].to(torch.long)
input_shape = tuple(dummy_input.shape)
input_type = dummy_input.numpy().dtype
input_tensor = torch.empty(input_shape, dtype=torch.int64, device="cpu").contiguous()
max_output_length = input_shape[1] * 10
output_shape = (1, 1, max_output_length)
output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
io_binding.bind_input(
name=input_meta.name, device_type="cpu", device_id=0,
element_type=input_type, shape=input_shape, buffer_ptr=input_tensor.data_ptr(),
)
io_binding.bind_output(
name=output_meta.name, device_type="cpu", device_id=0,
element_type=np.float32, shape=output_shape, buffer_ptr=output_tensor.data_ptr(),
)
# --- Inference Function ---
def tts_inference_io_binding(text: str):
"""TTS inference with IO Binding."""
global input_tensor, output_tensor, io_binding
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs.input_ids.to(torch.long)
current_input_shape = tuple(input_ids.shape)
if current_input_shape[1] > input_tensor.shape[1]:
input_tensor = torch.empty(current_input_shape, dtype=torch.int64, device="cpu").contiguous()
io_binding.bind_input(
name=input_meta.name, device_type="cpu", device_id=0,
element_type=input_type, shape=current_input_shape,
buffer_ptr=input_tensor.data_ptr(),
)
input_tensor[:current_input_shape[0], :current_input_shape[1]].copy_(input_ids)
required_output_length = current_input_shape[1] * 10
if required_output_length > output_tensor.shape[2]:
output_shape = (1, 1, required_output_length)
output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
io_binding.bind_output(
name=output_meta.name, device_type="cpu", device_id=0,
element_type=np.float32, shape=output_shape,
buffer_ptr=output_tensor.data_ptr(),
)
io_binding.clear_binding_outputs()
ort_session.run_with_iobinding(io_binding)
ort_outputs = io_binding.get_outputs()
output_data = ort_outputs[0].numpy()
return (sampling_rate, output_data.squeeze())
# --- Gradio Interface ---
iface = gr.Interface(
fn=tts_inference_io_binding,
inputs=gr.Textbox(lines=3, placeholder="Enter text here..."),
outputs=gr.Audio(type="numpy", label="Generated Speech"),
title="Optimized MMS-TTS (English)",
description="Fast TTS with ONNX Runtime and IO Binding (Hugging Face Hub).",
examples=[
["Hello, this is a demonstration."],
["This uses ONNX Runtime and IO Binding."],
["The quick brown fox jumps over the lazy dog."],
["Try your own text!"]
],
cache_examples=False,
)
if __name__ == "__main__":
iface.launch()