File size: 12,903 Bytes
0b85fb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
"""
Style-Bert-VITS2 の学習・推論に必要な各言語ごとの BERT モデルをロード/取得するためのモジュール。

オリジナルの Bert-VITS2 では各言語ごとの BERT モデルが初回インポート時にハードコードされたパスから「暗黙的に」ロードされているが、
場合によっては多重にロードされて非効率なほか、BERT モデルのロード元のパスがハードコードされているためライブラリ化ができない。

そこで、ライブラリの利用前に、音声合成に利用する言語の BERT モデルだけを「明示的に」ロードできるようにした。
一度 load_model/tokenizer() で当該言語の BERT モデルがロードされていれば、ライブラリ内部のどこからでもロード済みのモデル/トークナイザーを取得できる。
"""

from __future__ import annotations

import gc
import time
from typing import TYPE_CHECKING, cast

from transformers import (
    AutoModelForMaskedLM,
    AutoTokenizer,
    DebertaV2Model,
    DebertaV2TokenizerFast,
    PreTrainedModel,
    PreTrainedTokenizer,
    PreTrainedTokenizerFast,
)

from style_bert_vits2.constants import DEFAULT_BERT_MODEL_PATHS, Languages
from style_bert_vits2.logging import logger
from style_bert_vits2.nlp import onnx_bert_models


if TYPE_CHECKING:
    import torch


# 各言語ごとのロード済みの BERT モデルを格納する辞書
__loaded_models: dict[Languages, PreTrainedModel | DebertaV2Model] = {}

# 各言語ごとのロード済みの BERT トークナイザーを格納する辞書
__loaded_tokenizers: dict[
    Languages,
    PreTrainedTokenizer | PreTrainedTokenizerFast | DebertaV2TokenizerFast,
] = {}


def load_model(
    language: Languages,
    pretrained_model_name_or_path: str | None = None,
    device_map: str
    | dict[str, int | str | torch.device]
    | int
    | torch.device
    | None = None,
    cache_dir: str | None = None,
    revision: str = "main",
) -> PreTrainedModel | DebertaV2Model:
    """
    指定された言語の BERT モデルをロードし、ロード済みの BERT モデルを返す。
    一度ロードされていれば、ロード済みの BERT モデルを即座に返す。
    ライブラリ利用時は常に必ず pretrain_model_name_or_path (Hugging Face のリポジトリ名 or ローカルのファイルパス) を指定する必要がある。
    ロードにはそれなりに時間がかかるため、ライブラリ利用前に明示的に pretrained_model_name_or_path を指定してロードしておくべき。
    device_map は既に指定された言語の BERT モデルがロードされている場合は効果がない。
    cache_dir と revision は pretrain_model_name_or_path がリポジトリ名の場合のみ有効。

    Style-Bert-VITS2 では、BERT モデルに下記の 3 つが利用されている。
    これ以外の BERT モデルを指定した場合は正常に動作しない可能性が高い。
    - 日本語: ku-nlp/deberta-v2-large-japanese-char-wwm
    - 英語: microsoft/deberta-v3-large
    - 中国語: hfl/chinese-roberta-wwm-ext-large

    Args:
        language (Languages): ロードする学習済みモデルの対象言語
        pretrained_model_name_or_path (str | None): ロードする学習済みモデルの名前またはパス。指定しない場合はデフォルトのパスが利用される (デフォルト: None)
        device_map (str | None): accelerate を使用して高速にデバイスにモデルをロードするためのデバイスマップ。
            指定しない場合は通常のモデルロード処理になる (デフォルト: None)
            ref: https://huggingface.co/docs/accelerate/usage_guides/big_modeling
        cache_dir (str | None): モデルのキャッシュディレクトリ。指定しない場合はデフォルトのキャッシュディレクトリが利用される (デフォルト: None)
        revision (str): モデルの Hugging Face 上の Git リビジョン。指定しない場合は最新の main ブランチの内容が利用される (デフォルト: None)

    Returns:
        PreTrainedModel | DebertaV2Model: ロード済みの BERT モデル
    """

    # すでにロード済みの場合はそのまま返す
    if language in __loaded_models:
        return __loaded_models[language]

    # pretrained_model_name_or_path が指定されていない場合はデフォルトのパスを利用
    if pretrained_model_name_or_path is None:
        assert DEFAULT_BERT_MODEL_PATHS[language].exists(), \
            f"The default {language.name} BERT model does not exist on the file system. Please specify the path to the pre-trained model."  # fmt: skip
        pretrained_model_name_or_path = str(DEFAULT_BERT_MODEL_PATHS[language])

    # BERT モデルをロードし、辞書に格納して返す
    ## 英語のみ DebertaV2Model でロードする必要がある
    start_time = time.time()
    if language == Languages.EN:
        __loaded_models[language] = cast(
            DebertaV2Model,
            DebertaV2Model.from_pretrained(
                pretrained_model_name_or_path,
                device_map=device_map,
                cache_dir=cache_dir,
                revision=revision,
            ),
        )
    else:
        __loaded_models[language] = AutoModelForMaskedLM.from_pretrained(
            pretrained_model_name_or_path,
            device_map=device_map,
            cache_dir=cache_dir,
            revision=revision,
        )
    logger.info(
        f"Loaded the {language.name} BERT model from {pretrained_model_name_or_path} ({time.time() - start_time:.2f}s)"
    )

    return __loaded_models[language]


