Spaces:
Running
Running
| """ | |
| 放置公用模型 | |
| """ | |
| import gc | |
| import logging | |
| import os | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForMaskedLM, BertTokenizer, MegatronBertModel | |
| from contants import config | |
| from utils.download import download_file | |
| from bert_vits2.text.chinese_bert import get_bert_feature as zh_bert | |
| from bert_vits2.text.english_bert_mock import get_bert_feature as en_bert | |
| from bert_vits2.text.japanese_bert import get_bert_feature as ja_bert | |
| from bert_vits2.text.japanese_bert_v111 import get_bert_feature as ja_bert_v111 | |
| from bert_vits2.text.japanese_bert_v200 import get_bert_feature as ja_bert_v200 | |
| from bert_vits2.text.english_bert_mock_v200 import get_bert_feature as en_bert_v200 | |
| from bert_vits2.text.chinese_bert_extra import get_bert_feature as zh_bert_extra | |
| from bert_vits2.text.japanese_bert_extra import get_bert_feature as ja_bert_extra | |
| class ModelHandler: | |
| def __init__(self, device=config.system.device): | |
| self.DOWNLOAD_PATHS = { | |
| "CHINESE_ROBERTA_WWM_EXT_LARGE": [ | |
| "https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/resolve/main/pytorch_model.bin", | |
| "https://hf-mirror.com/hfl/chinese-roberta-wwm-ext-large/resolve/main/pytorch_model.bin", | |
| ], | |
| "BERT_BASE_JAPANESE_V3": [ | |
| "https://huggingface.co/cl-tohoku/bert-base-japanese-v3/resolve/main/pytorch_model.bin", | |
| "https://hf-mirror.com/cl-tohoku/bert-base-japanese-v3/resolve/main/pytorch_model.bin", | |
| ], | |
| "BERT_LARGE_JAPANESE_V2": [ | |
| "https://huggingface.co/cl-tohoku/bert-large-japanese-v2/resolve/main/pytorch_model.bin", | |
| "https://hf-mirror.com/cl-tohoku/bert-large-japanese-v2/resolve/main/pytorch_model.bin", | |
| ], | |
| "DEBERTA_V2_LARGE_JAPANESE": [ | |
| "https://huggingface.co/ku-nlp/deberta-v2-large-japanese/resolve/main/pytorch_model.bin", | |
| "https://hf-mirror.com/ku-nlp/deberta-v2-large-japanese/resolve/main/pytorch_model.bin", | |
| ], | |
| "DEBERTA_V3_LARGE": [ | |
| "https://huggingface.co/microsoft/deberta-v3-large/resolve/main/pytorch_model.bin", | |
| "https://hf-mirror.com/microsoft/deberta-v3-large/resolve/main/pytorch_model.bin", | |
| ], | |
| "SPM": [ | |
| "https://huggingface.co/microsoft/deberta-v3-large/resolve/main/spm.model", | |
| "https://hf-mirror.com/microsoft/deberta-v3-large/resolve/main/spm.model", | |
| ], | |
| "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM": [ | |
| "https://huggingface.co/ku-nlp/deberta-v2-large-japanese-char-wwm/resolve/main/pytorch_model.bin", | |
| "https://hf-mirror.com/ku-nlp/deberta-v2-large-japanese-char-wwm/resolve/main/pytorch_model.bin", | |
| ], | |
| "WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM": [ | |
| "https://huggingface.co/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim/resolve/main/pytorch_model.bin", | |
| "https://hf-mirror.com/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim/resolve/main/pytorch_model.bin", | |
| ], | |
| "CLAP_HTSAT_FUSED": [ | |
| "https://huggingface.co/laion/clap-htsat-fused/resolve/main/pytorch_model.bin?download=true", | |
| "https://hf-mirror.com/laion/clap-htsat-fused/resolve/main/pytorch_model.bin?download=true", | |
| ], | |
| "Erlangshen_MegatronBert_1.3B_Chinese": [ | |
| "https://huggingface.co/IDEA-CCNL/Erlangshen-UniMC-MegatronBERT-1.3B-Chinese/resolve/main/pytorch_model.bin", | |
| "https://hf-mirror.com/IDEA-CCNL/Erlangshen-UniMC-MegatronBERT-1.3B-Chinese/resolve/main/pytorch_model.bin", | |
| ], | |
| "G2PWModel": [ | |
| # "https://storage.googleapis.com/esun-ai/g2pW/G2PWModel-v2-onnx.zip", | |
| "https://huggingface.co/ADT109119/G2PWModel-v2-onnx/resolve/main/g2pw.onnx", | |
| "https://hf-mirror.com/ADT109119/G2PWModel-v2-onnx/resolve/main/g2pw.onnx", | |
| ], | |
| "CHINESE_HUBERT_BASE": [ | |
| "https://huggingface.co/TencentGameMate/chinese-hubert-base/resolve/main/pytorch_model.bin", | |
| "https://hf-mirror.com/TencentGameMate/chinese-hubert-base/resolve/main/pytorch_model.bin", | |
| ] | |
| } | |
| self.SHA256 = { | |
| "CHINESE_ROBERTA_WWM_EXT_LARGE": "4ac62d49144d770c5ca9a5d1d3039c4995665a080febe63198189857c6bd11cd", | |
| "BERT_BASE_JAPANESE_V3": "e172862e0674054d65e0ba40d67df2a4687982f589db44aa27091c386e5450a4", | |
| "BERT_LARGE_JAPANESE_V2": "50212d714f79af45d3e47205faa356d0e5030e1c9a37138eadda544180f9e7c9", | |
| "DEBERTA_V2_LARGE_JAPANESE": "a6c15feac0dea77ab8835c70e1befa4cf4c2137862c6fb2443b1553f70840047", | |
| "DEBERTA_V3_LARGE": "dd5b5d93e2db101aaf281df0ea1216c07ad73620ff59c5b42dccac4bf2eef5b5", | |
| "SPM": "c679fbf93643d19aab7ee10c0b99e460bdbc02fedf34b92b05af343b4af586fd", | |
| "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM": "bf0dab8ad87bd7c22e85ec71e04f2240804fda6d33196157d6b5923af6ea1201", | |
| "WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM": "176d9d1ce29a8bddbab44068b9c1c194c51624c7f1812905e01355da58b18816", | |
| "CLAP_HTSAT_FUSED": "1ed5d0215d887551ddd0a49ce7311b21429ebdf1e6a129d4e68f743357225253", | |
| "Erlangshen_MegatronBert_1.3B_Chinese": "3456bb8f2c7157985688a4cb5cecdb9e229cb1dcf785b01545c611462ffe3579", | |
| # "G2PWModel": "bb40c8c7b5baa755b2acd317c6bc5a65e4af7b80c40a569247fbd76989299999", | |
| "G2PWModel": "", | |
| "CHINESE_HUBERT_BASE": "2fefccd26c2794a583b80f6f7210c721873cb7ebae2c1cde3baf9b27855e24d8", | |
| } | |
| self.model_path = { | |
| "CHINESE_ROBERTA_WWM_EXT_LARGE": os.path.join(config.abs_path, config.system.data_path, | |
| config.model_config.chinese_roberta_wwm_ext_large), | |
| "BERT_BASE_JAPANESE_V3": os.path.join(config.abs_path, config.system.data_path, | |
| config.model_config.bert_base_japanese_v3), | |
| "BERT_LARGE_JAPANESE_V2": os.path.join(config.abs_path, config.system.data_path, | |
| config.model_config.bert_large_japanese_v2), | |
| "DEBERTA_V2_LARGE_JAPANESE": os.path.join(config.abs_path, config.system.data_path, | |
| config.model_config.deberta_v2_large_japanese), | |
| "DEBERTA_V3_LARGE": os.path.join(config.abs_path, config.system.data_path, | |
| config.model_config.deberta_v3_large), | |
| "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM": os.path.join(config.abs_path, config.system.data_path, | |
| config.model_config.deberta_v2_large_japanese_char_wwm), | |
| "WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM": os.path.join(config.abs_path, config.system.data_path, | |
| config.model_config.wav2vec2_large_robust_12_ft_emotion_msp_dim), | |
| "CLAP_HTSAT_FUSED": os.path.join(config.abs_path, config.system.data_path, | |
| config.model_config.clap_htsat_fused), | |
| "Erlangshen_MegatronBert_1.3B_Chinese": os.path.join(config.abs_path, config.system.data_path, | |
| config.model_config.erlangshen_MegatronBert_1_3B_Chinese), | |
| "G2PWModel": os.path.join(config.abs_path, config.system.data_path, config.model_config.g2pw_model), | |
| "CHINESE_HUBERT_BASE": os.path.join(config.abs_path, config.system.data_path, | |
| config.model_config.chinese_hubert_base), | |
| } | |
| self.lang_bert_func_map = {"zh": zh_bert, "en": en_bert, "ja": ja_bert, "ja_v111": ja_bert_v111, | |
| "ja_v200": ja_bert_v200, "en_v200": en_bert_v200, "zh_extra": zh_bert_extra, | |
| "ja_extra": ja_bert_extra} | |
| self.bert_models = {} # Value: (tokenizer, model, reference_count) | |
| self.emotion = None | |
| self.clap = None | |
| self.pinyinPlus = None | |
| self.device = device | |
| self.ssl_model = None | |
| if config.bert_vits2_config.torch_data_type.lower() in ["float16", "fp16"]: | |
| self.torch_dtype = torch.float16 | |
| else: | |
| self.torch_dtype = None | |
| def emotion_model(self): | |
| return self.emotion["model"] | |
| def emotion_processor(self): | |
| return self.emotion["processor"] | |
| def clap_model(self): | |
| return self.clap["model"] | |
| def clap_processor(self): | |
| return self.clap["processor"] | |
| def _download_model(self, model_name, target_path=None): | |
| urls = self.DOWNLOAD_PATHS[model_name] | |
| if target_path is None: | |
| target_path = os.path.join(self.model_path[model_name], "pytorch_model.bin") | |
| expected_sha256 = self.SHA256[model_name] | |
| success, message = download_file(urls, target_path, expected_sha256=expected_sha256) | |
| if not success: | |
| logging.error(f"Failed to download {model_name}: {message}") | |
| else: | |
| logging.info(f"{message}") | |
| def load_bert(self, bert_model_name, max_retries=3): | |
| if bert_model_name not in self.bert_models: | |
| retries = 0 | |
| model_path = "" | |
| while retries < max_retries: | |
| model_path = self.model_path[bert_model_name] | |
| logging.info(f"Loading BERT model: {model_path}") | |
| try: | |
| if bert_model_name == "Erlangshen_MegatronBert_1.3B_Chinese": | |
| tokenizer = BertTokenizer.from_pretrained(model_path, torch_dtype=self.torch_dtype) | |
| model = MegatronBertModel.from_pretrained(model_path, torch_dtype=self.torch_dtype).to( | |
| self.device) | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, torch_dtype=self.torch_dtype) | |
| model = AutoModelForMaskedLM.from_pretrained(model_path, torch_dtype=self.torch_dtype).to( | |
| self.device) | |
| self.bert_models[bert_model_name] = (tokenizer, model, 1) # 初始化引用计数为1 | |
| logging.info(f"Success loading: {model_path}") | |
| break | |
| except Exception as e: | |
| logging.error(f"Failed loading {model_path}. {e}") | |
| logging.info(f"Trying to download.") | |
| if bert_model_name == "DEBERTA_V3_LARGE" and not os.path.exists( | |
| os.path.join(model_path, "spm.model")): | |
| self._download_model("SPM", os.path.join(model_path, "spm.model")) | |
| self._download_model(bert_model_name) | |
| retries += 1 | |
| if retries == max_retries: | |
| logging.error(f"Failed to load {model_path} after {max_retries} retries.") | |
| else: | |
| tokenizer, model, count = self.bert_models[bert_model_name] | |
| self.bert_models[bert_model_name] = (tokenizer, model, count + 1) | |
| def load_emotion(self, max_retries=3): | |
| """Bert-VITS2 v2.1 EmotionModel""" | |
| if self.emotion is None: | |
| from transformers import Wav2Vec2Processor | |
| from bert_vits2.get_emo import EmotionModel | |
| retries = 0 | |
| model_path = self.model_path["WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM"] | |
| while retries < max_retries: | |
| logging.info(f"Loading WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM: {model_path}") | |
| try: | |
| self.emotion = {} | |
| self.emotion["model"] = EmotionModel.from_pretrained(model_path).to(self.device) | |
| self.emotion["processor"] = Wav2Vec2Processor.from_pretrained(model_path) | |
| self.emotion["reference_count"] = 1 | |
| logging.info(f"Success loading: {model_path}") | |
| break | |
| except Exception as e: | |
| logging.error(f"Failed loading {model_path}. {e}") | |
| self._download_model("WAV2VEC2_LARGE_ROBUST_12_FT_EMOTION_MSP_DIM") | |
| retries += 1 | |
| if retries == max_retries: | |
| logging.error(f"Failed to load {model_path} after {max_retries} retries.") | |
| else: | |
| self.emotion["reference_count"] += 1 | |
| def release_emotion(self): | |
| if self.emotion is not None: | |
| self.emotion["reference_count"] -= 1 | |
| if self.emotion["reference_count"] <= 0: | |
| del self.emotion | |
| self.emotion = None | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| logging.info(f"Emotion model has been released.") | |
| def load_clap(self, max_retries=3): | |
| """Bert-VITS2 v2.2 ClapModel""" | |
| if self.clap is None: | |
| from transformers import ClapModel, ClapProcessor | |
| retries = 0 | |
| model_path = self.model_path["CLAP_HTSAT_FUSED"] | |
| while retries < max_retries: | |
| logging.info(f"Loading CLAP_HTSAT_FUSED: {model_path}") | |
| try: | |
| self.clap = {} | |
| self.clap["model"] = ClapModel.from_pretrained(model_path, torch_dtype=self.torch_dtype).to( | |
| self.device) | |
| self.clap["processor"] = ClapProcessor.from_pretrained(model_path, torch_dtype=self.torch_dtype) | |
| self.clap["reference_count"] = 1 | |
| logging.info(f"Success loading: {model_path}") | |
| break | |
| except Exception as e: | |
| logging.error(f"Failed loading {model_path}. {e}") | |
| self._download_model("CLAP_HTSAT_FUSED") | |
| retries += 1 | |
| if retries == max_retries: | |
| logging.error(f"Failed to load {model_path} after {max_retries} retries.") | |
| else: | |
| self.clap["reference_count"] += 1 | |
| def release_clap(self): | |
| if self.clap is not None: | |
| self.clap["reference_count"] -= 1 | |
| if self.clap["reference_count"] <= 0: | |
| del self.clap | |
| self.clap = None | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| logging.info(f"Clap model has been released.") | |
| def get_bert_model(self, bert_model_name): | |
| if bert_model_name not in self.bert_models: | |
| self.load_bert(bert_model_name) | |
| tokenizer, model, _ = self.bert_models[bert_model_name] | |
| return tokenizer, model | |
| def get_bert_feature(self, norm_text, word2ph, language, bert_model_name, style_text=None, style_weight=0.7): | |
| tokenizer, model = self.get_bert_model(bert_model_name) | |
| bert_feature = self.lang_bert_func_map[language](norm_text, word2ph, tokenizer, model, self.device, | |
| style_text=style_text, style_weight=style_weight) | |
| return bert_feature | |
| def get_pinyinPlus(self): | |
| if self.pinyinPlus is None: | |
| from bert_vits2.g2pW.pypinyin_G2pW_bv2 import G2PWPinyin | |
| logging.info(f"Loading G2PWModel: {self.model_path['G2PWModel']}") | |
| self.pinyinPlus = G2PWPinyin( | |
| model_dir=self.model_path["G2PWModel"], | |
| model_source=self.model_path["Erlangshen_MegatronBert_1.3B_Chinese"], | |
| v_to_u=False, | |
| neutral_tone_with_five=True, | |
| ) | |
| logging.info("Success loading G2PWModel") | |
| return self.pinyinPlus | |
| def release_bert(self, bert_model_name): | |
| if bert_model_name in self.bert_models: | |
| _, _, count = self.bert_models[bert_model_name] | |
| count -= 1 | |
| if count == 0: | |
| # 当引用计数为0时,删除模型并释放其资源 | |
| del self.bert_models[bert_model_name] | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| logging.info(f"BERT model {bert_model_name} has been released.") | |
| else: | |
| tokenizer, model = self.bert_models[bert_model_name][:2] | |
| self.bert_models[bert_model_name] = (tokenizer, model, count) | |
| def load_ssl(self, max_retries=3): | |
| """GPT-SoVITS""" | |
| if self.ssl_model is None: | |
| retries = 0 | |
| model_path = self.model_path["CHINESE_HUBERT_BASE"] | |
| while retries < max_retries: | |
| logging.info(f"Loading CHINESE_HUBERT_BASE: {model_path}") | |
| try: | |
| from gpt_sovits.feature_extractor.cnhubert import CNHubert | |
| self.ssl_model = {} | |
| model_path = self.model_path.get("CHINESE_HUBERT_BASE") | |
| self.ssl_model["model"] = CNHubert(model_path) | |
| self.ssl_model["model"].eval() | |
| if config.gpt_sovits_config.is_half: | |
| self.ssl_model["model"] = self.ssl_model["model"].half() | |
| self.ssl_model["model"] = self.ssl_model["model"].to(self.device) | |
| self.ssl_model["reference_count"] = 1 | |
| logging.info(f"Success loading: {model_path}") | |
| break | |
| except Exception as e: | |
| logging.error(f"Failed loading {model_path}. {e}") | |
| self._download_model("CHINESE_HUBERT_BASE") | |
| retries += 1 | |
| if retries == max_retries: | |
| logging.error(f"Failed to load {model_path} after {max_retries} retries.") | |
| else: | |
| self.ssl_model["reference_count"] += 1 | |
| def get_ssl_model(self): | |
| if self.ssl_model is None: | |
| self.load_ssl() | |
| return self.ssl_model.get("model") | |
| def release_ssl_model(self): | |
| if self.ssl_model is not None: | |
| self.ssl_model["reference_count"] -= 1 | |
| if self.ssl_model["reference_count"] <= 0: | |
| del self.ssl_model | |
| self.ssl_model = None | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| logging.info(f"SSL model has been released.") | |
| def is_model_loaded(self, bert_model_name): | |
| return bert_model_name in self.bert_models | |
| def reference_count(self, bert_model_name): | |
| return self.bert_models[bert_model_name][2] if bert_model_name in self.bert_models else 0 | |