Athspi commited on
Commit
b805946
·
verified ·
1 Parent(s): 87a1ec1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -72
app.py CHANGED
@@ -4,90 +4,76 @@ import numpy as np
4
  from transformers import AutoTokenizer
5
  import onnxruntime
6
  from huggingface_hub import hf_hub_download
 
7
 
8
  # --- Configuration ---
9
  repo_id = "Athspi/Gg" # Your Hugging Face Hub repository ID
10
- onnx_filename = "mms_tts_eng.onnx" # Name of the ONNX file in the repository
11
- sampling_rate = 16000 # Sampling rate of the model (adjust if needed)
12
 
13
- # --- Load Model and Tokenizer ---
14
 
15
- # Download the ONNX model (using hf_hub_download for caching)
16
  onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename)
 
17
 
18
- # Load the tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
19
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
20
 
21
- # --- ONNX Runtime Session Setup with Optimization ---
22
 
23
  session_options = onnxruntime.SessionOptions()
24
- # Optimization level: Use all available optimizations
25
  session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
26
- # Threading: Set intra_op_num_threads to the number of *physical* cores
27
- # (You'll need to determine this for your system). Here's a
28
- # way to get it programmatically (but it might not be 100%
29
- # reliable on all systems).
30
  try:
31
  import psutil
32
  num_physical_cores = psutil.cpu_count(logical=False)
33
  except ImportError:
34
- print("psutil not installed. You can install it with: pip install psutil")
35
- num_physical_cores = 4 # Set a reasonable default (e.g., 4)
36
- print(f"Using default number of physical cores: {num_physical_cores}")
37
-
38
  session_options.intra_op_num_threads = num_physical_cores
39
- session_options.inter_op_num_threads = 1 # Usually best for TTS to be 1 or 2
40
 
41
- # Create the ONNX Runtime inference session
42
  ort_session = onnxruntime.InferenceSession(
43
  onnx_model_path,
44
- providers=['CPUExecutionProvider'], # You can try other providers if available
45
  sess_options=session_options,
46
  )
47
 
48
-
49
  # --- IO Binding Setup ---
50
-
51
  io_binding = ort_session.io_binding()
52
-
53
- # Get input/output metadata
54
  input_meta = ort_session.get_inputs()[0]
55
  output_meta = ort_session.get_outputs()[0]
56
-
57
- # Dummy input for shape/type
58
  dummy_input = tokenizer("a", return_tensors="pt")["input_ids"].to(torch.long)
59
  input_shape = tuple(dummy_input.shape)
60
  input_type = dummy_input.numpy().dtype
61
-
62
- # Pre-allocate input tensor (CPU, contiguous)
63
  input_tensor = torch.empty(input_shape, dtype=torch.int64, device="cpu").contiguous()
64
-
65
- # Pre-allocate output tensor (CPU, contiguous) - estimate max size
66
- max_output_length = input_shape[1] * 10 # Adjust factor as needed
67
  output_shape = (1, 1, max_output_length)
68
  output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
69
 
70
- # Bind the pre-allocated tensors
71
  io_binding.bind_input(
72
- name=input_meta.name,
73
- device_type="cpu",
74
- device_id=0,
75
- element_type=input_type,
76
- shape=input_shape,
77
- buffer_ptr=input_tensor.data_ptr(),
78
  )
79
-
80
  io_binding.bind_output(
81
- name=output_meta.name,
82
- device_type="cpu",
83
- device_id=0,
84
- element_type=np.float32,
85
- shape=output_shape,
86
- buffer_ptr=output_tensor.data_ptr(),
87
  )
88
 
89
-
90
- # --- Inference Function (with IO Binding) ---
91
 
92
  def tts_inference_io_binding(text: str):
93
  """TTS inference with IO Binding."""
@@ -97,62 +83,47 @@ def tts_inference_io_binding(text: str):
97
  input_ids = inputs.input_ids.to(torch.long)
98
  current_input_shape = tuple(input_ids.shape)
99
 
100
- # Resize input tensor if necessary
101
  if current_input_shape[1] > input_tensor.shape[1]:
