Athspi commited on
Commit
f07e098
·
verified ·
1 Parent(s): 745e842

Create app py

Browse files
Files changed (1) hide show
  1. app py +159 -0
app py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoTokenizer
5
+ import onnxruntime
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ # --- Configuration ---
9
+ repo_id = "Athspi/Gg" # Your Hugging Face Hub repository ID
10
+ onnx_filename = "mms_tts_eng.onnx" # Name of the ONNX file in the repository
11
+ sampling_rate = 16000 # Sampling rate of the model (adjust if needed)
12
+
13
+ # --- Load Model and Tokenizer ---
14
+
15
+ # Download the ONNX model (using hf_hub_download for caching)
16
+ onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename)
17
+
18
+ # Load the tokenizer
19
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
20
+
21
+ # --- ONNX Runtime Session Setup with Optimization ---
22
+
23
+ session_options = onnxruntime.SessionOptions()
24
+ # Optimization level: Use all available optimizations
25
+ session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
26
+ # Threading: Set intra_op_num_threads to the number of *physical* cores
27
+ # (You'll need to determine this for your system). Here's a
28
+ # way to get it programmatically (but it might not be 100%
29
+ # reliable on all systems).
30
+ try:
31
+ import psutil
32
+ num_physical_cores = psutil.cpu_count(logical=False)
33
+ except ImportError:
34
+ print("psutil not installed. You can install it with: pip install psutil")
35
+ num_physical_cores = 4 # Set a reasonable default (e.g., 4)
36
+ print(f"Using default number of physical cores: {num_physical_cores}")
37
+
38
+ session_options.intra_op_num_threads = num_physical_cores
39
+ session_options.inter_op_num_threads = 1 # Usually best for TTS to be 1 or 2
40
+
41
+ # Create the ONNX Runtime inference session
42
+ ort_session = onnxruntime.InferenceSession(
43
+ onnx_model_path,
44
+ providers=['CPUExecutionProvider'], # You can try other providers if available
45
+ sess_options=session_options,
46
+ )
47
+
48
+
49
+ # --- IO Binding Setup ---
50
+
51
+ io_binding = ort_session.io_binding()
52
+
53
+ # Get input/output metadata
54
+ input_meta = ort_session.get_inputs()[0]
55
+ output_meta = ort_session.get_outputs()[0]
56
+
57
+ # Dummy input for shape/type
58
+ dummy_input = tokenizer("a", return_tensors="pt")["input_ids"].to(torch.long)
59
+ input_shape = tuple(dummy_input.shape)
60
+ input_type = dummy_input.numpy().dtype
61
+
62
+ # Pre-allocate input tensor (CPU, contiguous)
63
+ input_tensor = torch.empty(input_shape, dtype=torch.int64, device="cpu").contiguous()
64
+
65
+ # Pre-allocate output tensor (CPU, contiguous) - estimate max size
66
+ max_output_length = input_shape[1] * 10 # Adjust factor as needed
67
+ output_shape = (1, 1, max_output_length)
68
+ output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
69
+
70
+ # Bind the pre-allocated tensors
71
+ io_binding.bind_input(
72
+ name=input_meta.name,
73
+ device_type="cpu",
74
+ device_id=0,
75
+ element_type=input_type,
76
+ shape=input_shape,
77
+ buffer_ptr=input_tensor.data_ptr(),
78
+ )
79
+
80
+ io_binding.bind_output(
81
+ name=output_meta.name,
82
+ device_type="cpu",
83
+ device_id=0,
84
+ element_type=np.float32,
85
+ shape=output_shape,
86
+ buffer_ptr=output_tensor.data_ptr(),
87
+ )
88
+
89
+
90
+ # --- Inference Function (with IO Binding) ---
91
+
92
+ def tts_inference_io_binding(text: str):
93
+ """TTS inference with IO Binding."""
94
+ global input_tensor, output_tensor, io_binding
95
+
96
+ inputs = tokenizer(text, return_tensors="pt")
97
+ input_ids = inputs.input_ids.to(torch.long)
98
+ current_input_shape = tuple(input_ids.shape)
99
+
100
+ # Resize input tensor if necessary
101
+ if current_input_shape[1] > input_tensor.shape[1]:
102
+ input_tensor = torch.empty(current_input_shape, dtype=torch.int64, device="cpu").contiguous()
103
+ io_binding.bind_input(
104
+ name=input_meta.name,
105
+ device_type="cpu",
106
+ device_id=0,
107
+ element_type=input_type,
108
+ shape=current_input_shape,
109
+ buffer_ptr=input_tensor.data_ptr(),
110
+ )
111
+
112
+ # Copy input data
113
+ input_tensor[:current_input_shape[0], :current_input_shape[1]].copy_(input_ids)
114
+
115
+ # Resize output tensor if necessary
116
+ required_output_length = current_input_shape[1] * 10
117
+ if required_output_length > output_tensor.shape[2]:
118
+ output_shape = (1, 1, required_output_length)
119
+ output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
120
+ io_binding.bind_output(
121
+ name=output_meta.name,
122
+ device_type="cpu",
123
+ device_id=0,
124
+ element_type=np.float32,
125
+ shape=output_shape,
126
+ buffer_ptr=output_tensor.data_ptr(),
127
+ )
128
+
129
+ # Clear binding
130
+ io_binding.clear_binding_outputs()
131
+
132
+ # Run inference
133
+ ort_session.run_with_iobinding(io_binding)
134
+
135
+ # Get output
136
+ ort_outputs = io_binding.get_outputs()
137
+ output_data = ort_outputs[0].numpy()
138
+
139
+ return (sampling_rate, output_data.squeeze())
140
+
141
+ # --- Gradio Interface ---
142
+
143
+ iface = gr.Interface(
144
+ fn=tts_inference_io_binding,
145
+ inputs=gr.Textbox(lines=3, placeholder="Enter text here..."), # Slightly larger textbox
146
+ outputs=gr.Audio(type="numpy", label="Generated Speech"),
147
+ title="Optimized MMS-TTS (English) with ONNX Runtime",
148
+ description="Fast Text-to-Speech using the facebook/mms-tts-eng model, optimized with ONNX Runtime and IO Binding. Model loaded from Hugging Face Hub.",
149
+ examples=[
150
+ ["Hello, this is a demonstration of optimized text-to-speech."],
151
+ ["This model uses ONNX Runtime and IO Binding for fast CPU inference."],
152
+ ["The quick brown fox jumps over the lazy dog."],
153
+ ["Try entering your own text to hear how it sounds!"]
154
+ ],
155
+ cache_examples=False, # Disable example caching (important for dynamic TTS)
156
+ )
157
+
158
+ if __name__ == "__main__":
159
+ iface.launch()