Athspi commited on
Commit
2e87863
·
verified ·
1 Parent(s): b805946

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -23
app.py CHANGED
@@ -4,31 +4,17 @@ import numpy as np
4
  from transformers import AutoTokenizer
5
  import onnxruntime
6
  from huggingface_hub import hf_hub_download
7
- import os # Import the 'os' module
8
 
9
  # --- Configuration ---
10
- repo_id = "Athspi/Gg" # Your Hugging Face Hub repository ID
11
- onnx_filename = "mms_tts_eng.onnx" # Name of the ONNX file
12
  sampling_rate = 16000
13
 
14
- # --- Download ONNX Model (and handle location) ---
15
-
16
- # Option 1: Use the cached path (Recommended)
17
  onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename)
18
  print(f"ONNX model downloaded to (cache): {onnx_model_path}")
19
 
20
- # Option 2: Download to a specific directory (e.g., the current working directory)
21
- # output_dir = "." # Current directory
22
- # onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename, cache_dir=output_dir)
23
- # print(f"ONNX model downloaded to: {onnx_model_path}")
24
-
25
- # Option 3: Download to a custom directory:
26
- # output_dir = "models" # Or any directory you want
27
- # os.makedirs(output_dir, exist_ok=True) # Create directory if it doesn't exist
28
- # onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename, cache_dir=output_dir)
29
- # print(f"ONNX model downloaded to: {onnx_model_path}")
30
-
31
-
32
  # --- Load Tokenizer ---
33
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
34
 
@@ -64,6 +50,7 @@ max_output_length = input_shape[1] * 10
64
  output_shape = (1, 1, max_output_length)
65
  output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
66
 
 
67
  io_binding.bind_input(
68
  name=input_meta.name, device_type="cpu", device_id=0,
69
  element_type=input_type, shape=input_shape, buffer_ptr=input_tensor.data_ptr(),
@@ -72,7 +59,6 @@ io_binding.bind_output(
72
  name=output_meta.name, device_type="cpu", device_id=0,
73
  element_type=np.float32, shape=output_shape, buffer_ptr=output_tensor.data_ptr(),
74
  )
75
-
76
  # --- Inference Function ---
77
 
78
  def tts_inference_io_binding(text: str):
@@ -85,7 +71,7 @@ def tts_inference_io_binding(text: str):
85
 
86
  if current_input_shape[1] > input_tensor.shape[1]:
87
  input_tensor = torch.empty(current_input_shape, dtype=torch.int64, device="cpu").contiguous()
88
- io_binding.bind_input(
89
  name=input_meta.name, device_type="cpu", device_id=0,
90
  element_type=input_type, shape=current_input_shape,
91
  buffer_ptr=input_tensor.data_ptr(),
@@ -97,16 +83,16 @@ def tts_inference_io_binding(text: str):
97
  if required_output_length > output_tensor.shape[2]:
98
  output_shape = (1, 1, required_output_length)
99
  output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
100
- io_binding.bind_output(
101
  name=output_meta.name, device_type="cpu", device_id=0,
102
  element_type=np.float32, shape=output_shape,
103
  buffer_ptr=output_tensor.data_ptr(),
104
  )
105
 
106
- io_binding.clear_binding_outputs()
107
  ort_session.run_with_iobinding(io_binding)
108
  ort_outputs = io_binding.get_outputs()
109
- output_data = ort_outputs[0].numpy()
110
  return (sampling_rate, output_data.squeeze())
111
 
112
  # --- Gradio Interface ---
 
4
  from transformers import AutoTokenizer
5
  import onnxruntime
6
  from huggingface_hub import hf_hub_download
7
+ import os
8
 
9
  # --- Configuration ---
10
+ repo_id = "Athspi/Gg"
11
+ onnx_filename = "mms_tts_eng.onnx"
12
  sampling_rate = 16000
13
 
14
+ # --- Download ONNX Model ---
 
 
15
  onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename)
16
  print(f"ONNX model downloaded to (cache): {onnx_model_path}")
17
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # --- Load Tokenizer ---
19
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
20
 
 
50
  output_shape = (1, 1, max_output_length)
51
  output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
52
 
53
+ # Bind BEFORE clear_binding_outputs
54
  io_binding.bind_input(
55
  name=input_meta.name, device_type="cpu", device_id=0,
56
  element_type=input_type, shape=input_shape, buffer_ptr=input_tensor.data_ptr(),
 
59
  name=output_meta.name, device_type="cpu", device_id=0,
60
  element_type=np.float32, shape=output_shape, buffer_ptr=output_tensor.data_ptr(),
61
  )
 
62
  # --- Inference Function ---
63
 
64
  def tts_inference_io_binding(text: str):
 
71
 
72
  if current_input_shape[1] > input_tensor.shape[1]:
73
  input_tensor = torch.empty(current_input_shape, dtype=torch.int64, device="cpu").contiguous()
74
+ io_binding.bind_input( # Re-bind input
75
  name=input_meta.name, device_type="cpu", device_id=0,
76
  element_type=input_type, shape=current_input_shape,
77
  buffer_ptr=input_tensor.data_ptr(),
 
83
  if required_output_length > output_tensor.shape[2]:
84
  output_shape = (1, 1, required_output_length)
85
  output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
86
+ io_binding.bind_output( # Re-bind output
87
  name=output_meta.name, device_type="cpu", device_id=0,
88
  element_type=np.float32, shape=output_shape,
89
  buffer_ptr=output_tensor.data_ptr(),
90
  )
91
 
92
+ io_binding.clear_binding_outputs() # Clear outputs *after* binding
93
  ort_session.run_with_iobinding(io_binding)
94
  ort_outputs = io_binding.get_outputs()
95
+ output_data = ort_outputs[0].numpy() # Directly use the bound output
96
  return (sampling_rate, output_data.squeeze())
97
 
98
  # --- Gradio Interface ---