|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import requests |
|
import soundfile as sf |
|
import json |
|
import numpy as np |
|
import argparse |
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser( |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
) |
|
|
|
parser.add_argument( |
|
"--server-url", |
|
type=str, |
|
default="localhost:8000", |
|
help="Address of the server", |
|
) |
|
|
|
parser.add_argument( |
|
"--reference-audio", |
|
type=str, |
|
default="../../example/prompt_audio.wav", |
|
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir", |
|
) |
|
|
|
parser.add_argument( |
|
"--reference-text", |
|
type=str, |
|
default="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。", |
|
help="", |
|
) |
|
|
|
parser.add_argument( |
|
"--target-text", |
|
type=str, |
|
default="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。", |
|
help="", |
|
) |
|
|
|
parser.add_argument( |
|
"--model-name", |
|
type=str, |
|
default="spark_tts", |
|
choices=[ |
|
"f5_tts", "spark_tts" |
|
], |
|
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline", |
|
) |
|
|
|
parser.add_argument( |
|
"--output-audio", |
|
type=str, |
|
default="output.wav", |
|
help="Path to save the output audio", |
|
) |
|
return parser.parse_args() |
|
|
|
def prepare_request( |
|
waveform, |
|
reference_text, |
|
target_text, |
|
sample_rate=16000, |
|
padding_duration: int = None, |
|
audio_save_dir: str = "./", |
|
): |
|
assert len(waveform.shape) == 1, "waveform should be 1D" |
|
lengths = np.array([[len(waveform)]], dtype=np.int32) |
|
if padding_duration: |
|
|
|
samples = np.zeros( |
|
( |
|
1, |
|
padding_duration |
|
* sample_rate |
|
* ((int(duration) // padding_duration) + 1), |
|
), |
|
dtype=np.float32, |
|
) |
|
|
|
samples[0, : len(waveform)] = waveform |
|
else: |
|
samples = waveform |
|
|
|
samples = samples.reshape(1, -1).astype(np.float32) |
|
|
|
data = { |
|
"inputs":[ |
|
{ |
|
"name": "reference_wav", |
|
"shape": samples.shape, |
|
"datatype": "FP32", |
|
"data": samples.tolist() |
|
}, |
|
{ |
|
"name": "reference_wav_len", |
|
"shape": lengths.shape, |
|
"datatype": "INT32", |
|
"data": lengths.tolist(), |
|
}, |
|
{ |
|
"name": "reference_text", |
|
"shape": [1, 1], |
|
"datatype": "BYTES", |
|
"data": [reference_text] |
|
}, |
|
{ |
|
"name": "target_text", |
|
"shape": [1, 1], |
|
"datatype": "BYTES", |
|
"data": [target_text] |
|
} |
|
] |
|
} |
|
|
|
return data |
|
|
|
if __name__ == "__main__": |
|
args = get_args() |
|
server_url = args.server_url |
|
if not server_url.startswith(("http://", "https://")): |
|
server_url = f"http://{server_url}" |
|
|
|
url = f"{server_url}/v2/models/{args.model_name}/infer" |
|
waveform, sr = sf.read(args.reference_audio) |
|
assert sr == 16000, "sample rate hardcoded in server" |
|
|
|
samples = np.array(waveform, dtype=np.float32) |
|
data = prepare_request(samples, args.reference_text, args.target_text) |
|
|
|
rsp = requests.post( |
|
url, |
|
headers={"Content-Type": "application/json"}, |
|
json=data, |
|
verify=False, |
|
params={"request_id": '0'} |
|
) |
|
result = rsp.json() |
|
audio = result["outputs"][0]["data"] |
|
audio = np.array(audio, dtype=np.float32) |
|
sf.write(args.output_audio, audio, 16000, "PCM_16") |