def load_tokenizer(
    language: Languages,
    pretrained_model_name_or_path: str | None = None,
    cache_dir: str | None = None,
    revision: str = "main",
) -> PreTrainedTokenizer | PreTrainedTokenizerFast | DebertaV2TokenizerFast:
    """
    指定された言語の BERT トークナイザーをロードし、ロード済みの BERT トークナイザーを返す。
    一度ロードされていれば、ロード済みの BERT トークナイザーを即座に返す。
    ライブラリ利用時は常に必ず pretrain_model_name_or_path (Hugging Face のリポジトリ名 or ローカルのファイルパス) を指定する必要がある。
    ロードにはそれなりに時間がかかるため、ライブラリ利用前に明示的に pretrained_model_name_or_path を指定してロードしておくべき。
    cache_dir と revision は pretrain_model_name_or_path がリポジトリ名の場合のみ有効。

    Style-Bert-VITS2 では、BERT モデルに下記の 3 つが利用されている。
    これ以外の BERT モデルを指定した場合は正常に動作しない可能性が高い。
    - 日本語: ku-nlp/deberta-v2-large-japanese-char-wwm
    - 英語: microsoft/deberta-v3-large
    - 中国語: hfl/chinese-roberta-wwm-ext-large

    Args:
        language (Languages): ロードする学習済みモデルの対象言語
        pretrained_model_name_or_path (str | None): ロードする学習済みモデルの名前またはパス。指定しない場合はデフォルトのパスが利用される (デフォルト: None)
        cache_dir (str | None): モデルのキャッシュディレクトリ。指定しない場合はデフォルトのキャッシュディレクトリが利用される (デフォルト: None)
        revision (str): モデルの Hugging Face 上の Git リビジョン。指定しない場合は最新の main ブランチの内容が利用される (デフォルト: None)

    Returns:
        PreTrainedTokenizer | PreTrainedTokenizerFast | DebertaV2TokenizerFast: ロード済みの BERT トークナイザー
    """

    # すでにロード済みの場合はそのまま返す
    if language in __loaded_tokenizers:
        return __loaded_tokenizers[language]

    # pretrained_model_name_or_path が指定されていない場合はデフォルトのパスを利用
    if pretrained_model_name_or_path is None:
        # ライブラリ利用時、特例的にこの状況で ONNX 版 BERT トークナイザーがロードされている場合はそのまま返す
        ## ONNX 版 BERT トークナイザー単独で g2p 処理を行うために必要 (各言語の g2p.py はこの関数に依存している)
        ## 設計的には微妙だがこの方が差異を吸収できて手っ取り早い
        if DEFAULT_BERT_MODEL_PATHS[language].exists() is False and onnx_bert_models.is_tokenizer_loaded(language):  # fmt: skip
            return onnx_bert_models.load_tokenizer(language)
        assert DEFAULT_BERT_MODEL_PATHS[language].exists(), \
            f"The default {language.name} BERT tokenizer does not exist on the file system. Please specify the path to the pre-trained model."  # fmt: skip
        pretrained_model_name_or_path = str(DEFAULT_BERT_MODEL_PATHS[language])

    # BERT トークナイザーをロードし、辞書に格納して返す
    ## 英語のみ DebertaV2TokenizerFast でロードする必要がある
    if language == Languages.EN:
        __loaded_tokenizers[language] = DebertaV2TokenizerFast.from_pretrained(
            pretrained_model_name_or_path,
            cache_dir=cache_dir,
            revision=revision,
        )
    else:
        __loaded_tokenizers[language] = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path,
            cache_dir=cache_dir,
            revision=revision,
            use_fast=True,  # デフォルトで True だが念のため明示的に指定
        )
    logger.info(
        f"Loaded the {language.name} BERT tokenizer from {pretrained_model_name_or_path}"
    )

    return __loaded_tokenizers[language]