102
  input_tensor = torch.empty(current_input_shape, dtype=torch.int64, device="cpu").contiguous()
103
  io_binding.bind_input(
104
- name=input_meta.name,
105
- device_type="cpu",
106
- device_id=0,
107
- element_type=input_type,
108
- shape=current_input_shape,
109
  buffer_ptr=input_tensor.data_ptr(),
110
  )
111
 
112
- # Copy input data
113
  input_tensor[:current_input_shape[0], :current_input_shape[1]].copy_(input_ids)
114
 
115
- # Resize output tensor if necessary
116
  required_output_length = current_input_shape[1] * 10
117
  if required_output_length > output_tensor.shape[2]:
118
  output_shape = (1, 1, required_output_length)
119
  output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
120
  io_binding.bind_output(
121
- name=output_meta.name,
122
- device_type="cpu",
123
- device_id=0,
124
- element_type=np.float32,
125
- shape=output_shape,
126
  buffer_ptr=output_tensor.data_ptr(),
127
  )
128
-
129
- # Clear binding
130
- io_binding.clear_binding_outputs()
131
 
132
- # Run inference
133
  ort_session.run_with_iobinding(io_binding)
134
-
135
- # Get output
136
  ort_outputs = io_binding.get_outputs()
137
  output_data = ort_outputs[0].numpy()
138
-
139
  return (sampling_rate, output_data.squeeze())
140
 
141
  # --- Gradio Interface ---
142
 
143
  iface = gr.Interface(
144
  fn=tts_inference_io_binding,
145
- inputs=gr.Textbox(lines=3, placeholder="Enter text here..."), # Slightly larger textbox
146
  outputs=gr.Audio(type="numpy", label="Generated Speech"),
147
- title="Optimized MMS-TTS (English) with ONNX Runtime",
148
- description="Fast Text-to-Speech using the facebook/mms-tts-eng model, optimized with ONNX Runtime and IO Binding. Model loaded from Hugging Face Hub.",
149
  examples=[
150
- ["Hello, this is a demonstration of optimized text-to-speech."],
151
- ["This model uses ONNX Runtime and IO Binding for fast CPU inference."],
152
  ["The quick brown fox jumps over the lazy dog."],
153
- ["Try entering your own text to hear how it sounds!"]
154
  ],
155
- cache_examples=False, # Disable example caching (important for dynamic TTS)
156
  )
157
 
158
  if __name__ == "__main__":
 
4
  from transformers import AutoTokenizer
5
  import onnxruntime
6
  from huggingface_hub import hf_hub_download
7
+ import os # Import the 'os' module
8
 
9
  # --- Configuration ---
10
  repo_id = "Athspi/Gg" # Your Hugging Face Hub repository ID
11
+ onnx_filename = "mms_tts_eng.onnx" # Name of the ONNX file
12
+ sampling_rate = 16000
13
 
14
+ # --- Download ONNX Model (and handle location) ---
15
 
16
+ # Option 1: Use the cached path (Recommended)
17
  onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename)
18
+ print(f"ONNX model downloaded to (cache): {onnx_model_path}")
19
 
20
+ # Option 2: Download to a specific directory (e.g., the current working directory)
21
+ # output_dir = "." # Current directory
22
+ # onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename, cache_dir=output_dir)
23
+ # print(f"ONNX model downloaded to: {onnx_model_path}")
24
+
25
+ # Option 3: Download to a custom directory:
26
+ # output_dir = "models" # Or any directory you want
27
+ # os.makedirs(output_dir, exist_ok=True) # Create directory if it doesn't exist
28
+ # onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename, cache_dir=output_dir)
29
+ # print(f"ONNX model downloaded to: {onnx_model_path}")
30
+
31
+
32
+ # --- Load Tokenizer ---
33
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
34
 
35
+ # --- ONNX Runtime Session Setup ---
36
 
37
  session_options = onnxruntime.SessionOptions()
 
38
  session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
 
 
 
 
39
  try:
40
  import psutil
41
  num_physical_cores = psutil.cpu_count(logical=False)
42
  except ImportError:
