Hasan Can Solakoğlu
commited on
Commit
·
2d70bfe
1
Parent(s):
b09897b
Keep Last N Checkpoints (#718)
Browse files* Add checkpoint management feature
- Introduced `keep_last_n_checkpoints` parameter in configuration and training scripts to manage the number of recent checkpoints retained.
- Updated `finetune_cli.py`, `finetune_gradio.py`, and `trainer.py` to support this new parameter.
- Implemented logic to remove older checkpoints beyond the specified limit during training.
- Adjusted settings loading and saving to include the new checkpoint management option.
This enhancement improves the training process by preventing excessive storage usage from old checkpoints.
- src/f5_tts/configs/E2TTS_Base_train.yaml +1 -0
- src/f5_tts/configs/E2TTS_Small_train.yaml +1 -0
- src/f5_tts/configs/F5TTS_Base_train.yaml +1 -0
- src/f5_tts/configs/F5TTS_Small_train.yaml +1 -0
- src/f5_tts/model/trainer.py +31 -0
- src/f5_tts/train/finetune_cli.py +7 -0
- src/f5_tts/train/finetune_gradio.py +21 -1
- src/f5_tts/train/train.py +1 -0
src/f5_tts/configs/E2TTS_Base_train.yaml
CHANGED
@@ -41,4 +41,5 @@ ckpts:
|
|
41 |
logger: wandb # wandb | tensorboard | None
|
42 |
save_per_updates: 50000 # save checkpoint per updates
|
43 |
last_per_updates: 5000 # save last checkpoint per updates
|
|
|
44 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
|
|
41 |
logger: wandb # wandb | tensorboard | None
|
42 |
save_per_updates: 50000 # save checkpoint per updates
|
43 |
last_per_updates: 5000 # save last checkpoint per updates
|
44 |
+
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
|
45 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/configs/E2TTS_Small_train.yaml
CHANGED
@@ -41,4 +41,5 @@ ckpts:
|
|
41 |
logger: wandb # wandb | tensorboard | None
|
42 |
save_per_updates: 50000 # save checkpoint per updates
|
43 |
last_per_updates: 5000 # save last checkpoint per updates
|
|
|
44 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
|
|
41 |
logger: wandb # wandb | tensorboard | None
|
42 |
save_per_updates: 50000 # save checkpoint per updates
|
43 |
last_per_updates: 5000 # save last checkpoint per updates
|
44 |
+
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
|
45 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/configs/F5TTS_Base_train.yaml
CHANGED
@@ -44,4 +44,5 @@ ckpts:
|
|
44 |
logger: wandb # wandb | tensorboard | None
|
45 |
save_per_updates: 50000 # save checkpoint per updates
|
46 |
last_per_updates: 5000 # save last checkpoint per updates
|
|
|
47 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
|
|
44 |
logger: wandb # wandb | tensorboard | None
|
45 |
save_per_updates: 50000 # save checkpoint per updates
|
46 |
last_per_updates: 5000 # save last checkpoint per updates
|
47 |
+
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
|
48 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/configs/F5TTS_Small_train.yaml
CHANGED
@@ -44,4 +44,5 @@ ckpts:
|
|
44 |
logger: wandb # wandb | tensorboard | None
|
45 |
save_per_updates: 50000 # save checkpoint per updates
|
46 |
last_per_updates: 5000 # save last checkpoint per updates
|
|
|
47 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
|
|
44 |
logger: wandb # wandb | tensorboard | None
|
45 |
save_per_updates: 50000 # save checkpoint per updates
|
46 |
last_per_updates: 5000 # save last checkpoint per updates
|
47 |
+
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
|
48 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/model/trainer.py
CHANGED
@@ -50,7 +50,17 @@ class Trainer:
|
|
50 |
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
51 |
is_local_vocoder: bool = False, # use local path vocoder
|
52 |
local_vocoder_path: str = "", # local vocoder path
|
|
|
|
|
53 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
55 |
|
56 |
if logger == "wandb" and not wandb.api.api_key:
|
@@ -134,6 +144,8 @@ class Trainer:
|
|
134 |
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
135 |
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
|
136 |
|
|
|
|
|
137 |
@property
|
138 |
def is_main(self):
|
139 |
return self.accelerator.is_main_process
|
@@ -154,7 +166,26 @@ class Trainer:
|
|
154 |
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
|
155 |
print(f"Saved last checkpoint at update {update}")
|
156 |
else:
|
|
|
|
|
|
|
|
|
157 |
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
def load_checkpoint(self):
|
160 |
if (
|
|
|
50 |
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
51 |
is_local_vocoder: bool = False, # use local path vocoder
|
52 |
local_vocoder_path: str = "", # local vocoder path
|
53 |
+
keep_last_n_checkpoints: int
|
54 |
+
| None = -1, # -1 (default) to keep all, 0 to not save intermediate ckpts, positive N to keep last N checkpoints
|
55 |
):
|
56 |
+
# Validate keep_last_n_checkpoints
|
57 |
+
if not isinstance(keep_last_n_checkpoints, int):
|
58 |
+
raise ValueError("keep_last_n_checkpoints must be an integer")
|
59 |
+
if keep_last_n_checkpoints < -1:
|
60 |
+
raise ValueError(
|
61 |
+
"keep_last_n_checkpoints must be -1 (keep all), 0 (no intermediate checkpoints), or positive integer"
|
62 |
+
)
|
63 |
+
|
64 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
65 |
|
66 |
if logger == "wandb" and not wandb.api.api_key:
|
|
|
144 |
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
145 |
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
|
146 |
|
147 |
+
self.keep_last_n_checkpoints = keep_last_n_checkpoints if keep_last_n_checkpoints is not None else None
|
148 |
+
|
149 |
@property
|
150 |
def is_main(self):
|
151 |
return self.accelerator.is_main_process
|
|
|
166 |
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
|
167 |
print(f"Saved last checkpoint at update {update}")
|
168 |
else:
|
169 |
+
# Skip saving intermediate checkpoints if keep_last_n_checkpoints is 0
|
170 |
+
if self.keep_last_n_checkpoints == 0:
|
171 |
+
return
|
172 |
+
|
173 |
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
|
174 |
+
# Implement rolling checkpoint system - only if keep_last_n_checkpoints is positive
|
175 |
+
if self.keep_last_n_checkpoints > 0:
|
176 |
+
# Get all checkpoint files except model_last.pt
|
177 |
+
checkpoints = [
|
178 |
+
f
|
179 |
+
for f in os.listdir(self.checkpoint_path)
|
180 |
+
if f.startswith("model_") and f.endswith(".pt") and f != "model_last.pt"
|
181 |
+
]
|
182 |
+
# Sort by step number
|
183 |
+
checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
|
184 |
+
# Remove old checkpoints if we have more than keep_last_n_checkpoints
|
185 |
+
while len(checkpoints) > self.keep_last_n_checkpoints:
|
186 |
+
oldest_checkpoint = checkpoints.pop(0)
|
187 |
+
os.remove(os.path.join(self.checkpoint_path, oldest_checkpoint))
|
188 |
+
print(f"Removed old checkpoint: {oldest_checkpoint}")
|
189 |
|
190 |
def load_checkpoint(self):
|
191 |
if (
|
src/f5_tts/train/finetune_cli.py
CHANGED
@@ -69,6 +69,12 @@ def parse_args():
|
|
69 |
action="store_true",
|
70 |
help="Use 8-bit Adam optimizer from bitsandbytes",
|
71 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
return parser.parse_args()
|
74 |
|
@@ -158,6 +164,7 @@ def main():
|
|
158 |
log_samples=args.log_samples,
|
159 |
last_per_updates=args.last_per_updates,
|
160 |
bnb_optimizer=args.bnb_optimizer,
|
|
|
161 |
)
|
162 |
|
163 |
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
|
|
69 |
action="store_true",
|
70 |
help="Use 8-bit Adam optimizer from bitsandbytes",
|
71 |
)
|
72 |
+
parser.add_argument(
|
73 |
+
"--keep_last_n_checkpoints",
|
74 |
+
type=int,
|
75 |
+
default=-1,
|
76 |
+
help="-1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints",
|
77 |
+
)
|
78 |
|
79 |
return parser.parse_args()
|
80 |
|
|
|
164 |
log_samples=args.log_samples,
|
165 |
last_per_updates=args.last_per_updates,
|
166 |
bnb_optimizer=args.bnb_optimizer,
|
167 |
+
keep_last_n_checkpoints=args.keep_last_n_checkpoints,
|
168 |
)
|
169 |
|
170 |
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
src/f5_tts/train/finetune_gradio.py
CHANGED
@@ -70,6 +70,7 @@ def save_settings(
|
|
70 |
mixed_precision,
|
71 |
logger,
|
72 |
ch_8bit_adam,
|
|
|
73 |
):
|
74 |
path_project = os.path.join(path_project_ckpts, project_name)
|
75 |
os.makedirs(path_project, exist_ok=True)
|
@@ -94,6 +95,7 @@ def save_settings(
|
|
94 |
"mixed_precision": mixed_precision,
|
95 |
"logger": logger,
|
96 |
"bnb_optimizer": ch_8bit_adam,
|
|
|
97 |
}
|
98 |
with open(file_setting, "w") as f:
|
99 |
json.dump(settings, f, indent=4)
|
@@ -126,6 +128,7 @@ def load_settings(project_name):
|
|
126 |
"mixed_precision": "none",
|
127 |
"logger": "wandb",
|
128 |
"bnb_optimizer": False,
|
|
|
129 |
}
|
130 |
return (
|
131 |
settings["exp_name"],
|
@@ -146,6 +149,7 @@ def load_settings(project_name):
|
|
146 |
settings["mixed_precision"],
|
147 |
settings["logger"],
|
148 |
settings["bnb_optimizer"],
|
|
|
149 |
)
|
150 |
|
151 |
with open(file_setting, "r") as f:
|
@@ -154,6 +158,8 @@ def load_settings(project_name):
|
|
154 |
settings["logger"] = "wandb"
|
155 |
if "bnb_optimizer" not in settings:
|
156 |
settings["bnb_optimizer"] = False
|
|
|
|
|
157 |
if "last_per_updates" not in settings: # patch for backward compatibility, with before f992c4e
|
158 |
settings["last_per_updates"] = settings["last_per_steps"] // settings["grad_accumulation_steps"]
|
159 |
return (
|
@@ -175,6 +181,7 @@ def load_settings(project_name):
|
|
175 |
settings["mixed_precision"],
|
176 |
settings["logger"],
|
177 |
settings["bnb_optimizer"],
|
|
|
178 |
)
|
179 |
|
180 |
|
@@ -390,6 +397,7 @@ def start_training(
|
|
390 |
stream=False,
|
391 |
logger="wandb",
|
392 |
ch_8bit_adam=False,
|
|
|
393 |
):
|
394 |
global training_process, tts_api, stop_signal
|
395 |
|
@@ -451,7 +459,8 @@ def start_training(
|
|
451 |
f"--num_warmup_updates {num_warmup_updates} "
|
452 |
f"--save_per_updates {save_per_updates} "
|
453 |
f"--last_per_updates {last_per_updates} "
|
454 |
-
f"--dataset_name {dataset_name}"
|
|
|
455 |
)
|
456 |
|
457 |
if finetune:
|
@@ -492,6 +501,7 @@ def start_training(
|
|
492 |
mixed_precision,
|
493 |
logger,
|
494 |
ch_8bit_adam,
|
|
|
495 |
)
|
496 |
|
497 |
try:
|
@@ -1564,6 +1574,13 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1564 |
with gr.Row():
|
1565 |
save_per_updates = gr.Number(label="Save per Updates", value=300)
|
1566 |
last_per_updates = gr.Number(label="Last per Updates", value=100)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1567 |
|
1568 |
with gr.Row():
|
1569 |
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
|
@@ -1592,6 +1609,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1592 |
mixed_precisionv,
|
1593 |
cd_loggerv,
|
1594 |
ch_8bit_adamv,
|
|
|
1595 |
) = load_settings(projects_selelect)
|
1596 |
exp_name.value = exp_namev
|
1597 |
learning_rate.value = learning_ratev
|
@@ -1611,6 +1629,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1611 |
mixed_precision.value = mixed_precisionv
|
1612 |
cd_logger.value = cd_loggerv
|
1613 |
ch_8bit_adam.value = ch_8bit_adamv
|
|
|
1614 |
|
1615 |
ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
|
1616 |
txt_info_train = gr.Text(label="Info", value="")
|
@@ -1670,6 +1689,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1670 |
ch_stream,
|
1671 |
cd_logger,
|
1672 |
ch_8bit_adam,
|
|
|
1673 |
],
|
1674 |
outputs=[txt_info_train, start_button, stop_button],
|
1675 |
)
|
|
|
70 |
mixed_precision,
|
71 |
logger,
|
72 |
ch_8bit_adam,
|
73 |
+
keep_last_n_checkpoints,
|
74 |
):
|
75 |
path_project = os.path.join(path_project_ckpts, project_name)
|
76 |
os.makedirs(path_project, exist_ok=True)
|
|
|
95 |
"mixed_precision": mixed_precision,
|
96 |
"logger": logger,
|
97 |
"bnb_optimizer": ch_8bit_adam,
|
98 |
+
"keep_last_n_checkpoints": keep_last_n_checkpoints,
|
99 |
}
|
100 |
with open(file_setting, "w") as f:
|
101 |
json.dump(settings, f, indent=4)
|
|
|
128 |
"mixed_precision": "none",
|
129 |
"logger": "wandb",
|
130 |
"bnb_optimizer": False,
|
131 |
+
"keep_last_n_checkpoints": -1, # Default to keep all checkpoints
|
132 |
}
|
133 |
return (
|
134 |
settings["exp_name"],
|
|
|
149 |
settings["mixed_precision"],
|
150 |
settings["logger"],
|
151 |
settings["bnb_optimizer"],
|
152 |
+
settings["keep_last_n_checkpoints"],
|
153 |
)
|
154 |
|
155 |
with open(file_setting, "r") as f:
|
|
|
158 |
settings["logger"] = "wandb"
|
159 |
if "bnb_optimizer" not in settings:
|
160 |
settings["bnb_optimizer"] = False
|
161 |
+
if "keep_last_n_checkpoints" not in settings:
|
162 |
+
settings["keep_last_n_checkpoints"] = -1 # Default to keep all checkpoints
|
163 |
if "last_per_updates" not in settings: # patch for backward compatibility, with before f992c4e
|
164 |
settings["last_per_updates"] = settings["last_per_steps"] // settings["grad_accumulation_steps"]
|
165 |
return (
|
|
|
181 |
settings["mixed_precision"],
|
182 |
settings["logger"],
|
183 |
settings["bnb_optimizer"],
|
184 |
+
settings["keep_last_n_checkpoints"],
|
185 |
)
|
186 |
|
187 |
|
|
|
397 |
stream=False,
|
398 |
logger="wandb",
|
399 |
ch_8bit_adam=False,
|
400 |
+
keep_last_n_checkpoints=-1,
|
401 |
):
|
402 |
global training_process, tts_api, stop_signal
|
403 |
|
|
|
459 |
f"--num_warmup_updates {num_warmup_updates} "
|
460 |
f"--save_per_updates {save_per_updates} "
|
461 |
f"--last_per_updates {last_per_updates} "
|
462 |
+
f"--dataset_name {dataset_name} "
|
463 |
+
f"--keep_last_n_checkpoints {keep_last_n_checkpoints}"
|
464 |
)
|
465 |
|
466 |
if finetune:
|
|
|
501 |
mixed_precision,
|
502 |
logger,
|
503 |
ch_8bit_adam,
|
504 |
+
keep_last_n_checkpoints,
|
505 |
)
|
506 |
|
507 |
try:
|
|
|
1574 |
with gr.Row():
|
1575 |
save_per_updates = gr.Number(label="Save per Updates", value=300)
|
1576 |
last_per_updates = gr.Number(label="Last per Updates", value=100)
|
1577 |
+
keep_last_n_checkpoints = gr.Number(
|
1578 |
+
label="Keep Last N Checkpoints",
|
1579 |
+
value=-1,
|
1580 |
+
step=1,
|
1581 |
+
precision=0,
|
1582 |
+
info="-1: Keep all checkpoints, 0: Only save final model_last.pt, N>0: Keep last N checkpoints",
|
1583 |
+
)
|
1584 |
|
1585 |
with gr.Row():
|
1586 |
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
|
|
|
1609 |
mixed_precisionv,
|
1610 |
cd_loggerv,
|
1611 |
ch_8bit_adamv,
|
1612 |
+
keep_last_n_checkpointsv,
|
1613 |
) = load_settings(projects_selelect)
|
1614 |
exp_name.value = exp_namev
|
1615 |
learning_rate.value = learning_ratev
|
|
|
1629 |
mixed_precision.value = mixed_precisionv
|
1630 |
cd_logger.value = cd_loggerv
|
1631 |
ch_8bit_adam.value = ch_8bit_adamv
|
1632 |
+
keep_last_n_checkpoints.value = keep_last_n_checkpointsv
|
1633 |
|
1634 |
ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
|
1635 |
txt_info_train = gr.Text(label="Info", value="")
|
|
|
1689 |
ch_stream,
|
1690 |
cd_logger,
|
1691 |
ch_8bit_adam,
|
1692 |
+
keep_last_n_checkpoints,
|
1693 |
],
|
1694 |
outputs=[txt_info_train, start_button, stop_button],
|
1695 |
)
|
src/f5_tts/train/train.py
CHANGED
@@ -61,6 +61,7 @@ def main(cfg):
|
|
61 |
mel_spec_type=mel_spec_type,
|
62 |
is_local_vocoder=cfg.model.vocoder.is_local,
|
63 |
local_vocoder_path=cfg.model.vocoder.local_path,
|
|
|
64 |
)
|
65 |
|
66 |
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
|
|
|
61 |
mel_spec_type=mel_spec_type,
|
62 |
is_local_vocoder=cfg.model.vocoder.is_local,
|
63 |
local_vocoder_path=cfg.model.vocoder.local_path,
|
64 |
+
keep_last_n_checkpoints=getattr(cfg.ckpts, "keep_last_n_checkpoints", None),
|
65 |
)
|
66 |
|
67 |
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
|