File size: 2,137 Bytes
6417dc9
c21b225
6417dc9
 
4368215
 
 
1f4b6af
c21b225
1f4b6af
 
4368215
1f4b6af
 
 
4368215
1f4b6af
 
4368215
6417dc9
 
 
 
c21b225
6417dc9
 
c21b225
6417dc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6b39ee
 
 
6417dc9
 
 
 
 
c21b225
6417dc9
4368215
6417dc9
 
 
 
 
c21b225
 
 
4368215
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import gradio as gr
import torch
import numpy as np
from transformers import AutoTokenizer
import onnxruntime
import scipy.io.wavfile
from huggingface_hub import hf_hub_download

# Define the Hugging Face repository/model ID.
repo_id = "Athspi/Gg"

# Download the ONNX model file from the repository.
# This will download "mms_tts_eng.onnx" from: https://huggingface.co/Athspi/Gg/resolve/main/mms_tts_eng.onnx
onnx_model_path = hf_hub_download(repo_id=repo_id, filename="mms_tts_eng.onnx")

# Load the tokenizer from the repository.
tokenizer = AutoTokenizer.from_pretrained(repo_id)

# Initialize the ONNX runtime session for inference.
ort_session = onnxruntime.InferenceSession(
    onnx_model_path, providers=['CPUExecutionProvider']
)

# Define the fixed sampling rate (adjust if your model uses a different rate)
sampling_rate = 16000

def tts_inference(text: str):
    """
    Convert input text to speech waveform using the ONNX model.
    
    Parameters:
        text (str): Input text to synthesize.
    
    Returns:
        waveform (np.ndarray): Synthesized audio waveform.
        sampling_rate (int): The sampling rate of the waveform.
    """
    # Tokenize the input text.
    inputs = tokenizer(text, return_tensors="pt")
    
    # Prepare inputs for the ONNX model.
    input_ids = inputs.input_ids.cpu().to(torch.long).numpy()
    
    # Run inference on the ONNX model.
    onnx_outputs = ort_session.run(None, {"input_ids": input_ids})
    waveform = onnx_outputs[0]
    
    # Ensure waveform is in float32 format (required by Gradio).
    waveform = waveform.astype(np.float32)

    # Remove unnecessary dimensions.
    waveform = np.squeeze(waveform)
    
    # Return the waveform and its sampling rate.
    return waveform, sampling_rate

# Build a Gradio interface.
iface = gr.Interface(
    fn=tts_inference,
    inputs=gr.Textbox(lines=2, placeholder="Enter text here..."),
    outputs=gr.Audio(type="numpy"),
    title="ONNX TTS Demo",
    description="Text-to-Speech synthesis using an ONNX model from the Athspi/Gg repository on Hugging Face."
)

if __name__ == "__main__":
    iface.launch()