0.4.0 fix gradient accumulation; change checkpointing logic to per_updates
Browse files- pyproject.toml +1 -1
- src/f5_tts/configs/E2TTS_Base_train.yaml +3 -3
- src/f5_tts/configs/E2TTS_Small_train.yaml +3 -3
- src/f5_tts/configs/F5TTS_Base_train.yaml +3 -3
- src/f5_tts/configs/F5TTS_Small_train.yaml +3 -3
- src/f5_tts/model/trainer.py +68 -50
- src/f5_tts/scripts/count_max_epoch.py +2 -2
- src/f5_tts/train/finetune_cli.py +6 -6
- src/f5_tts/train/finetune_gradio.py +25 -23
- src/f5_tts/train/train.py +1 -1
pyproject.toml
CHANGED
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4 |
|
5 |
[project]
|
6 |
name = "f5-tts"
|
7 |
-
version = "0.
|
8 |
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
9 |
readme = "README.md"
|
10 |
license = {text = "MIT License"}
|
|
|
4 |
|
5 |
[project]
|
6 |
name = "f5-tts"
|
7 |
+
version = "0.4.0"
|
8 |
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
9 |
readme = "README.md"
|
10 |
license = {text = "MIT License"}
|
src/f5_tts/configs/E2TTS_Base_train.yaml
CHANGED
@@ -12,7 +12,7 @@ datasets:
|
|
12 |
optim:
|
13 |
epochs: 15
|
14 |
learning_rate: 7.5e-5
|
15 |
-
num_warmup_updates: 20000 # warmup
|
16 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
max_grad_norm: 1.0 # gradient clipping
|
18 |
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
@@ -39,6 +39,6 @@ model:
|
|
39 |
|
40 |
ckpts:
|
41 |
logger: wandb # wandb | tensorboard | None
|
42 |
-
save_per_updates: 50000 # save checkpoint per
|
43 |
-
|
44 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
|
|
12 |
optim:
|
13 |
epochs: 15
|
14 |
learning_rate: 7.5e-5
|
15 |
+
num_warmup_updates: 20000 # warmup updates
|
16 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
max_grad_norm: 1.0 # gradient clipping
|
18 |
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
|
|
39 |
|
40 |
ckpts:
|
41 |
logger: wandb # wandb | tensorboard | None
|
42 |
+
save_per_updates: 50000 # save checkpoint per updates
|
43 |
+
last_per_updates: 5000 # save last checkpoint per updates
|
44 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/configs/E2TTS_Small_train.yaml
CHANGED
@@ -12,7 +12,7 @@ datasets:
|
|
12 |
optim:
|
13 |
epochs: 15
|
14 |
learning_rate: 7.5e-5
|
15 |
-
num_warmup_updates: 20000 # warmup
|
16 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
max_grad_norm: 1.0
|
18 |
bnb_optimizer: False
|
@@ -39,6 +39,6 @@ model:
|
|
39 |
|
40 |
ckpts:
|
41 |
logger: wandb # wandb | tensorboard | None
|
42 |
-
save_per_updates: 50000 # save checkpoint per
|
43 |
-
|
44 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
|
|
12 |
optim:
|
13 |
epochs: 15
|
14 |
learning_rate: 7.5e-5
|
15 |
+
num_warmup_updates: 20000 # warmup updates
|
16 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
max_grad_norm: 1.0
|
18 |
bnb_optimizer: False
|
|
|
39 |
|
40 |
ckpts:
|
41 |
logger: wandb # wandb | tensorboard | None
|
42 |
+
save_per_updates: 50000 # save checkpoint per updates
|
43 |
+
last_per_updates: 5000 # save last checkpoint per updates
|
44 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/configs/F5TTS_Base_train.yaml
CHANGED
@@ -12,7 +12,7 @@ datasets:
|
|
12 |
optim:
|
13 |
epochs: 15
|
14 |
learning_rate: 7.5e-5
|
15 |
-
num_warmup_updates: 20000 # warmup
|
16 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
max_grad_norm: 1.0 # gradient clipping
|
18 |
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
@@ -42,6 +42,6 @@ model:
|
|
42 |
|
43 |
ckpts:
|
44 |
logger: wandb # wandb | tensorboard | None
|
45 |
-
save_per_updates: 50000 # save checkpoint per
|
46 |
-
|
47 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
|
|
12 |
optim:
|
13 |
epochs: 15
|
14 |
learning_rate: 7.5e-5
|
15 |
+
num_warmup_updates: 20000 # warmup updates
|
16 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
max_grad_norm: 1.0 # gradient clipping
|
18 |
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
|
|
42 |
|
43 |
ckpts:
|
44 |
logger: wandb # wandb | tensorboard | None
|
45 |
+
save_per_updates: 50000 # save checkpoint per updates
|
46 |
+
last_per_updates: 5000 # save last checkpoint per updates
|
47 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/configs/F5TTS_Small_train.yaml
CHANGED
@@ -12,7 +12,7 @@ datasets:
|
|
12 |
optim:
|
13 |
epochs: 15
|
14 |
learning_rate: 7.5e-5
|
15 |
-
num_warmup_updates: 20000 # warmup
|
16 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
max_grad_norm: 1.0 # gradient clipping
|
18 |
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
@@ -42,6 +42,6 @@ model:
|
|
42 |
|
43 |
ckpts:
|
44 |
logger: wandb # wandb | tensorboard | None
|
45 |
-
save_per_updates: 50000 # save checkpoint per
|
46 |
-
|
47 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
|
|
12 |
optim:
|
13 |
epochs: 15
|
14 |
learning_rate: 7.5e-5
|
15 |
+
num_warmup_updates: 20000 # warmup updates
|
16 |
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
17 |
max_grad_norm: 1.0 # gradient clipping
|
18 |
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
|
|
42 |
|
43 |
ckpts:
|
44 |
logger: wandb # wandb | tensorboard | None
|
45 |
+
save_per_updates: 50000 # save checkpoint per updates
|
46 |
+
last_per_updates: 5000 # save last checkpoint per updates
|
47 |
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/model/trainer.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import gc
|
|
|
4 |
import os
|
5 |
|
6 |
import torch
|
@@ -42,7 +43,7 @@ class Trainer:
|
|
42 |
wandb_run_name="test_run",
|
43 |
wandb_resume_id: str = None,
|
44 |
log_samples: bool = False,
|
45 |
-
|
46 |
accelerate_kwargs: dict = dict(),
|
47 |
ema_kwargs: dict = dict(),
|
48 |
bnb_optimizer: bool = False,
|
@@ -57,6 +58,11 @@ class Trainer:
|
|
57 |
print(f"Using logger: {logger}")
|
58 |
self.log_samples = log_samples
|
59 |
|
|
|
|
|
|
|
|
|
|
|
60 |
self.accelerator = Accelerator(
|
61 |
log_with=logger if logger == "wandb" else None,
|
62 |
kwargs_handlers=[ddp_kwargs],
|
@@ -102,7 +108,7 @@ class Trainer:
|
|
102 |
self.epochs = epochs
|
103 |
self.num_warmup_updates = num_warmup_updates
|
104 |
self.save_per_updates = save_per_updates
|
105 |
-
self.
|
106 |
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
|
107 |
|
108 |
self.batch_size = batch_size
|
@@ -132,7 +138,7 @@ class Trainer:
|
|
132 |
def is_main(self):
|
133 |
return self.accelerator.is_main_process
|
134 |
|
135 |
-
def save_checkpoint(self,
|
136 |
self.accelerator.wait_for_everyone()
|
137 |
if self.is_main:
|
138 |
checkpoint = dict(
|
@@ -140,15 +146,15 @@ class Trainer:
|
|
140 |
optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
|
141 |
ema_model_state_dict=self.ema_model.state_dict(),
|
142 |
scheduler_state_dict=self.scheduler.state_dict(),
|
143 |
-
|
144 |
)
|
145 |
if not os.path.exists(self.checkpoint_path):
|
146 |
os.makedirs(self.checkpoint_path)
|
147 |
if last:
|
148 |
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
|
149 |
-
print(f"Saved last checkpoint at
|
150 |
else:
|
151 |
-
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{
|
152 |
|
153 |
def load_checkpoint(self):
|
154 |
if (
|
@@ -177,7 +183,14 @@ class Trainer:
|
|
177 |
if self.is_main:
|
178 |
self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
|
179 |
|
180 |
-
if "step" in checkpoint:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
# patch for backward compatibility, 305e3ea
|
182 |
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
|
183 |
if key in checkpoint["model_state_dict"]:
|
@@ -187,19 +200,19 @@ class Trainer:
|
|
187 |
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
|
188 |
if self.scheduler:
|
189 |
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
190 |
-
|
191 |
else:
|
192 |
checkpoint["model_state_dict"] = {
|
193 |
k.replace("ema_model.", ""): v
|
194 |
for k, v in checkpoint["ema_model_state_dict"].items()
|
195 |
-
if k not in ["initted", "step"]
|
196 |
}
|
197 |
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
|
198 |
-
|
199 |
|
200 |
del checkpoint
|
201 |
gc.collect()
|
202 |
-
return
|
203 |
|
204 |
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
|
205 |
if self.log_samples:
|
@@ -248,25 +261,26 @@ class Trainer:
|
|
248 |
|
249 |
# accelerator.prepare() dispatches batches to devices;
|
250 |
# which means the length of dataloader calculated before, should consider the number of devices
|
251 |
-
|
252 |
self.num_warmup_updates * self.accelerator.num_processes
|
253 |
) # consider a fixed warmup steps while using accelerate multi-gpu ddp
|
254 |
# otherwise by default with split_batches=False, warmup steps change with num_processes
|
255 |
-
|
256 |
-
|
257 |
-
warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=
|
258 |
-
decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=
|
259 |
self.scheduler = SequentialLR(
|
260 |
-
self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[
|
261 |
)
|
262 |
train_dataloader, self.scheduler = self.accelerator.prepare(
|
263 |
train_dataloader, self.scheduler
|
264 |
-
) # actual
|
265 |
-
|
266 |
-
|
267 |
|
268 |
if exists(resumable_with_seed):
|
269 |
orig_epoch_step = len(train_dataloader)
|
|
|
270 |
skipped_epoch = int(start_step // orig_epoch_step)
|
271 |
skipped_batch = start_step % orig_epoch_step
|
272 |
skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
|
@@ -276,23 +290,21 @@ class Trainer:
|
|
276 |
for epoch in range(skipped_epoch, self.epochs):
|
277 |
self.model.train()
|
278 |
if exists(resumable_with_seed) and epoch == skipped_epoch:
|
279 |
-
|
280 |
-
|
281 |
-
desc=f"Epoch {epoch+1}/{self.epochs}",
|
282 |
-
unit="step",
|
283 |
-
disable=not self.accelerator.is_local_main_process,
|
284 |
-
initial=skipped_batch,
|
285 |
-
total=orig_epoch_step,
|
286 |
-
)
|
287 |
else:
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
|
|
|
|
|
|
|
|
296 |
with self.accelerator.accumulate(self.model):
|
297 |
text_inputs = batch["text"]
|
298 |
mel_spec = batch["mel"].permute(0, 2, 1)
|
@@ -301,7 +313,7 @@ class Trainer:
|
|
301 |
# TODO. add duration predictor training
|
302 |
if self.duration_predictor is not None and self.accelerator.is_local_main_process:
|
303 |
dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
|
304 |
-
self.accelerator.log({"duration loss": dur_loss.item()}, step=
|
305 |
|
306 |
loss, cond, pred = self.model(
|
307 |
mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
|
@@ -318,18 +330,20 @@ class Trainer:
|
|
318 |
if self.is_main and self.accelerator.sync_gradients:
|
319 |
self.ema_model.update()
|
320 |
|
321 |
-
|
|
|
|
|
322 |
|
323 |
if self.accelerator.is_local_main_process:
|
324 |
-
self.accelerator.log(
|
|
|
|
|
325 |
if self.logger == "tensorboard":
|
326 |
-
self.writer.add_scalar("loss", loss.item(),
|
327 |
-
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0],
|
328 |
-
|
329 |
-
progress_bar.set_postfix(step=str(global_step), loss=loss.item())
|
330 |
|
331 |
-
if
|
332 |
-
self.save_checkpoint(
|
333 |
|
334 |
if self.log_samples and self.accelerator.is_local_main_process:
|
335 |
ref_audio_len = mel_lengths[0]
|
@@ -355,12 +369,16 @@ class Trainer:
|
|
355 |
gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
|
356 |
ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
|
357 |
|
358 |
-
torchaudio.save(
|
359 |
-
|
|
|
|
|
|
|
|
|
360 |
|
361 |
-
if
|
362 |
-
self.save_checkpoint(
|
363 |
|
364 |
-
self.save_checkpoint(
|
365 |
|
366 |
self.accelerator.end_training()
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import gc
|
4 |
+
import math
|
5 |
import os
|
6 |
|
7 |
import torch
|
|
|
43 |
wandb_run_name="test_run",
|
44 |
wandb_resume_id: str = None,
|
45 |
log_samples: bool = False,
|
46 |
+
last_per_updates=None,
|
47 |
accelerate_kwargs: dict = dict(),
|
48 |
ema_kwargs: dict = dict(),
|
49 |
bnb_optimizer: bool = False,
|
|
|
58 |
print(f"Using logger: {logger}")
|
59 |
self.log_samples = log_samples
|
60 |
|
61 |
+
if grad_accumulation_steps > 1 and self.is_main:
|
62 |
+
print(
|
63 |
+
"Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e"
|
64 |
+
)
|
65 |
+
|
66 |
self.accelerator = Accelerator(
|
67 |
log_with=logger if logger == "wandb" else None,
|
68 |
kwargs_handlers=[ddp_kwargs],
|
|
|
108 |
self.epochs = epochs
|
109 |
self.num_warmup_updates = num_warmup_updates
|
110 |
self.save_per_updates = save_per_updates
|
111 |
+
self.last_per_updates = default(last_per_updates, save_per_updates)
|
112 |
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
|
113 |
|
114 |
self.batch_size = batch_size
|
|
|
138 |
def is_main(self):
|
139 |
return self.accelerator.is_main_process
|
140 |
|
141 |
+
def save_checkpoint(self, update, last=False):
|
142 |
self.accelerator.wait_for_everyone()
|
143 |
if self.is_main:
|
144 |
checkpoint = dict(
|
|
|
146 |
optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
|
147 |
ema_model_state_dict=self.ema_model.state_dict(),
|
148 |
scheduler_state_dict=self.scheduler.state_dict(),
|
149 |
+
update=update,
|
150 |
)
|
151 |
if not os.path.exists(self.checkpoint_path):
|
152 |
os.makedirs(self.checkpoint_path)
|
153 |
if last:
|
154 |
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
|
155 |
+
print(f"Saved last checkpoint at update {update}")
|
156 |
else:
|
157 |
+
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
|
158 |
|
159 |
def load_checkpoint(self):
|
160 |
if (
|
|
|
183 |
if self.is_main:
|
184 |
self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
|
185 |
|
186 |
+
if "update" in checkpoint or "step" in checkpoint:
|
187 |
+
# patch for backward compatibility, with before f992c4e
|
188 |
+
if "step" in checkpoint:
|
189 |
+
checkpoint["update"] = checkpoint["step"] // self.grad_accumulation_steps
|
190 |
+
if self.grad_accumulation_steps > 1 and self.is_main:
|
191 |
+
print(
|
192 |
+
"F5-TTS WARNING: Loading checkpoint saved with per_steps logic (before f992c4e), will convert to per_updates according to grad_accumulation_steps setting, may have unexpected behaviour."
|
193 |
+
)
|
194 |
# patch for backward compatibility, 305e3ea
|
195 |
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
|
196 |
if key in checkpoint["model_state_dict"]:
|
|
|
200 |
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
|
201 |
if self.scheduler:
|
202 |
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
203 |
+
update = checkpoint["update"]
|
204 |
else:
|
205 |
checkpoint["model_state_dict"] = {
|
206 |
k.replace("ema_model.", ""): v
|
207 |
for k, v in checkpoint["ema_model_state_dict"].items()
|
208 |
+
if k not in ["initted", "update", "step"]
|
209 |
}
|
210 |
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
|
211 |
+
update = 0
|
212 |
|
213 |
del checkpoint
|
214 |
gc.collect()
|
215 |
+
return update
|
216 |
|
217 |
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
|
218 |
if self.log_samples:
|
|
|
261 |
|
262 |
# accelerator.prepare() dispatches batches to devices;
|
263 |
# which means the length of dataloader calculated before, should consider the number of devices
|
264 |
+
warmup_updates = (
|
265 |
self.num_warmup_updates * self.accelerator.num_processes
|
266 |
) # consider a fixed warmup steps while using accelerate multi-gpu ddp
|
267 |
# otherwise by default with split_batches=False, warmup steps change with num_processes
|
268 |
+
total_updates = math.ceil(len(train_dataloader) / self.grad_accumulation_steps) * self.epochs
|
269 |
+
decay_updates = total_updates - warmup_updates
|
270 |
+
warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_updates)
|
271 |
+
decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_updates)
|
272 |
self.scheduler = SequentialLR(
|
273 |
+
self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_updates]
|
274 |
)
|
275 |
train_dataloader, self.scheduler = self.accelerator.prepare(
|
276 |
train_dataloader, self.scheduler
|
277 |
+
) # actual multi_gpu updates = single_gpu updates / gpu nums
|
278 |
+
start_update = self.load_checkpoint()
|
279 |
+
global_update = start_update
|
280 |
|
281 |
if exists(resumable_with_seed):
|
282 |
orig_epoch_step = len(train_dataloader)
|
283 |
+
start_step = start_update * self.grad_accumulation_steps
|
284 |
skipped_epoch = int(start_step // orig_epoch_step)
|
285 |
skipped_batch = start_step % orig_epoch_step
|
286 |
skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
|
|
|
290 |
for epoch in range(skipped_epoch, self.epochs):
|
291 |
self.model.train()
|
292 |
if exists(resumable_with_seed) and epoch == skipped_epoch:
|
293 |
+
progress_bar_initial = math.ceil(skipped_batch / self.grad_accumulation_steps)
|
294 |
+
current_dataloader = skipped_dataloader
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
else:
|
296 |
+
progress_bar_initial = 0
|
297 |
+
current_dataloader = train_dataloader
|
298 |
+
|
299 |
+
progress_bar = tqdm(
|
300 |
+
range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
|
301 |
+
desc=f"Epoch {epoch+1}/{self.epochs}",
|
302 |
+
unit="update",
|
303 |
+
disable=not self.accelerator.is_local_main_process,
|
304 |
+
initial=progress_bar_initial,
|
305 |
+
)
|
306 |
+
|
307 |
+
for batch in current_dataloader:
|
308 |
with self.accelerator.accumulate(self.model):
|
309 |
text_inputs = batch["text"]
|
310 |
mel_spec = batch["mel"].permute(0, 2, 1)
|
|
|
313 |
# TODO. add duration predictor training
|
314 |
if self.duration_predictor is not None and self.accelerator.is_local_main_process:
|
315 |
dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
|
316 |
+
self.accelerator.log({"duration loss": dur_loss.item()}, step=global_update)
|
317 |
|
318 |
loss, cond, pred = self.model(
|
319 |
mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
|
|
|
330 |
if self.is_main and self.accelerator.sync_gradients:
|
331 |
self.ema_model.update()
|
332 |
|
333 |
+
global_update += 1
|
334 |
+
progress_bar.update(1)
|
335 |
+
progress_bar.set_postfix(update=str(global_update), loss=loss.item())
|
336 |
|
337 |
if self.accelerator.is_local_main_process:
|
338 |
+
self.accelerator.log(
|
339 |
+
{"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_update
|
340 |
+
)
|
341 |
if self.logger == "tensorboard":
|
342 |
+
self.writer.add_scalar("loss", loss.item(), global_update)
|
343 |
+
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update)
|
|
|
|
|
344 |
|
345 |
+
if global_update % self.save_per_updates == 0:
|
346 |
+
self.save_checkpoint(global_update)
|
347 |
|
348 |
if self.log_samples and self.accelerator.is_local_main_process:
|
349 |
ref_audio_len = mel_lengths[0]
|
|
|
369 |
gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
|
370 |
ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
|
371 |
|
372 |
+
torchaudio.save(
|
373 |
+
f"{log_samples_path}/update_{global_update}_gen.wav", gen_audio, target_sample_rate
|
374 |
+
)
|
375 |
+
torchaudio.save(
|
376 |
+
f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate
|
377 |
+
)
|
378 |
|
379 |
+
if global_update % self.last_per_updates == 0:
|
380 |
+
self.save_checkpoint(global_update, last=True)
|
381 |
|
382 |
+
self.save_checkpoint(global_update, last=True)
|
383 |
|
384 |
self.accelerator.end_training()
|
src/f5_tts/scripts/count_max_epoch.py
CHANGED
@@ -20,13 +20,13 @@ grad_accum = 1
|
|
20 |
mini_batch_frames = frames_per_gpu * grad_accum * gpus
|
21 |
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
|
22 |
updates_per_epoch = total_hours / mini_batch_hours
|
23 |
-
steps_per_epoch = updates_per_epoch * grad_accum
|
24 |
|
25 |
# result
|
26 |
epochs = wanted_max_updates / updates_per_epoch
|
27 |
print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
|
28 |
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
|
29 |
-
print(f" or approx. 0/{steps_per_epoch:.0f} steps")
|
30 |
|
31 |
# others
|
32 |
print(f"total {total_hours:.0f} hours")
|
|
|
20 |
mini_batch_frames = frames_per_gpu * grad_accum * gpus
|
21 |
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
|
22 |
updates_per_epoch = total_hours / mini_batch_hours
|
23 |
+
# steps_per_epoch = updates_per_epoch * grad_accum
|
24 |
|
25 |
# result
|
26 |
epochs = wanted_max_updates / updates_per_epoch
|
27 |
print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
|
28 |
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
|
29 |
+
# print(f" or approx. 0/{steps_per_epoch:.0f} steps")
|
30 |
|
31 |
# others
|
32 |
print(f"total {total_hours:.0f} hours")
|
src/f5_tts/train/finetune_cli.py
CHANGED
@@ -27,7 +27,7 @@ def parse_args():
|
|
27 |
|
28 |
# num_warmup_updates = 300 for 5000 sample about 10 hours
|
29 |
|
30 |
-
# change save_per_updates ,
|
31 |
|
32 |
parser = argparse.ArgumentParser(description="Train CFM Model")
|
33 |
|
@@ -44,9 +44,9 @@ def parse_args():
|
|
44 |
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
|
45 |
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
|
46 |
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
|
47 |
-
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup
|
48 |
-
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X
|
49 |
-
parser.add_argument("--
|
50 |
parser.add_argument("--finetune", action="store_true", help="Use Finetune")
|
51 |
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
|
52 |
parser.add_argument(
|
@@ -61,7 +61,7 @@ def parse_args():
|
|
61 |
parser.add_argument(
|
62 |
"--log_samples",
|
63 |
action="store_true",
|
64 |
-
help="Log inferenced samples per ckpt save
|
65 |
)
|
66 |
parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
|
67 |
parser.add_argument(
|
@@ -156,7 +156,7 @@ def main():
|
|
156 |
wandb_run_name=args.exp_name,
|
157 |
wandb_resume_id=wandb_resume_id,
|
158 |
log_samples=args.log_samples,
|
159 |
-
|
160 |
bnb_optimizer=args.bnb_optimizer,
|
161 |
)
|
162 |
|
|
|
27 |
|
28 |
# num_warmup_updates = 300 for 5000 sample about 10 hours
|
29 |
|
30 |
+
# change save_per_updates , last_per_updates change this value what you need ,
|
31 |
|
32 |
parser = argparse.ArgumentParser(description="Train CFM Model")
|
33 |
|
|
|
44 |
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
|
45 |
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
|
46 |
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
|
47 |
+
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates")
|
48 |
+
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates")
|
49 |
+
parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates")
|
50 |
parser.add_argument("--finetune", action="store_true", help="Use Finetune")
|
51 |
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
|
52 |
parser.add_argument(
|
|
|
61 |
parser.add_argument(
|
62 |
"--log_samples",
|
63 |
action="store_true",
|
64 |
+
help="Log inferenced samples per ckpt save updates",
|
65 |
)
|
66 |
parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
|
67 |
parser.add_argument(
|
|
|
156 |
wandb_run_name=args.exp_name,
|
157 |
wandb_resume_id=wandb_resume_id,
|
158 |
log_samples=args.log_samples,
|
159 |
+
last_per_updates=args.last_per_updates,
|
160 |
bnb_optimizer=args.bnb_optimizer,
|
161 |
)
|
162 |
|
src/f5_tts/train/finetune_gradio.py
CHANGED
@@ -62,7 +62,7 @@ def save_settings(
|
|
62 |
epochs,
|
63 |
num_warmup_updates,
|
64 |
save_per_updates,
|
65 |
-
|
66 |
finetune,
|
67 |
file_checkpoint_train,
|
68 |
tokenizer_type,
|
@@ -86,7 +86,7 @@ def save_settings(
|
|
86 |
"epochs": epochs,
|
87 |
"num_warmup_updates": num_warmup_updates,
|
88 |
"save_per_updates": save_per_updates,
|
89 |
-
"
|
90 |
"finetune": finetune,
|
91 |
"file_checkpoint_train": file_checkpoint_train,
|
92 |
"tokenizer_type": tokenizer_type,
|
@@ -118,7 +118,7 @@ def load_settings(project_name):
|
|
118 |
"epochs": 100,
|
119 |
"num_warmup_updates": 2,
|
120 |
"save_per_updates": 300,
|
121 |
-
"
|
122 |
"finetune": True,
|
123 |
"file_checkpoint_train": "",
|
124 |
"tokenizer_type": "pinyin",
|
@@ -138,7 +138,7 @@ def load_settings(project_name):
|
|
138 |
settings["epochs"],
|
139 |
settings["num_warmup_updates"],
|
140 |
settings["save_per_updates"],
|
141 |
-
settings["
|
142 |
settings["finetune"],
|
143 |
settings["file_checkpoint_train"],
|
144 |
settings["tokenizer_type"],
|
@@ -154,6 +154,8 @@ def load_settings(project_name):
|
|
154 |
settings["logger"] = "wandb"
|
155 |
if "bnb_optimizer" not in settings:
|
156 |
settings["bnb_optimizer"] = False
|
|
|
|
|
157 |
return (
|
158 |
settings["exp_name"],
|
159 |
settings["learning_rate"],
|
@@ -165,7 +167,7 @@ def load_settings(project_name):
|
|
165 |
settings["epochs"],
|
166 |
settings["num_warmup_updates"],
|
167 |
settings["save_per_updates"],
|
168 |
-
settings["
|
169 |
settings["finetune"],
|
170 |
settings["file_checkpoint_train"],
|
171 |
settings["tokenizer_type"],
|
@@ -379,7 +381,7 @@ def start_training(
|
|
379 |
epochs=11,
|
380 |
num_warmup_updates=200,
|
381 |
save_per_updates=400,
|
382 |
-
|
383 |
finetune=True,
|
384 |
file_checkpoint_train="",
|
385 |
tokenizer_type="pinyin",
|
@@ -448,7 +450,7 @@ def start_training(
|
|
448 |
f"--epochs {epochs} "
|
449 |
f"--num_warmup_updates {num_warmup_updates} "
|
450 |
f"--save_per_updates {save_per_updates} "
|
451 |
-
f"--
|
452 |
f"--dataset_name {dataset_name}"
|
453 |
)
|
454 |
|
@@ -482,7 +484,7 @@ def start_training(
|
|
482 |
epochs,
|
483 |
num_warmup_updates,
|
484 |
save_per_updates,
|
485 |
-
|
486 |
finetune,
|
487 |
file_checkpoint_train,
|
488 |
tokenizer_type,
|
@@ -880,7 +882,7 @@ def calculate_train(
|
|
880 |
learning_rate,
|
881 |
num_warmup_updates,
|
882 |
save_per_updates,
|
883 |
-
|
884 |
finetune,
|
885 |
):
|
886 |
path_project = os.path.join(path_data, name_project)
|
@@ -892,7 +894,7 @@ def calculate_train(
|
|
892 |
max_samples,
|
893 |
num_warmup_updates,
|
894 |
save_per_updates,
|
895 |
-
|
896 |
"project not found !",
|
897 |
learning_rate,
|
898 |
)
|
@@ -940,14 +942,14 @@ def calculate_train(
|
|
940 |
|
941 |
num_warmup_updates = int(samples * 0.05)
|
942 |
save_per_updates = int(samples * 0.10)
|
943 |
-
|
944 |
|
945 |
max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
|
946 |
num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
|
947 |
save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
|
948 |
-
|
949 |
-
if
|
950 |
-
|
951 |
|
952 |
total_hours = hours
|
953 |
mel_hop_length = 256
|
@@ -978,7 +980,7 @@ def calculate_train(
|
|
978 |
max_samples,
|
979 |
num_warmup_updates,
|
980 |
save_per_updates,
|
981 |
-
|
982 |
samples,
|
983 |
learning_rate,
|
984 |
int(epochs),
|
@@ -1530,7 +1532,7 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
|
|
1530 |
|
1531 |
with gr.TabItem("Train Data"):
|
1532 |
gr.Markdown("""```plaintext
|
1533 |
-
The auto-setting is still experimental. Please make sure that the epochs, save per updates, and last per
|
1534 |
If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
|
1535 |
```""")
|
1536 |
with gr.Row():
|
@@ -1561,7 +1563,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1561 |
|
1562 |
with gr.Row():
|
1563 |
save_per_updates = gr.Number(label="Save per Updates", value=300)
|
1564 |
-
|
1565 |
|
1566 |
with gr.Row():
|
1567 |
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
|
@@ -1582,7 +1584,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1582 |
epochsv,
|
1583 |
num_warmupv_updatesv,
|
1584 |
save_per_updatesv,
|
1585 |
-
|
1586 |
finetunev,
|
1587 |
file_checkpoint_trainv,
|
1588 |
tokenizer_typev,
|
@@ -1601,7 +1603,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1601 |
epochs.value = epochsv
|
1602 |
num_warmup_updates.value = num_warmupv_updatesv
|
1603 |
save_per_updates.value = save_per_updatesv
|
1604 |
-
|
1605 |
ch_finetune.value = finetunev
|
1606 |
file_checkpoint_train.value = file_checkpoint_trainv
|
1607 |
tokenizer_type.value = tokenizer_typev
|
@@ -1659,7 +1661,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1659 |
epochs,
|
1660 |
num_warmup_updates,
|
1661 |
save_per_updates,
|
1662 |
-
|
1663 |
ch_finetune,
|
1664 |
file_checkpoint_train,
|
1665 |
tokenizer_type,
|
@@ -1682,7 +1684,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1682 |
learning_rate,
|
1683 |
num_warmup_updates,
|
1684 |
save_per_updates,
|
1685 |
-
|
1686 |
ch_finetune,
|
1687 |
],
|
1688 |
outputs=[
|
@@ -1690,7 +1692,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1690 |
max_samples,
|
1691 |
num_warmup_updates,
|
1692 |
save_per_updates,
|
1693 |
-
|
1694 |
lb_samples,
|
1695 |
learning_rate,
|
1696 |
epochs,
|
@@ -1713,7 +1715,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1713 |
epochs,
|
1714 |
num_warmup_updates,
|
1715 |
save_per_updates,
|
1716 |
-
|
1717 |
ch_finetune,
|
1718 |
file_checkpoint_train,
|
1719 |
tokenizer_type,
|
|
|
62 |
epochs,
|
63 |
num_warmup_updates,
|
64 |
save_per_updates,
|
65 |
+
last_per_updates,
|
66 |
finetune,
|
67 |
file_checkpoint_train,
|
68 |
tokenizer_type,
|
|
|
86 |
"epochs": epochs,
|
87 |
"num_warmup_updates": num_warmup_updates,
|
88 |
"save_per_updates": save_per_updates,
|
89 |
+
"last_per_updates": last_per_updates,
|
90 |
"finetune": finetune,
|
91 |
"file_checkpoint_train": file_checkpoint_train,
|
92 |
"tokenizer_type": tokenizer_type,
|
|
|
118 |
"epochs": 100,
|
119 |
"num_warmup_updates": 2,
|
120 |
"save_per_updates": 300,
|
121 |
+
"last_per_updates": 100,
|
122 |
"finetune": True,
|
123 |
"file_checkpoint_train": "",
|
124 |
"tokenizer_type": "pinyin",
|
|
|
138 |
settings["epochs"],
|
139 |
settings["num_warmup_updates"],
|
140 |
settings["save_per_updates"],
|
141 |
+
settings["last_per_updates"],
|
142 |
settings["finetune"],
|
143 |
settings["file_checkpoint_train"],
|
144 |
settings["tokenizer_type"],
|
|
|
154 |
settings["logger"] = "wandb"
|
155 |
if "bnb_optimizer" not in settings:
|
156 |
settings["bnb_optimizer"] = False
|
157 |
+
if "last_per_updates" not in settings: # patch for backward compatibility, with before f992c4e
|
158 |
+
settings["last_per_updates"] = settings["last_per_steps"] // settings["grad_accumulation_steps"]
|
159 |
return (
|
160 |
settings["exp_name"],
|
161 |
settings["learning_rate"],
|
|
|
167 |
settings["epochs"],
|
168 |
settings["num_warmup_updates"],
|
169 |
settings["save_per_updates"],
|
170 |
+
settings["last_per_updates"],
|
171 |
settings["finetune"],
|
172 |
settings["file_checkpoint_train"],
|
173 |
settings["tokenizer_type"],
|
|
|
381 |
epochs=11,
|
382 |
num_warmup_updates=200,
|
383 |
save_per_updates=400,
|
384 |
+
last_per_updates=800,
|
385 |
finetune=True,
|
386 |
file_checkpoint_train="",
|
387 |
tokenizer_type="pinyin",
|
|
|
450 |
f"--epochs {epochs} "
|
451 |
f"--num_warmup_updates {num_warmup_updates} "
|
452 |
f"--save_per_updates {save_per_updates} "
|
453 |
+
f"--last_per_updates {last_per_updates} "
|
454 |
f"--dataset_name {dataset_name}"
|
455 |
)
|
456 |
|
|
|
484 |
epochs,
|
485 |
num_warmup_updates,
|
486 |
save_per_updates,
|
487 |
+
last_per_updates,
|
488 |
finetune,
|
489 |
file_checkpoint_train,
|
490 |
tokenizer_type,
|
|
|
882 |
learning_rate,
|
883 |
num_warmup_updates,
|
884 |
save_per_updates,
|
885 |
+
last_per_updates,
|
886 |
finetune,
|
887 |
):
|
888 |
path_project = os.path.join(path_data, name_project)
|
|
|
894 |
max_samples,
|
895 |
num_warmup_updates,
|
896 |
save_per_updates,
|
897 |
+
last_per_updates,
|
898 |
"project not found !",
|
899 |
learning_rate,
|
900 |
)
|
|
|
942 |
|
943 |
num_warmup_updates = int(samples * 0.05)
|
944 |
save_per_updates = int(samples * 0.10)
|
945 |
+
last_per_updates = int(save_per_updates * 0.25)
|
946 |
|
947 |
max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
|
948 |
num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
|
949 |
save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
|
950 |
+
last_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_updates)
|
951 |
+
if last_per_updates <= 0:
|
952 |
+
last_per_updates = 2
|
953 |
|
954 |
total_hours = hours
|
955 |
mel_hop_length = 256
|
|
|
980 |
max_samples,
|
981 |
num_warmup_updates,
|
982 |
save_per_updates,
|
983 |
+
last_per_updates,
|
984 |
samples,
|
985 |
learning_rate,
|
986 |
int(epochs),
|
|
|
1532 |
|
1533 |
with gr.TabItem("Train Data"):
|
1534 |
gr.Markdown("""```plaintext
|
1535 |
+
The auto-setting is still experimental. Please make sure that the epochs, save per updates, and last per updates are set correctly, or change them manually as needed.
|
1536 |
If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
|
1537 |
```""")
|
1538 |
with gr.Row():
|
|
|
1563 |
|
1564 |
with gr.Row():
|
1565 |
save_per_updates = gr.Number(label="Save per Updates", value=300)
|
1566 |
+
last_per_updates = gr.Number(label="Last per Updates", value=100)
|
1567 |
|
1568 |
with gr.Row():
|
1569 |
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
|
|
|
1584 |
epochsv,
|
1585 |
num_warmupv_updatesv,
|
1586 |
save_per_updatesv,
|
1587 |
+
last_per_updatesv,
|
1588 |
finetunev,
|
1589 |
file_checkpoint_trainv,
|
1590 |
tokenizer_typev,
|
|
|
1603 |
epochs.value = epochsv
|
1604 |
num_warmup_updates.value = num_warmupv_updatesv
|
1605 |
save_per_updates.value = save_per_updatesv
|
1606 |
+
last_per_updates.value = last_per_updatesv
|
1607 |
ch_finetune.value = finetunev
|
1608 |
file_checkpoint_train.value = file_checkpoint_trainv
|
1609 |
tokenizer_type.value = tokenizer_typev
|
|
|
1661 |
epochs,
|
1662 |
num_warmup_updates,
|
1663 |
save_per_updates,
|
1664 |
+
last_per_updates,
|
1665 |
ch_finetune,
|
1666 |
file_checkpoint_train,
|
1667 |
tokenizer_type,
|
|
|
1684 |
learning_rate,
|
1685 |
num_warmup_updates,
|
1686 |
save_per_updates,
|
1687 |
+
last_per_updates,
|
1688 |
ch_finetune,
|
1689 |
],
|
1690 |
outputs=[
|
|
|
1692 |
max_samples,
|
1693 |
num_warmup_updates,
|
1694 |
save_per_updates,
|
1695 |
+
last_per_updates,
|
1696 |
lb_samples,
|
1697 |
learning_rate,
|
1698 |
epochs,
|
|
|
1715 |
epochs,
|
1716 |
num_warmup_updates,
|
1717 |
save_per_updates,
|
1718 |
+
last_per_updates,
|
1719 |
ch_finetune,
|
1720 |
file_checkpoint_train,
|
1721 |
tokenizer_type,
|
src/f5_tts/train/train.py
CHANGED
@@ -55,7 +55,7 @@ def main(cfg):
|
|
55 |
wandb_project="CFM-TTS",
|
56 |
wandb_run_name=exp_name,
|
57 |
wandb_resume_id=wandb_resume_id,
|
58 |
-
|
59 |
log_samples=True,
|
60 |
bnb_optimizer=cfg.optim.bnb_optimizer,
|
61 |
mel_spec_type=mel_spec_type,
|
|
|
55 |
wandb_project="CFM-TTS",
|
56 |
wandb_run_name=exp_name,
|
57 |
wandb_resume_id=wandb_resume_id,
|
58 |
+
last_per_updates=cfg.ckpts.last_per_updates,
|
59 |
log_samples=True,
|
60 |
bnb_optimizer=cfg.optim.bnb_optimizer,
|
61 |
mel_spec_type=mel_spec_type,
|