Spaces:
Runtime error
Runtime error
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
src/f5_tts/train/finetune_cli.py
CHANGED
|
@@ -6,6 +6,7 @@ from cached_path import cached_path
|
|
| 6 |
from f5_tts.model import CFM, UNetT, DiT, Trainer
|
| 7 |
from f5_tts.model.utils import get_tokenizer
|
| 8 |
from f5_tts.model.dataset import load_dataset
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
# -------------------------- Dataset Settings --------------------------- #
|
|
@@ -63,6 +64,7 @@ def parse_args():
|
|
| 63 |
|
| 64 |
def main():
|
| 65 |
args = parse_args()
|
|
|
|
| 66 |
|
| 67 |
# Model parameters based on experiment name
|
| 68 |
if args.exp_name == "F5TTS_Base":
|
|
@@ -85,12 +87,9 @@ def main():
|
|
| 85 |
ckpt_path = args.pretrain
|
| 86 |
|
| 87 |
if args.finetune:
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
os.
|
| 91 |
-
shutil.copy2(ckpt_path, os.path.join(path_ckpt, os.path.basename(ckpt_path)))
|
| 92 |
-
|
| 93 |
-
checkpoint_path = os.path.join("ckpts", args.dataset_name)
|
| 94 |
|
| 95 |
# Use the tokenizer and tokenizer_path provided in the command line arguments
|
| 96 |
tokenizer = args.tokenizer
|
|
|
|
| 6 |
from f5_tts.model import CFM, UNetT, DiT, Trainer
|
| 7 |
from f5_tts.model.utils import get_tokenizer
|
| 8 |
from f5_tts.model.dataset import load_dataset
|
| 9 |
+
from importlib.resources import files
|
| 10 |
|
| 11 |
|
| 12 |
# -------------------------- Dataset Settings --------------------------- #
|
|
|
|
| 64 |
|
| 65 |
def main():
|
| 66 |
args = parse_args()
|
| 67 |
+
checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
|
| 68 |
|
| 69 |
# Model parameters based on experiment name
|
| 70 |
if args.exp_name == "F5TTS_Base":
|
|
|
|
| 87 |
ckpt_path = args.pretrain
|
| 88 |
|
| 89 |
if args.finetune:
|
| 90 |
+
if not os.path.isdir(checkpoint_path):
|
| 91 |
+
os.makedirs(checkpoint_path, exist_ok=True)
|
| 92 |
+
shutil.copy2(ckpt_path, os.path.join(checkpoint_path, os.path.basename(ckpt_path)))
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
# Use the tokenizer and tokenizer_path provided in the command line arguments
|
| 95 |
tokenizer = args.tokenizer
|