|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This script supports to load dataset from huggingface and sends it to the server |
|
for decoding, in parallel. |
|
|
|
Usage: |
|
num_task=2 |
|
|
|
# For offline F5-TTS |
|
python3 client_grpc.py \ |
|
--server-addr localhost \ |
|
--model-name f5_tts \ |
|
--num-tasks $num_task \ |
|
--huggingface-dataset yuekai/seed_tts \ |
|
--split-name test_zh \ |
|
--log-dir ./log_concurrent_tasks_${num_task} |
|
|
|
# For offline Spark-TTS-0.5B |
|
python3 client_grpc.py \ |
|
--server-addr localhost \ |
|
--model-name spark_tts \ |
|
--num-tasks $num_task \ |
|
--huggingface-dataset yuekai/seed_tts \ |
|
--split-name wenetspeech4tts \ |
|
--log-dir ./log_concurrent_tasks_${num_task} |
|
""" |
|
|
|
import argparse |
|
import asyncio |
|
import json |
|
import queue |
|
import uuid |
|
import functools |
|
|
|
import os |
|
import time |
|
import types |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import soundfile as sf |
|
import tritonclient |
|
import tritonclient.grpc.aio as grpcclient_aio |
|
import tritonclient.grpc as grpcclient_sync |
|
from tritonclient.utils import np_to_triton_dtype, InferenceServerException |
|
|
|
|
|
|
|
class UserData: |
|
def __init__(self): |
|
self._completed_requests = queue.Queue() |
|
self._first_chunk_time = None |
|
self._start_time = None |
|
|
|
def record_start_time(self): |
|
self._start_time = time.time() |
|
|
|
def get_first_chunk_latency(self): |
|
if self._first_chunk_time and self._start_time: |
|
return self._first_chunk_time - self._start_time |
|
return None |
|
|
|
def callback(user_data, result, error): |
|
if user_data._first_chunk_time is None and not error: |
|
user_data._first_chunk_time = time.time() |
|
if error: |
|
user_data._completed_requests.put(error) |
|
else: |
|
user_data._completed_requests.put(result) |
|
|
|
|
|
|
|
def write_triton_stats(stats, summary_file): |
|
with open(summary_file, "w") as summary_f: |
|
model_stats = stats["model_stats"] |
|
|
|
summary_f.write( |
|
"The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n" |
|
) |
|
summary_f.write("To learn more about the log, please refer to: \n") |
|
summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n") |
|
summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n") |
|
summary_f.write( |
|
"To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n" |
|
) |
|
summary_f.write( |
|
"However, there is a trade-off between the increased queue time and the increased batch size. \n" |
|
) |
|
summary_f.write( |
|
"You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n" |
|
) |
|
summary_f.write( |
|
"See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n" |
|
) |
|
for model_state in model_stats: |
|
if "last_inference" not in model_state: |
|
continue |
|
summary_f.write(f"model name is {model_state['name']} \n") |
|
model_inference_stats = model_state["inference_stats"] |
|
total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9 |
|
total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9 |
|
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9 |
|
total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9 |
|
summary_f.write( |
|
f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" |
|
) |
|
model_batch_stats = model_state["batch_stats"] |
|
for batch in model_batch_stats: |
|
batch_size = int(batch["batch_size"]) |
|
compute_input = batch["compute_input"] |
|
compute_output = batch["compute_output"] |
|
compute_infer = batch["compute_infer"] |
|
batch_count = int(compute_infer["count"]) |
|
assert compute_infer["count"] == compute_output["count"] == compute_input["count"] |
|
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6 |
|
compute_input_time_ms = int(compute_input["ns"]) / 1e6 |
|
compute_output_time_ms = int(compute_output["ns"]) / 1e6 |
|
summary_f.write( |
|
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" |
|
) |
|
summary_f.write( |
|
f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " |
|
) |
|
summary_f.write( |
|
f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" |
|
) |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
|
|
|
parser.add_argument( |
|
"--server-addr", |
|
type=str, |
|
default="localhost", |
|
help="Address of the server", |
|
) |
|
|
|
parser.add_argument( |
|
"--server-port", |
|
type=int, |
|
default=8001, |
|
help="Grpc port of the triton server, default is 8001", |
|
) |
|
|
|
parser.add_argument( |
|
"--reference-audio", |
|
type=str, |
|
default=None, |
|
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir", |
|
) |
|
|
|
parser.add_argument( |
|
"--reference-text", |
|
type=str, |
|
default="", |
|
help="", |
|
) |
|
|
|
parser.add_argument( |
|
"--target-text", |
|
type=str, |
|
default="", |
|
help="", |
|
) |
|
|
|
parser.add_argument( |
|
"--huggingface-dataset", |
|
type=str, |
|
default="yuekai/seed_tts", |
|
help="dataset name in huggingface dataset hub", |
|
) |
|
|
|
parser.add_argument( |
|
"--split-name", |
|
type=str, |
|
default="wenetspeech4tts", |
|
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], |
|
help="dataset split name, default is 'test'", |
|
) |
|
|
|
parser.add_argument( |
|
"--manifest-path", |
|
type=str, |
|
default=None, |
|
help="Path to the manifest dir which includes wav.scp trans.txt files.", |
|
) |
|
|
|
parser.add_argument( |
|
"--model-name", |
|
type=str, |
|
default="f5_tts", |
|
choices=["f5_tts", "spark_tts"], |
|
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline", |
|
) |
|
|
|
parser.add_argument( |
|
"--num-tasks", |
|
type=int, |
|
default=1, |
|
help="Number of concurrent tasks for sending", |
|
) |
|
|
|
parser.add_argument( |
|
"--log-interval", |
|
type=int, |
|
default=5, |
|
help="Controls how frequently we print the log.", |
|
) |
|
|
|
parser.add_argument( |
|
"--compute-wer", |
|
action="store_true", |
|
default=False, |
|
help="""True to compute WER. |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--log-dir", |
|
type=str, |
|
required=False, |
|
default="./tmp", |
|
help="log directory", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--mode", |
|
type=str, |
|
default="offline", |
|
choices=["offline", "streaming"], |
|
help="Select offline or streaming benchmark mode." |
|
) |
|
parser.add_argument( |
|
"--chunk-overlap-duration", |
|
type=float, |
|
default=0.1, |
|
help="Chunk overlap duration for streaming reconstruction (in seconds)." |
|
) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
def load_audio(wav_path, target_sample_rate=16000): |
|
assert target_sample_rate == 16000, "hard coding in server" |
|
if isinstance(wav_path, dict): |
|
waveform = wav_path["array"] |
|
sample_rate = wav_path["sampling_rate"] |
|
else: |
|
waveform, sample_rate = sf.read(wav_path) |
|
if sample_rate != target_sample_rate: |
|
from scipy.signal import resample |
|
|
|
num_samples = int(len(waveform) * (target_sample_rate / sample_rate)) |
|
waveform = resample(waveform, num_samples) |
|
return waveform, target_sample_rate |
|
|
|
def prepare_request_input_output( |
|
protocol_client, |
|
waveform, |
|
reference_text, |
|
target_text, |
|
sample_rate=16000, |
|
padding_duration: int = None |
|
): |
|
"""Prepares inputs for Triton inference (offline or streaming).""" |
|
assert len(waveform.shape) == 1, "waveform should be 1D" |
|
lengths = np.array([[len(waveform)]], dtype=np.int32) |
|
|
|
|
|
if padding_duration: |
|
duration = len(waveform) / sample_rate |
|
|
|
|
|
if reference_text: |
|
estimated_target_duration = duration / len(reference_text) * len(target_text) |
|
else: |
|
estimated_target_duration = duration |
|
|
|
|
|
required_total_samples = padding_duration * sample_rate * ( |
|
(int(estimated_target_duration + duration) // padding_duration) + 1 |
|
) |
|
samples = np.zeros((1, required_total_samples), dtype=np.float32) |
|
samples[0, : len(waveform)] = waveform |
|
else: |
|
|
|
samples = waveform.reshape(1, -1).astype(np.float32) |
|
|
|
|
|
inputs = [ |
|
protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)), |
|
protocol_client.InferInput( |
|
"reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype) |
|
), |
|
protocol_client.InferInput("reference_text", [1, 1], "BYTES"), |
|
protocol_client.InferInput("target_text", [1, 1], "BYTES"), |
|
] |
|
inputs[0].set_data_from_numpy(samples) |
|
inputs[1].set_data_from_numpy(lengths) |
|
|
|
input_data_numpy = np.array([reference_text], dtype=object) |
|
input_data_numpy = input_data_numpy.reshape((1, 1)) |
|
inputs[2].set_data_from_numpy(input_data_numpy) |
|
|
|
input_data_numpy = np.array([target_text], dtype=object) |
|
input_data_numpy = input_data_numpy.reshape((1, 1)) |
|
inputs[3].set_data_from_numpy(input_data_numpy) |
|
|
|
outputs = [protocol_client.InferRequestedOutput("waveform")] |
|
|
|
return inputs, outputs |
|
|
|
def run_sync_streaming_inference( |
|
sync_triton_client: tritonclient.grpc.InferenceServerClient, |
|
model_name: str, |
|
inputs: list, |
|
outputs: list, |
|
request_id: str, |
|
user_data: UserData, |
|
chunk_overlap_duration: float, |
|
save_sample_rate: int, |
|
audio_save_path: str, |
|
): |
|
"""Helper function to run the blocking sync streaming call.""" |
|
start_time_total = time.time() |
|
user_data.record_start_time() |
|
|
|
|
|
sync_triton_client.start_stream(callback=functools.partial(callback, user_data)) |
|
|
|
|
|
sync_triton_client.async_stream_infer( |
|
model_name, |
|
inputs, |
|
request_id=request_id, |
|
outputs=outputs, |
|
enable_empty_final_response=True, |
|
) |
|
|
|
|
|
audios = [] |
|
while True: |
|
try: |
|
result = user_data._completed_requests.get() |
|
if isinstance(result, InferenceServerException): |
|
print(f"Received InferenceServerException: {result}") |
|
sync_triton_client.stop_stream() |
|
return None, None, None |
|
|
|
response = result.get_response() |
|
final = response.parameters["triton_final_response"].bool_param |
|
if final is True: |
|
break |
|
|
|
audio_chunk = result.as_numpy("waveform").reshape(-1) |
|
if audio_chunk.size > 0: |
|
audios.append(audio_chunk) |
|
else: |
|
print("Warning: received empty audio chunk.") |
|
|
|
except queue.Empty: |
|
print(f"Timeout waiting for response for request id {request_id}") |
|
sync_triton_client.stop_stream() |
|
return None, None, None |
|
|
|
sync_triton_client.stop_stream() |
|
end_time_total = time.time() |
|
total_request_latency = end_time_total - start_time_total |
|
first_chunk_latency = user_data.get_first_chunk_latency() |
|
|
|
|
|
actual_duration = 0 |
|
if audios: |
|
cross_fade_samples = int(chunk_overlap_duration * save_sample_rate) |
|
fade_out = np.linspace(1, 0, cross_fade_samples) |
|
fade_in = np.linspace(0, 1, cross_fade_samples) |
|
reconstructed_audio = None |
|
|
|
|
|
if not audios: |
|
print("Warning: No audio chunks received.") |
|
reconstructed_audio = np.array([], dtype=np.float32) |
|
elif len(audios) == 1: |
|
reconstructed_audio = audios[0] |
|
else: |
|
reconstructed_audio = audios[0][:-cross_fade_samples] |
|
for i in range(1, len(audios)): |
|
|
|
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in + |
|
audios[i - 1][-cross_fade_samples:] * fade_out) |
|
|
|
middle_part = audios[i][cross_fade_samples:-cross_fade_samples] |
|
|
|
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part]) |
|
|
|
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]]) |
|
|
|
if reconstructed_audio is not None and reconstructed_audio.size > 0: |
|
actual_duration = len(reconstructed_audio) / save_sample_rate |
|
|
|
os.makedirs(os.path.dirname(audio_save_path), exist_ok=True) |
|
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16") |
|
else: |
|
print("Warning: No audio chunks received or reconstructed.") |
|
actual_duration = 0 |
|
|
|
else: |
|
print("Warning: No audio chunks received.") |
|
actual_duration = 0 |
|
|
|
return total_request_latency, first_chunk_latency, actual_duration |
|
|
|
|
|
async def send_streaming( |
|
manifest_item_list: list, |
|
name: str, |
|
server_url: str, |
|
protocol_client: types.ModuleType, |
|
log_interval: int, |
|
model_name: str, |
|
audio_save_dir: str = "./", |
|
save_sample_rate: int = 16000, |
|
chunk_overlap_duration: float = 0.1, |
|
padding_duration: int = None, |
|
): |
|
total_duration = 0.0 |
|
latency_data = [] |
|
task_id = int(name[5:]) |
|
sync_triton_client = None |
|
|
|
try: |
|
print(f"{name}: Initializing sync client for streaming...") |
|
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) |
|
|
|
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.") |
|
for i, item in enumerate(manifest_item_list): |
|
if i % log_interval == 0: |
|
print(f"{name}: Processing item {i}/{len(manifest_item_list)}") |
|
|
|
try: |
|
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000) |
|
reference_text, target_text = item["reference_text"], item["target_text"] |
|
|
|
inputs, outputs = prepare_request_input_output( |
|
protocol_client, |
|
waveform, |
|
reference_text, |
|
target_text, |
|
sample_rate, |
|
padding_duration=padding_duration |
|
) |
|
request_id = str(uuid.uuid4()) |
|
user_data = UserData() |
|
|
|
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav") |
|
|
|
total_request_latency, first_chunk_latency, actual_duration = await asyncio.to_thread( |
|
run_sync_streaming_inference, |
|
sync_triton_client, |
|
model_name, |
|
inputs, |
|
outputs, |
|
request_id, |
|
user_data, |
|
chunk_overlap_duration, |
|
save_sample_rate, |
|
audio_save_path |
|
) |
|
|
|
if total_request_latency is not None: |
|
print(f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s") |
|
latency_data.append((total_request_latency, first_chunk_latency, actual_duration)) |
|
total_duration += actual_duration |
|
else: |
|
print(f"{name}: Item {i} failed.") |
|
|
|
|
|
except FileNotFoundError: |
|
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}") |
|
except Exception as e: |
|
print(f"Error processing item {i} ({item['target_audio_path']}): {e}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
|
|
finally: |
|
if sync_triton_client: |
|
try: |
|
print(f"{name}: Closing sync client...") |
|
sync_triton_client.close() |
|
except Exception as e: |
|
print(f"{name}: Error closing sync client: {e}") |
|
|
|
|
|
print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s") |
|
return total_duration, latency_data |
|
|
|
async def send( |
|
manifest_item_list: list, |
|
name: str, |
|
triton_client: tritonclient.grpc.aio.InferenceServerClient, |
|
protocol_client: types.ModuleType, |
|
log_interval: int, |
|
model_name: str, |
|
padding_duration: int = None, |
|
audio_save_dir: str = "./", |
|
save_sample_rate: int = 16000, |
|
): |
|
total_duration = 0.0 |
|
latency_data = [] |
|
task_id = int(name[5:]) |
|
|
|
print(f"manifest_item_list: {manifest_item_list}") |
|
for i, item in enumerate(manifest_item_list): |
|
if i % log_interval == 0: |
|
print(f"{name}: {i}/{len(manifest_item_list)}") |
|
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000) |
|
reference_text, target_text = item["reference_text"], item["target_text"] |
|
|
|
inputs, outputs = prepare_request_input_output( |
|
protocol_client, |
|
waveform, |
|
reference_text, |
|
target_text, |
|
sample_rate, |
|
padding_duration=padding_duration |
|
) |
|
sequence_id = 100000000 + i + task_id * 10 |
|
start = time.time() |
|
response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs) |
|
|
|
audio = response.as_numpy("waveform").reshape(-1) |
|
actual_duration = len(audio) / save_sample_rate |
|
|
|
end = time.time() - start |
|
|
|
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav") |
|
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16") |
|
|
|
latency_data.append((end, actual_duration)) |
|
total_duration += actual_duration |
|
|
|
return total_duration, latency_data |
|
|
|
|
|
def load_manifests(manifest_path): |
|
with open(manifest_path, "r") as f: |
|
manifest_list = [] |
|
for line in f: |
|
assert len(line.strip().split("|")) == 4 |
|
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") |
|
utt = Path(utt).stem |
|
|
|
if not os.path.isabs(prompt_wav): |
|
prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav) |
|
manifest_list.append( |
|
{ |
|
"audio_filepath": prompt_wav, |
|
"reference_text": prompt_text, |
|
"target_text": gt_text, |
|
"target_audio_path": utt, |
|
} |
|
) |
|
return manifest_list |
|
|
|
|
|
def split_data(data, k): |
|
n = len(data) |
|
if n < k: |
|
print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.") |
|
k = n |
|
|
|
quotient = n // k |
|
remainder = n % k |
|
|
|
result = [] |
|
start = 0 |
|
for i in range(k): |
|
if i < remainder: |
|
end = start + quotient + 1 |
|
else: |
|
end = start + quotient |
|
|
|
result.append(data[start:end]) |
|
start = end |
|
|
|
return result |
|
|
|
async def main(): |
|
args = get_args() |
|
url = f"{args.server_addr}:{args.server_port}" |
|
|
|
|
|
triton_client = None |
|
protocol_client = None |
|
if args.mode == "offline": |
|
print("Initializing gRPC client for offline mode...") |
|
|
|
triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False) |
|
protocol_client = grpcclient_aio |
|
elif args.mode == "streaming": |
|
print("Initializing gRPC client for streaming mode...") |
|
|
|
|
|
|
|
protocol_client = grpcclient_sync |
|
else: |
|
raise ValueError(f"Invalid mode: {args.mode}") |
|
|
|
|
|
if args.reference_audio: |
|
args.num_tasks = 1 |
|
args.log_interval = 1 |
|
manifest_item_list = [ |
|
{ |
|
"reference_text": args.reference_text, |
|
"target_text": args.target_text, |
|
"audio_filepath": args.reference_audio, |
|
"target_audio_path": "test", |
|
} |
|
] |
|
elif args.huggingface_dataset: |
|
import datasets |
|
|
|
dataset = datasets.load_dataset( |
|
args.huggingface_dataset, |
|
split=args.split_name, |
|
trust_remote_code=True, |
|
) |
|
manifest_item_list = [] |
|
for i in range(len(dataset)): |
|
manifest_item_list.append( |
|
{ |
|
"audio_filepath": dataset[i]["prompt_audio"], |
|
"reference_text": dataset[i]["prompt_text"], |
|
"target_audio_path": dataset[i]["id"], |
|
"target_text": dataset[i]["target_text"], |
|
} |
|
) |
|
else: |
|
manifest_item_list = load_manifests(args.manifest_path) |
|
|
|
num_tasks = min(args.num_tasks, len(manifest_item_list)) |
|
manifest_item_list = split_data(manifest_item_list, num_tasks) |
|
|
|
os.makedirs(args.log_dir, exist_ok=True) |
|
tasks = [] |
|
start_time = time.time() |
|
for i in range(num_tasks): |
|
|
|
if args.mode == "offline": |
|
task = asyncio.create_task( |
|
send( |
|
manifest_item_list[i], |
|
name=f"task-{i}", |
|
triton_client=triton_client, |
|
protocol_client=protocol_client, |
|
log_interval=args.log_interval, |
|
model_name=args.model_name, |
|
audio_save_dir=args.log_dir, |
|
padding_duration=1, |
|
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000, |
|
) |
|
) |
|
elif args.mode == "streaming": |
|
task = asyncio.create_task( |
|
send_streaming( |
|
manifest_item_list[i], |
|
name=f"task-{i}", |
|
server_url=url, |
|
protocol_client=protocol_client, |
|
log_interval=args.log_interval, |
|
model_name=args.model_name, |
|
audio_save_dir=args.log_dir, |
|
padding_duration=10, |
|
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000, |
|
chunk_overlap_duration=args.chunk_overlap_duration, |
|
) |
|
) |
|
|
|
tasks.append(task) |
|
|
|
ans_list = await asyncio.gather(*tasks) |
|
|
|
end_time = time.time() |
|
elapsed = end_time - start_time |
|
|
|
total_duration = 0.0 |
|
latency_data = [] |
|
for ans in ans_list: |
|
if ans: |
|
total_duration += ans[0] |
|
latency_data.extend(ans[1]) |
|
else: |
|
print("Warning: A task returned None, possibly due to an error.") |
|
|
|
|
|
if total_duration == 0: |
|
print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.") |
|
rtf = float('inf') |
|
else: |
|
rtf = elapsed / total_duration |
|
|
|
s = f"Mode: {args.mode}\n" |
|
s += f"RTF: {rtf:.4f}\n" |
|
s += f"total_duration: {total_duration:.3f} seconds\n" |
|
s += f"({total_duration / 3600:.2f} hours)\n" |
|
s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n" |
|
|
|
|
|
if latency_data: |
|
if args.mode == "offline": |
|
|
|
latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data] |
|
if latency_list: |
|
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0 |
|
latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0 |
|
s += f"latency_variance: {latency_variance:.2f}\n" |
|
s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n" |
|
s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n" |
|
s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n" |
|
s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n" |
|
s += f"average_latency_ms: {latency_ms:.2f}\n" |
|
else: |
|
s += "No latency data collected for offline mode.\n" |
|
|
|
elif args.mode == "streaming": |
|
|
|
total_latency_list = [total for (total, first, duration) in latency_data if total is not None] |
|
first_chunk_latency_list = [first for (total, first, duration) in latency_data if first is not None] |
|
|
|
s += "\n--- Total Request Latency ---\n" |
|
if total_latency_list: |
|
avg_total_latency_ms = sum(total_latency_list) / len(total_latency_list) * 1000.0 |
|
variance_total_latency = np.var(total_latency_list, dtype=np.float64) * 1000.0 |
|
s += f"total_request_latency_variance: {variance_total_latency:.2f}\n" |
|
s += f"total_request_latency_50_percentile_ms: {np.percentile(total_latency_list, 50) * 1000.0:.2f}\n" |
|
s += f"total_request_latency_90_percentile_ms: {np.percentile(total_latency_list, 90) * 1000.0:.2f}\n" |
|
s += f"total_request_latency_95_percentile_ms: {np.percentile(total_latency_list, 95) * 1000.0:.2f}\n" |
|
s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n" |
|
s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n" |
|
else: |
|
s += "No total request latency data collected.\n" |
|
|
|
s += "\n--- First Chunk Latency ---\n" |
|
if first_chunk_latency_list: |
|
avg_first_chunk_latency_ms = sum(first_chunk_latency_list) / len(first_chunk_latency_list) * 1000.0 |
|
variance_first_chunk_latency = np.var(first_chunk_latency_list, dtype=np.float64) * 1000.0 |
|
s += f"first_chunk_latency_variance: {variance_first_chunk_latency:.2f}\n" |
|
s += f"first_chunk_latency_50_percentile_ms: {np.percentile(first_chunk_latency_list, 50) * 1000.0:.2f}\n" |
|
s += f"first_chunk_latency_90_percentile_ms: {np.percentile(first_chunk_latency_list, 90) * 1000.0:.2f}\n" |
|
s += f"first_chunk_latency_95_percentile_ms: {np.percentile(first_chunk_latency_list, 95) * 1000.0:.2f}\n" |
|
s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n" |
|
s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n" |
|
else: |
|
s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n" |
|
else: |
|
s += "No latency data collected.\n" |
|
|
|
|
|
print(s) |
|
if args.manifest_path: |
|
name = Path(args.manifest_path).stem |
|
elif args.split_name: |
|
name = args.split_name |
|
elif args.reference_audio: |
|
name = Path(args.reference_audio).stem |
|
else: |
|
name = "results" |
|
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f: |
|
f.write(s) |
|
|
|
|
|
|
|
stats_client = None |
|
try: |
|
print("Initializing temporary async client for fetching stats...") |
|
stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False) |
|
print("Fetching inference statistics...") |
|
|
|
stats = await stats_client.get_inference_statistics(model_name="", as_json=True) |
|
print("Fetching model config...") |
|
metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True) |
|
|
|
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt") |
|
|
|
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f: |
|
json.dump(metadata, f, indent=4) |
|
|
|
except Exception as e: |
|
print(f"Could not retrieve statistics or config: {e}") |
|
finally: |
|
if stats_client: |
|
try: |
|
print("Closing temporary async stats client...") |
|
await stats_client.close() |
|
except Exception as e: |
|
print(f"Error closing async stats client: {e}") |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
async def run_main(): |
|
try: |
|
await main() |
|
except Exception as e: |
|
print(f"An error occurred in main: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
asyncio.run(run_main()) |
|
|