43
+ print("psutil not installed. Install with: pip install psutil")
44
+ num_physical_cores = 4
45
+ print(f"Using default: {num_physical_cores}")
 
46
  session_options.intra_op_num_threads = num_physical_cores
47
+ session_options.inter_op_num_threads = 1
48
 
 
49
  ort_session = onnxruntime.InferenceSession(
50
  onnx_model_path,
51
+ providers=['CPUExecutionProvider'],
52
  sess_options=session_options,
53
  )
54
 
 
55
  # --- IO Binding Setup ---
 
56
  io_binding = ort_session.io_binding()
 
 
57
  input_meta = ort_session.get_inputs()[0]
58
  output_meta = ort_session.get_outputs()[0]
 
 
59
  dummy_input = tokenizer("a", return_tensors="pt")["input_ids"].to(torch.long)
60
  input_shape = tuple(dummy_input.shape)
61
  input_type = dummy_input.numpy().dtype
 
 
62
  input_tensor = torch.empty(input_shape, dtype=torch.int64, device="cpu").contiguous()
63
+ max_output_length = input_shape[1] * 10
 
 
64
  output_shape = (1, 1, max_output_length)
65
  output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
66
 
 
67
  io_binding.bind_input(
68
+ name=input_meta.name, device_type="cpu", device_id=0,
69
+ element_type=input_type, shape=input_shape, buffer_ptr=input_tensor.data_ptr(),
 
 
 
 
70
  )
 
71
  io_binding.bind_output(
72
+ name=output_meta.name, device_type="cpu", device_id=0,
73
+ element_type=np.float32, shape=output_shape, buffer_ptr=output_tensor.data_ptr(),
 
 
 
 
74
  )
75
 
76
+ # --- Inference Function ---
 
77
 
78
  def tts_inference_io_binding(text: str):
79
  """TTS inference with IO Binding."""
 
83
  input_ids = inputs.input_ids.to(torch.long)
84
  current_input_shape = tuple(input_ids.shape)
85
 
 
86
  if current_input_shape[1] > input_tensor.shape[1]:
87
  input_tensor = torch.empty(current_input_shape, dtype=torch.int64, device="cpu").contiguous()
88
  io_binding.bind_input(
89
+ name=input_meta.name, device_type="cpu", device_id=0,
90
+ element_type=input_type, shape=current_input_shape,
 
 
 
91
  buffer_ptr=input_tensor.data_ptr(),
92
  )
93
 
 
94
  input_tensor[:current_input_shape[0], :current_input_shape[1]].copy_(input_ids)
95
 
 
96
  required_output_length = current_input_shape[1] * 10
97
  if required_output_length > output_tensor.shape[2]:
98
  output_shape = (1, 1, required_output_length)
99
  output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
100
  io_binding.bind_output(
101
+ name=output_meta.name, device_type="cpu", device_id=0,
102
+ element_type=np.float32, shape=output_shape,
 
 
 
103
  buffer_ptr=output_tensor.data_ptr(),
104
  )
 
 
 
105
 
106
+ io_binding.clear_binding_outputs()
107
  ort_session.run_with_iobinding(io_binding)
 
 
108
  ort_outputs = io_binding.get_outputs()
109
  output_data = ort_outputs[0].numpy()
 
110
  return (sampling_rate, output_data.squeeze())
111
 
112
  # --- Gradio Interface ---
113
 
114
  iface = gr.Interface(
115
  fn=tts_inference_io_binding,
116
+ inputs=gr.Textbox(lines=3, placeholder="Enter text here..."),
117
  outputs=gr.Audio(type="numpy", label="Generated Speech"),
118
+ title="Optimized MMS-TTS (English)",
119
+ description="Fast TTS with ONNX Runtime and IO Binding (Hugging Face Hub).",
120
  examples=[
121
+ ["Hello, this is a demonstration."],
122
+ ["This uses ONNX Runtime and IO Binding."],
123
  ["The quick brown fox jumps over the lazy dog."],
124
+ ["Try your own text!"]
125
  ],
126
+ cache_examples=False,
127
  )
128
 
129
  if __name__ == "__main__":