hcsolakoglu
commited on
Commit
·
728bcf7
1
Parent(s):
437353c
Add --bnb_optimizer argument to CLI and pass it to Trainer initialization
Browse filesAdd `--bnb_optimizer` argument to CLI and pass it to Trainer initialization.
* Add `--bnb_optimizer` argument to `parse_args()` function in `src/f5_tts/train/finetune_cli.py`.
* Pass `bnb_optimizer` argument to `Trainer` initialization in the `main()` function of `src/f5_tts/train/finetune_cli.py`.
src/f5_tts/train/finetune_cli.py
CHANGED
|
@@ -55,7 +55,6 @@ def parse_args():
|
|
| 55 |
default=None,
|
| 56 |
help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
|
| 57 |
)
|
| 58 |
-
|
| 59 |
parser.add_argument(
|
| 60 |
"--log_samples",
|
| 61 |
type=bool,
|
|
@@ -63,6 +62,12 @@ def parse_args():
|
|
| 63 |
help="Log inferenced samples per ckpt save steps",
|
| 64 |
)
|
| 65 |
parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
return parser.parse_args()
|
| 68 |
|
|
@@ -147,6 +152,7 @@ def main():
|
|
| 147 |
wandb_resume_id=wandb_resume_id,
|
| 148 |
log_samples=args.log_samples,
|
| 149 |
last_per_steps=args.last_per_steps,
|
|
|
|
| 150 |
)
|
| 151 |
|
| 152 |
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
|
|
|
| 55 |
default=None,
|
| 56 |
help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
|
| 57 |
)
|
|
|
|
| 58 |
parser.add_argument(
|
| 59 |
"--log_samples",
|
| 60 |
type=bool,
|
|
|
|
| 62 |
help="Log inferenced samples per ckpt save steps",
|
| 63 |
)
|
| 64 |
parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--bnb_optimizer",
|
| 67 |
+
type=bool,
|
| 68 |
+
default=False,
|
| 69 |
+
help="Use 8-bit Adam optimizer from bitsandbytes"
|
| 70 |
+
)
|
| 71 |
|
| 72 |
return parser.parse_args()
|
| 73 |
|
|
|
|
| 152 |
wandb_resume_id=wandb_resume_id,
|
| 153 |
log_samples=args.log_samples,
|
| 154 |
last_per_steps=args.last_per_steps,
|
| 155 |
+
bnb_optimizer=args.bnb_optimizer,
|
| 156 |
)
|
| 157 |
|
| 158 |
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|