Update app.py
Browse files
app.py
CHANGED
@@ -4,31 +4,17 @@ import numpy as np
|
|
4 |
from transformers import AutoTokenizer
|
5 |
import onnxruntime
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
-
import os
|
8 |
|
9 |
# --- Configuration ---
|
10 |
-
repo_id = "Athspi/Gg"
|
11 |
-
onnx_filename = "mms_tts_eng.onnx"
|
12 |
sampling_rate = 16000
|
13 |
|
14 |
-
# --- Download ONNX Model
|
15 |
-
|
16 |
-
# Option 1: Use the cached path (Recommended)
|
17 |
onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename)
|
18 |
print(f"ONNX model downloaded to (cache): {onnx_model_path}")
|
19 |
|
20 |
-
# Option 2: Download to a specific directory (e.g., the current working directory)
|
21 |
-
# output_dir = "." # Current directory
|
22 |
-
# onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename, cache_dir=output_dir)
|
23 |
-
# print(f"ONNX model downloaded to: {onnx_model_path}")
|
24 |
-
|
25 |
-
# Option 3: Download to a custom directory:
|
26 |
-
# output_dir = "models" # Or any directory you want
|
27 |
-
# os.makedirs(output_dir, exist_ok=True) # Create directory if it doesn't exist
|
28 |
-
# onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename, cache_dir=output_dir)
|
29 |
-
# print(f"ONNX model downloaded to: {onnx_model_path}")
|
30 |
-
|
31 |
-
|
32 |
# --- Load Tokenizer ---
|
33 |
tokenizer = AutoTokenizer.from_pretrained(repo_id)
|
34 |
|
@@ -64,6 +50,7 @@ max_output_length = input_shape[1] * 10
|
|
64 |
output_shape = (1, 1, max_output_length)
|
65 |
output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
|
66 |
|
|
|
67 |
io_binding.bind_input(
|
68 |
name=input_meta.name, device_type="cpu", device_id=0,
|
69 |
element_type=input_type, shape=input_shape, buffer_ptr=input_tensor.data_ptr(),
|
@@ -72,7 +59,6 @@ io_binding.bind_output(
|
|
72 |
name=output_meta.name, device_type="cpu", device_id=0,
|
73 |
element_type=np.float32, shape=output_shape, buffer_ptr=output_tensor.data_ptr(),
|
74 |
)
|
75 |
-
|
76 |
# --- Inference Function ---
|
77 |
|
78 |
def tts_inference_io_binding(text: str):
|
@@ -85,7 +71,7 @@ def tts_inference_io_binding(text: str):
|
|
85 |
|
86 |
if current_input_shape[1] > input_tensor.shape[1]:
|
87 |
input_tensor = torch.empty(current_input_shape, dtype=torch.int64, device="cpu").contiguous()
|
88 |
-
io_binding.bind_input(
|
89 |
name=input_meta.name, device_type="cpu", device_id=0,
|
90 |
element_type=input_type, shape=current_input_shape,
|
91 |
buffer_ptr=input_tensor.data_ptr(),
|
@@ -97,16 +83,16 @@ def tts_inference_io_binding(text: str):
|
|
97 |
if required_output_length > output_tensor.shape[2]:
|
98 |
output_shape = (1, 1, required_output_length)
|
99 |
output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
|
100 |
-
io_binding.bind_output(
|
101 |
name=output_meta.name, device_type="cpu", device_id=0,
|
102 |
element_type=np.float32, shape=output_shape,
|
103 |
buffer_ptr=output_tensor.data_ptr(),
|
104 |
)
|
105 |
|
106 |
-
io_binding.clear_binding_outputs()
|
107 |
ort_session.run_with_iobinding(io_binding)
|
108 |
ort_outputs = io_binding.get_outputs()
|
109 |
-
output_data = ort_outputs[0].numpy()
|
110 |
return (sampling_rate, output_data.squeeze())
|
111 |
|
112 |
# --- Gradio Interface ---
|
|
|
4 |
from transformers import AutoTokenizer
|
5 |
import onnxruntime
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
+
import os
|
8 |
|
9 |
# --- Configuration ---
|
10 |
+
repo_id = "Athspi/Gg"
|
11 |
+
onnx_filename = "mms_tts_eng.onnx"
|
12 |
sampling_rate = 16000
|
13 |
|
14 |
+
# --- Download ONNX Model ---
|
|
|
|
|
15 |
onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename)
|
16 |
print(f"ONNX model downloaded to (cache): {onnx_model_path}")
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
# --- Load Tokenizer ---
|
19 |
tokenizer = AutoTokenizer.from_pretrained(repo_id)
|
20 |
|
|
|
50 |
output_shape = (1, 1, max_output_length)
|
51 |
output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
|
52 |
|
53 |
+
# Bind BEFORE clear_binding_outputs
|
54 |
io_binding.bind_input(
|
55 |
name=input_meta.name, device_type="cpu", device_id=0,
|
56 |
element_type=input_type, shape=input_shape, buffer_ptr=input_tensor.data_ptr(),
|
|
|
59 |
name=output_meta.name, device_type="cpu", device_id=0,
|
60 |
element_type=np.float32, shape=output_shape, buffer_ptr=output_tensor.data_ptr(),
|
61 |
)
|
|
|
62 |
# --- Inference Function ---
|
63 |
|
64 |
def tts_inference_io_binding(text: str):
|
|
|
71 |
|
72 |
if current_input_shape[1] > input_tensor.shape[1]:
|
73 |
input_tensor = torch.empty(current_input_shape, dtype=torch.int64, device="cpu").contiguous()
|
74 |
+
io_binding.bind_input( # Re-bind input
|
75 |
name=input_meta.name, device_type="cpu", device_id=0,
|
76 |
element_type=input_type, shape=current_input_shape,
|
77 |
buffer_ptr=input_tensor.data_ptr(),
|
|
|
83 |
if required_output_length > output_tensor.shape[2]:
|
84 |
output_shape = (1, 1, required_output_length)
|
85 |
output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
|
86 |
+
io_binding.bind_output( # Re-bind output
|
87 |
name=output_meta.name, device_type="cpu", device_id=0,
|
88 |
element_type=np.float32, shape=output_shape,
|
89 |
buffer_ptr=output_tensor.data_ptr(),
|
90 |
)
|
91 |
|
92 |
+
io_binding.clear_binding_outputs() # Clear outputs *after* binding
|
93 |
ort_session.run_with_iobinding(io_binding)
|
94 |
ort_outputs = io_binding.get_outputs()
|
95 |
+
output_data = ort_outputs[0].numpy() # Directly use the bound output
|
96 |
return (sampling_rate, output_data.squeeze())
|
97 |
|
98 |
# --- Gradio Interface ---
|