def transfer_model(language: Languages, device: str) -> None:
    """
    指定された言語の BERT モデルを、指定されたデバイスに移動する。
    モデルのロード後に推論デバイスを変更したい場合に利用する。
    既に指定されたデバイスにモデルがロードされている場合は何も行われない。

    Args:
        language (Languages): モデルを移動する言語
        device (str): モデルを移動するデバイス
    """

    if language not in __loaded_models:
        raise ValueError(f"BERT model for {language.name} is not loaded.")

    # 既に指定されたデバイスにモデルがロードされている場合は何もしない
    # ex: current_device="cuda:0", device="cuda" → 何もしない
    # ex: current_device="cuda:0", device="cpu" → モデルを CPU に移動
    current_device = str(__loaded_models[language].device)
    if current_device.startswith(device):
        return

    __loaded_models[language].to(device)  # type: ignore
    logger.info(
        f"Transferred the {language.name} BERT model from {current_device} to {device}"
    )


def is_model_loaded(language: Languages) -> bool:
    """
    指定された言語の BERT モデルがロード済みかどうかを返す。
    """

    return language in __loaded_models


def is_tokenizer_loaded(language: Languages) -> bool:
    """
    指定された言語の BERT トークナイザーがロード済みかどうかを返す。
    """

    return language in __loaded_tokenizers


def unload_model(language: Languages) -> None:
    """
    指定された言語の BERT モデルをアンロードする。

    Args:
        language (Languages): アンロードする BERT モデルの言語
    """

    import torch

    if language in __loaded_models:
        del __loaded_models[language]
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
        logger.info(f"Unloaded the {language.name} BERT model")


def unload_tokenizer(language: Languages) -> None:
    """
    指定された言語の BERT トークナイザーをアンロードする。

    Args:
        language (Languages): アンロードする BERT トークナイザーの言語
    """

    if language in __loaded_tokenizers:
        del __loaded_tokenizers[language]
        gc.collect()
        logger.info(f"Unloaded the {language.name} BERT tokenizer")


def unload_all_models() -> None:
    """
    すべての BERT モデルをアンロードする。
    """

    for language in list(__loaded_models.keys()):
        unload_model(language)
    logger.info("Unloaded all BERT models")


def unload_all_tokenizers() -> None:
    """
    すべての BERT トークナイザーをアンロードする。
    """

    for language in list(__loaded_tokenizers.keys()):
        unload_tokenizer(language)
    logger.info("Unloaded all BERT tokenizers")