Spaces:
Configuration error
Configuration error
unknown
commited on
Commit
·
3f3743e
1
Parent(s):
6871802
add finetune miss
Browse files- finetune-cli.py +11 -6
finetune-cli.py
CHANGED
@@ -28,6 +28,7 @@ def parse_args():
|
|
28 |
parser.add_argument('--num_warmup_updates', type=int, default=5, help='Warmup steps')
|
29 |
parser.add_argument('--save_per_updates', type=int, default=10, help='Save checkpoint every X steps')
|
30 |
parser.add_argument('--last_per_steps', type=int, default=10, help='Save last checkpoint every X steps')
|
|
|
31 |
|
32 |
return parser.parse_args()
|
33 |
|
@@ -42,17 +43,21 @@ def main():
|
|
42 |
wandb_resume_id = None
|
43 |
model_cls = DiT
|
44 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
45 |
-
|
|
|
46 |
elif args.exp_name == "E2TTS_Base":
|
47 |
wandb_resume_id = None
|
48 |
model_cls = UNetT
|
49 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
path_ckpt = os.path.join("ckpts",args.dataset_name)
|
53 |
-
if os.path.isdir(path_ckpt)==False:
|
54 |
-
os.makedirs(path_ckpt,exist_ok=True)
|
55 |
-
shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path)))
|
56 |
checkpoint_path=os.path.join("ckpts",args.dataset_name)
|
57 |
|
58 |
# Use the dataset_name provided in the command line
|
|
|
28 |
parser.add_argument('--num_warmup_updates', type=int, default=5, help='Warmup steps')
|
29 |
parser.add_argument('--save_per_updates', type=int, default=10, help='Save checkpoint every X steps')
|
30 |
parser.add_argument('--last_per_steps', type=int, default=10, help='Save last checkpoint every X steps')
|
31 |
+
parser.add_argument('--finetune', type=bool, default=True, help='Use Finetune')
|
32 |
|
33 |
return parser.parse_args()
|
34 |
|
|
|
43 |
wandb_resume_id = None
|
44 |
model_cls = DiT
|
45 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
46 |
+
if args.finetune:
|
47 |
+
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
48 |
elif args.exp_name == "E2TTS_Base":
|
49 |
wandb_resume_id = None
|
50 |
model_cls = UNetT
|
51 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
52 |
+
if args.finetune:
|
53 |
+
ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
54 |
+
|
55 |
+
if args.finetune:
|
56 |
+
path_ckpt = os.path.join("ckpts",args.dataset_name)
|
57 |
+
if os.path.isdir(path_ckpt)==False:
|
58 |
+
os.makedirs(path_ckpt,exist_ok=True)
|
59 |
+
shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path)))
|
60 |
|
|
|
|
|
|
|
|
|
61 |
checkpoint_path=os.path.join("ckpts",args.dataset_name)
|
62 |
|
63 |
# Use the dataset_name provided in the command line
|