Spaces:
Configuration error
Configuration error
unknown
commited on
Commit
·
886500a
1
Parent(s):
2ca1fb7
update
Browse files- src/f5_tts/model/trainer.py +20 -9
src/f5_tts/model/trainer.py
CHANGED
@@ -18,8 +18,6 @@ from ema_pytorch import EMA
|
|
18 |
from f5_tts.model import CFM
|
19 |
from f5_tts.model.utils import exists, default
|
20 |
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
|
21 |
-
from f5_tts.infer.utils_infer import target_sample_rate, hop_length, nfe_step, cfg_strength, sway_sampling_coef, vocos
|
22 |
-
from f5_tts.model.utils import get_sample
|
23 |
|
24 |
# trainer
|
25 |
|
@@ -51,6 +49,11 @@ class Trainer:
|
|
51 |
bnb_optimizer: bool = False,
|
52 |
export_samples=False,
|
53 |
):
|
|
|
|
|
|
|
|
|
|
|
54 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
55 |
|
56 |
self.logger = logger
|
@@ -102,13 +105,6 @@ class Trainer:
|
|
102 |
|
103 |
self.writer = SummaryWriter(log_dir=folder_path)
|
104 |
|
105 |
-
# export audio and mel
|
106 |
-
self.export_samples = export_samples
|
107 |
-
if self.export_samples:
|
108 |
-
self.path_ckpts_project = checkpoint_path
|
109 |
-
self.file_path_samples = os.path.join(self.path_ckpts_project, "samples")
|
110 |
-
os.makedirs(self.file_path_samples, exist_ok=True)
|
111 |
-
|
112 |
self.model = model
|
113 |
|
114 |
if self.is_main:
|
@@ -213,6 +209,21 @@ class Trainer:
|
|
213 |
self.writer.add_scalar(key, value, step)
|
214 |
|
215 |
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
if exists(resumable_with_seed):
|
217 |
generator = torch.Generator()
|
218 |
generator.manual_seed(resumable_with_seed)
|
|
|
18 |
from f5_tts.model import CFM
|
19 |
from f5_tts.model.utils import exists, default
|
20 |
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
|
|
|
|
|
21 |
|
22 |
# trainer
|
23 |
|
|
|
49 |
bnb_optimizer: bool = False,
|
50 |
export_samples=False,
|
51 |
):
|
52 |
+
# export audio and mel
|
53 |
+
self.export_samples = export_samples
|
54 |
+
if export_samples:
|
55 |
+
self.path_ckpts_project = checkpoint_path
|
56 |
+
|
57 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
58 |
|
59 |
self.logger = logger
|
|
|
105 |
|
106 |
self.writer = SummaryWriter(log_dir=folder_path)
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
self.model = model
|
109 |
|
110 |
if self.is_main:
|
|
|
209 |
self.writer.add_scalar(key, value, step)
|
210 |
|
211 |
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
|
212 |
+
# import only when export_sample True
|
213 |
+
if self.export_samples:
|
214 |
+
from f5_tts.infer.utils_infer import (
|
215 |
+
target_sample_rate,
|
216 |
+
hop_length,
|
217 |
+
nfe_step,
|
218 |
+
cfg_strength,
|
219 |
+
sway_sampling_coef,
|
220 |
+
vocos,
|
221 |
+
)
|
222 |
+
from f5_tts.model.utils import get_sample
|
223 |
+
|
224 |
+
self.file_path_samples = os.path.join(self.path_ckpts_project, "samples")
|
225 |
+
os.makedirs(self.file_path_samples, exist_ok=True)
|
226 |
+
|
227 |
if exists(resumable_with_seed):
|
228 |
generator = torch.Generator()
|
229 |
generator.manual_seed(resumable_with_seed)
|