import gradio as gr import torch import numpy as np from transformers import AutoTokenizer import onnxruntime from huggingface_hub import hf_hub_download import os # --- Configuration --- repo_id = "Athspi/Gg" onnx_filename = "mms_tts_eng.onnx" sampling_rate = 16000 # --- Download ONNX Model --- onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename) print(f"ONNX model downloaded to (cache): {onnx_model_path}") # --- Load Tokenizer --- tokenizer = AutoTokenizer.from_pretrained(repo_id) # --- ONNX Runtime Session Setup --- session_options = onnxruntime.SessionOptions() session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL try: import psutil num_physical_cores = psutil.cpu_count(logical=False) except ImportError: print("psutil not installed. Install with: pip install psutil") num_physical_cores = 4 print(f"Using default: {num_physical_cores}") session_options.intra_op_num_threads = num_physical_cores session_options.inter_op_num_threads = 1 ort_session = onnxruntime.InferenceSession( onnx_model_path, providers=['CPUExecutionProvider'], sess_options=session_options, ) # --- IO Binding Setup --- io_binding = ort_session.io_binding() input_meta = ort_session.get_inputs()[0] output_meta = ort_session.get_outputs()[0] dummy_input = tokenizer("a", return_tensors="pt")["input_ids"].to(torch.long) input_shape = tuple(dummy_input.shape) input_type = dummy_input.numpy().dtype input_tensor = torch.empty(input_shape, dtype=torch.int64, device="cpu").contiguous() max_output_length = input_shape[1] * 10 output_shape = (1, 1, max_output_length) output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous() # Initial binding io_binding.bind_input( name=input_meta.name, device_type="cpu", device_id=0, element_type=input_type, shape=input_shape, buffer_ptr=input_tensor.data_ptr(), ) io_binding.bind_output( name=output_meta.name, device_type="cpu", device_id=0, element_type=np.float32, shape=output_shape, buffer_ptr=output_tensor.data_ptr(), ) # --- Inference Function --- def tts_inference_io_binding(text: str): """TTS inference with IO Binding.""" global input_tensor, output_tensor, io_binding inputs = tokenizer(text, return_tensors="pt") input_ids = inputs.input_ids.to(torch.long) current_input_shape = tuple(input_ids.shape) # Resize and re-bind input if necessary if current_input_shape[1] > input_tensor.shape[1]: input_tensor = torch.empty(current_input_shape, dtype=torch.int64, device="cpu").contiguous() io_binding.bind_input( name=input_meta.name, device_type="cpu", device_id=0, element_type=input_type, shape=current_input_shape, buffer_ptr=input_tensor.data_ptr(), ) # Copy input data to the pre-allocated tensor input_tensor[:current_input_shape[0], :current_input_shape[1]].copy_(input_ids) # Resize and re-bind *output* if necessary required_output_length = current_input_shape[1] * 10 # Estimate if required_output_length > output_tensor.shape[2]: output_shape = (1, 1, required_output_length) output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous() io_binding.bind_output( # Re-bind output name=output_meta.name, device_type="cpu", device_id=0, element_type=np.float32, shape=output_shape, buffer_ptr=output_tensor.data_ptr(), ) # Clear outputs *before* running inference, *after* (re)binding io_binding.clear_binding_outputs() ort_session.run_with_iobinding(io_binding) # Run inference # The output data is now *already* in output_tensor, so we just get it ort_outputs = io_binding.get_outputs() # Get a list with the output information. output_data = ort_outputs[0].numpy() # Get the data as a NumPy array return (sampling_rate, output_data.squeeze()) # --- Gradio Interface --- iface = gr.Interface( fn=tts_inference_io_binding, inputs=gr.Textbox(lines=3, placeholder="Enter text here..."), outputs=gr.Audio(type="numpy", label="Generated Speech"), title="Optimized MMS-TTS (English)", description="Fast TTS with ONNX Runtime and IO Binding (Hugging Face Hub).", examples=[ ["Hello, this is a demonstration."], ["This uses ONNX Runtime and IO Binding."], ["The quick brown fox jumps over the lazy dog."], ["Try your own text!"] ], cache_examples=False, ) if __name__ == "__main__": iface.launch()