File size: 3,903 Bytes
bf0e5e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import argparse
from pathlib import Path

import datetime
import json
import os
import sys
from typing import Optional
from huggingface_hub import hf_hub_download
def get_dataset_pytorch_model(local_dir):
	os.makedirs(local_dir, exist_ok=True)
	hf_hub_download("kkvc-hf/Style-Bert-VITS2-Datasets", f"{local_dir}/pytorch_model.bin", local_dir=local_dir, repo_type="dataset")

get_dataset_pytorch_model("bert/chinese-roberta-wwm-ext-large")
get_dataset_pytorch_model("bert/deberta-v2-large-japanese-char-wwm")
get_dataset_pytorch_model("bert/deberta-v3-large")
get_dataset_pytorch_model("slm/wavlm-base-plus")

local_dir = "bert/deberta-v2-large-japanese-char-wwm"
os.makedirs(local_dir, exist_ok=True)
hf_hub_download("kkvc-hf/Style-Bert-VITS2-bert_deberta-v2-large-japanese-char-wwm", "pytorch_model.bin", local_dir=local_dir, repo_type="dataset")

local_dir = "pretrained"
os.makedirs(local_dir, exist_ok=True)
hf_hub_download("kkvc-hf/Style-Bert-VITS2-Datasets", f"{local_dir}/D_0.safetensors", local_dir=local_dir, repo_type="dataset")
hf_hub_download("kkvc-hf/Style-Bert-VITS2-Datasets", f"{local_dir}/G_0.safetensors", local_dir=local_dir, repo_type="dataset")
hf_hub_download("kkvc-hf/Style-Bert-VITS2-Datasets", f"{local_dir}/DUR_0.safetensors", local_dir=local_dir, repo_type="dataset")

local_dir = "pretrained_jp_extra"
os.makedirs(local_dir, exist_ok=True)
hf_hub_download("kkvc-hf/Style-Bert-VITS2-Datasets", f"{local_dir}/D_0.safetensors", local_dir=local_dir, repo_type="dataset")
hf_hub_download("kkvc-hf/Style-Bert-VITS2-Datasets", f"{local_dir}/G_0.safetensors", local_dir=local_dir, repo_type="dataset")
hf_hub_download("kkvc-hf/Style-Bert-VITS2-Datasets", f"{local_dir}/WD_0.safetensors", local_dir=local_dir, repo_type="dataset")

import gradio as gr
import torch

from config import get_path_config
from gradio_tabs.dataset import create_dataset_app
from gradio_tabs.inference import create_inference_app
from gradio_tabs.merge import create_merge_app
from gradio_tabs.style_vectors import create_style_vectors_app
from gradio_tabs.train import create_train_app
from style_bert_vits2.constants import GRADIO_THEME, VERSION
from style_bert_vits2.nlp.japanese import pyopenjtalk_worker
from style_bert_vits2.nlp.japanese.user_dict import update_dict
from style_bert_vits2.tts_model import TTSModelHolder


# このプロセスからはワーカーを起動して辞書を使いたいので、ここで初期化
pyopenjtalk_worker.initialize_worker()

# dict_data/ 以下の辞書データを pyopenjtalk に適用
update_dict()


parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--port", type=int, default=None)
parser.add_argument("--no_autolaunch", action="store_true")
parser.add_argument("--share", action="store_true")
# parser.add_argument("--skip_default_models", action="store_true")

args = parser.parse_args()
device = args.device
if device == "cuda" and not torch.cuda.is_available():
    device = "cpu"

# if not args.skip_default_models:
#     download_default_models()

path_config = get_path_config()
model_holder = TTSModelHolder(Path(path_config.assets_root), device)

with gr.Blocks(theme=GRADIO_THEME) as app:
    gr.Markdown(f"# Style-Bert-VITS2 WebUI (version {VERSION})")
    with gr.Tabs():
        with gr.Tab("音声合成"):
            create_inference_app(model_holder=model_holder)
        with gr.Tab("データセット作成"):
            create_dataset_app()
        with gr.Tab("学習"):
            create_train_app()
        with gr.Tab("スタイル作成"):
            create_style_vectors_app()
        with gr.Tab("マージ"):
            create_merge_app(model_holder=model_holder)

app.launch(
    #server_name=args.host,
    #server_port=args.port,
    inbrowser=not args.no_autolaunch,
    share=args.share,
)