unknown commited on
Commit
3f3743e
·
1 Parent(s): 6871802

add finetune miss

Browse files
Files changed (1) hide show
  1. finetune-cli.py +11 -6
finetune-cli.py CHANGED
@@ -28,6 +28,7 @@ def parse_args():
28
  parser.add_argument('--num_warmup_updates', type=int, default=5, help='Warmup steps')
29
  parser.add_argument('--save_per_updates', type=int, default=10, help='Save checkpoint every X steps')
30
  parser.add_argument('--last_per_steps', type=int, default=10, help='Save last checkpoint every X steps')
 
31
 
32
  return parser.parse_args()
33
 
@@ -42,17 +43,21 @@ def main():
42
  wandb_resume_id = None
43
  model_cls = DiT
44
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
45
- ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
 
46
  elif args.exp_name == "E2TTS_Base":
47
  wandb_resume_id = None
48
  model_cls = UNetT
49
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
50
- ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
 
 
 
 
 
 
 
51
 
52
- path_ckpt = os.path.join("ckpts",args.dataset_name)
53
- if os.path.isdir(path_ckpt)==False:
54
- os.makedirs(path_ckpt,exist_ok=True)
55
- shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path)))
56
  checkpoint_path=os.path.join("ckpts",args.dataset_name)
57
 
58
  # Use the dataset_name provided in the command line
 
28
  parser.add_argument('--num_warmup_updates', type=int, default=5, help='Warmup steps')
29
  parser.add_argument('--save_per_updates', type=int, default=10, help='Save checkpoint every X steps')
30
  parser.add_argument('--last_per_steps', type=int, default=10, help='Save last checkpoint every X steps')
31
+ parser.add_argument('--finetune', type=bool, default=True, help='Use Finetune')
32
 
33
  return parser.parse_args()
34
 
 
43
  wandb_resume_id = None
44
  model_cls = DiT
45
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
46
+ if args.finetune:
47
+ ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
48
  elif args.exp_name == "E2TTS_Base":
49
  wandb_resume_id = None
50
  model_cls = UNetT
51
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
52
+ if args.finetune:
53
+ ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
54
+
55
+ if args.finetune:
56
+ path_ckpt = os.path.join("ckpts",args.dataset_name)
57
+ if os.path.isdir(path_ckpt)==False:
58
+ os.makedirs(path_ckpt,exist_ok=True)
59
+ shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path)))
60
 
 
 
 
 
61
  checkpoint_path=os.path.join("ckpts",args.dataset_name)
62
 
63
  # Use the dataset_name provided in the command line