Spaces:
Configuration error
Configuration error
Merge pull request #729 from hcsolakoglu/fix-ckpt-rotation
Browse filesExclude pretrained models from the checkpoint rotation logic
src/f5_tts/model/trainer.py
CHANGED
@@ -160,10 +160,14 @@ class Trainer:
|
|
160 |
return
|
161 |
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
|
162 |
if self.keep_last_n_checkpoints > 0:
|
|
|
163 |
checkpoints = [
|
164 |
f
|
165 |
for f in os.listdir(self.checkpoint_path)
|
166 |
-
if f.startswith("model_")
|
|
|
|
|
|
|
167 |
]
|
168 |
checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
|
169 |
while len(checkpoints) > self.keep_last_n_checkpoints:
|
@@ -183,10 +187,24 @@ class Trainer:
|
|
183 |
if "model_last.pt" in os.listdir(self.checkpoint_path):
|
184 |
latest_checkpoint = "model_last.pt"
|
185 |
else:
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
191 |
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
|
192 |
|
|
|
160 |
return
|
161 |
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
|
162 |
if self.keep_last_n_checkpoints > 0:
|
163 |
+
# Updated logic to exclude pretrained model from rotation
|
164 |
checkpoints = [
|
165 |
f
|
166 |
for f in os.listdir(self.checkpoint_path)
|
167 |
+
if f.startswith("model_")
|
168 |
+
and not f.startswith("pretrained_") # Exclude pretrained models
|
169 |
+
and f.endswith(".pt")
|
170 |
+
and f != "model_last.pt"
|
171 |
]
|
172 |
checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
|
173 |
while len(checkpoints) > self.keep_last_n_checkpoints:
|
|
|
187 |
if "model_last.pt" in os.listdir(self.checkpoint_path):
|
188 |
latest_checkpoint = "model_last.pt"
|
189 |
else:
|
190 |
+
# Updated to consider pretrained models for loading but prioritize training checkpoints
|
191 |
+
all_checkpoints = [
|
192 |
+
f
|
193 |
+
for f in os.listdir(self.checkpoint_path)
|
194 |
+
if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith(".pt")
|
195 |
+
]
|
196 |
+
|
197 |
+
# First try to find regular training checkpoints
|
198 |
+
training_checkpoints = [f for f in all_checkpoints if f.startswith("model_") and f != "model_last.pt"]
|
199 |
+
if training_checkpoints:
|
200 |
+
latest_checkpoint = sorted(
|
201 |
+
training_checkpoints,
|
202 |
+
key=lambda x: int("".join(filter(str.isdigit, x))),
|
203 |
+
)[-1]
|
204 |
+
else:
|
205 |
+
# If no training checkpoints, use pretrained model
|
206 |
+
latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_"))
|
207 |
+
|
208 |
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
209 |
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
|
210 |
|
src/f5_tts/train/finetune_cli.py
CHANGED
@@ -111,7 +111,8 @@ def main():
|
|
111 |
if not os.path.isdir(checkpoint_path):
|
112 |
os.makedirs(checkpoint_path, exist_ok=True)
|
113 |
|
114 |
-
|
|
|
115 |
if not os.path.isfile(file_checkpoint):
|
116 |
shutil.copy2(ckpt_path, file_checkpoint)
|
117 |
print("copy checkpoint for finetune")
|
|
|
111 |
if not os.path.isdir(checkpoint_path):
|
112 |
os.makedirs(checkpoint_path, exist_ok=True)
|
113 |
|
114 |
+
# Change: Add 'pretrained_' prefix to copied model
|
115 |
+
file_checkpoint = os.path.join(checkpoint_path, "pretrained_" + os.path.basename(ckpt_path))
|
116 |
if not os.path.isfile(file_checkpoint):
|
117 |
shutil.copy2(ckpt_path, file_checkpoint)
|
118 |
print("copy checkpoint for finetune")
|
src/f5_tts/train/finetune_gradio.py
CHANGED
@@ -1099,7 +1099,9 @@ def vocab_extend(project_name, symbols, model_type):
|
|
1099 |
dataset_name = name_project.replace("_pinyin", "").replace("_char", "")
|
1100 |
new_ckpt_path = os.path.join(path_project_ckpts, dataset_name)
|
1101 |
os.makedirs(new_ckpt_path, exist_ok=True)
|
1102 |
-
|
|
|
|
|
1103 |
|
1104 |
size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
|
1105 |
|
|
|
1099 |
dataset_name = name_project.replace("_pinyin", "").replace("_char", "")
|
1100 |
new_ckpt_path = os.path.join(path_project_ckpts, dataset_name)
|
1101 |
os.makedirs(new_ckpt_path, exist_ok=True)
|
1102 |
+
|
1103 |
+
# Add pretrained_ prefix to model when copying for consistency with finetune_cli.py
|
1104 |
+
new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_model_1200000.pt")
|
1105 |
|
1106 |
size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
|
1107 |
|