File size: 4,337 Bytes
edc06cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""キャラクターに関係したリクエストにかかる時間の測定"""

import argparse
from pathlib import Path
from test.benchmark.engine_preparation import ServerType, generate_client
from test.benchmark.speed.utility import benchmark_time


def benchmark_get_speakers(server: ServerType, root_dir: Path | None = None) -> float:
    """`GET /speakers` にかかる時間を測定する。"""

    client = generate_client(server, root_dir)

    def execute() -> None:
        """計測対象となる処理を実行する"""
        client.get("/speakers", params={})

    average_time = benchmark_time(execute, n_repeat=10)
    return average_time


def benchmark_get_speaker_info_all(
    server: ServerType, root_dir: Path | None = None
) -> float:
    """全ての喋れるキャラクターへの `GET /speaker_info` にかかる時間を測定する。"""

    client = generate_client(server, root_dir)

    # speaker_uuid 一覧を準備
    response = client.get("/speakers", params={})
    assert response.status_code == 200
    talk_characters = response.json()
    uuids = list(map(lambda c: c["speaker_uuid"], talk_characters))

    def execute() -> None:
        """計測対象となる処理を実行する"""
        for uuid in uuids:
            client.get("/speaker_info", params={"speaker_uuid": uuid})

    average_time = benchmark_time(execute, n_repeat=10)
    return average_time


def benchmark_request_time_for_all_talk_characters(
    server: ServerType, root_dir: Path | None = None
) -> float:
    """
    喋れるキャラクターの数と同じ回数の `GET /` にかかる時間を測定する。
    `GET /` はエンジン内部処理が最小であるため、全ての喋れるキャラクター分のリクエスト-レスポンス(ネットワーク処理部分)にかかる時間を擬似的に計測できる。
    """

    client = generate_client(server, root_dir)

    # speaker_uuid 一覧を準備
    response = client.get("/speakers", params={})
    assert response.status_code == 200
    talk_characters = response.json()
    uuids = list(map(lambda c: c["speaker_uuid"], talk_characters))

    def execute() -> None:
        """計測対象となる処理を実行する"""
        for _ in uuids:
            client.get("/", params={})

    average_time = benchmark_time(execute, n_repeat=10)
    return average_time


if __name__ == "__main__":
    # 実行コマンドは `python -m test.benchmark.speed.speaker` である。
    # `server="localhost"` の場合、本ベンチマーク実行に先立ってエンジン起動が必要である。
    # エンジン起動コマンドの一例として以下を示す。
    # (別プロセスで)`python run.py --voicevox_dir=VOICEVOX/vv-engine`

    parser = argparse.ArgumentParser()
    parser.add_argument("--voicevox_dir", type=Path)
    args = parser.parse_args()
    root_dir: Path | None = args.voicevox_dir

    result_speakers_fakeserve = benchmark_get_speakers("fake", root_dir)
    result_speakers_localhost = benchmark_get_speakers("localhost", root_dir)
    print("`GET /speakers` fakeserve: {:.4f} sec".format(result_speakers_fakeserve))
    print("`GET /speakers` localhost: {:.4f} sec".format(result_speakers_localhost))

    _result_spk_infos_fakeserve = benchmark_get_speaker_info_all("fake", root_dir)
    _result_spk_infos_localhost = benchmark_get_speaker_info_all("localhost", root_dir)
    result_spk_infos_fakeserve = "{:.3f}".format(_result_spk_infos_fakeserve)
    result_spk_infos_localhost = "{:.3f}".format(_result_spk_infos_localhost)
    print(
        f"全ての喋れるキャラクター `GET /speaker_info` fakeserve: {result_spk_infos_fakeserve} sec"
    )
    print(
        f"全ての喋れるキャラクター `GET /speaker_info` localhost: {result_spk_infos_localhost} sec"
    )

    req_time_all_fake = benchmark_request_time_for_all_talk_characters("fake", root_dir)
    req_time_all_local = benchmark_request_time_for_all_talk_characters(
        "localhost", root_dir
    )
    print(
        "全ての喋れるキャラクター `GET /` fakeserve: {:.3f} sec".format(
            req_time_all_fake
        )
    )
    print(
        "全ての喋れるキャラクター `GET /` localhost: {:.3f} sec".format(
            req_time_all_local
        )
    )