unknown commited on
Commit
886500a
·
1 Parent(s): 2ca1fb7
Files changed (1) hide show
  1. src/f5_tts/model/trainer.py +20 -9
src/f5_tts/model/trainer.py CHANGED
@@ -18,8 +18,6 @@ from ema_pytorch import EMA
18
  from f5_tts.model import CFM
19
  from f5_tts.model.utils import exists, default
20
  from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
21
- from f5_tts.infer.utils_infer import target_sample_rate, hop_length, nfe_step, cfg_strength, sway_sampling_coef, vocos
22
- from f5_tts.model.utils import get_sample
23
 
24
  # trainer
25
 
@@ -51,6 +49,11 @@ class Trainer:
51
  bnb_optimizer: bool = False,
52
  export_samples=False,
53
  ):
 
 
 
 
 
54
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
55
 
56
  self.logger = logger
@@ -102,13 +105,6 @@ class Trainer:
102
 
103
  self.writer = SummaryWriter(log_dir=folder_path)
104
 
105
- # export audio and mel
106
- self.export_samples = export_samples
107
- if self.export_samples:
108
- self.path_ckpts_project = checkpoint_path
109
- self.file_path_samples = os.path.join(self.path_ckpts_project, "samples")
110
- os.makedirs(self.file_path_samples, exist_ok=True)
111
-
112
  self.model = model
113
 
114
  if self.is_main:
@@ -213,6 +209,21 @@ class Trainer:
213
  self.writer.add_scalar(key, value, step)
214
 
215
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  if exists(resumable_with_seed):
217
  generator = torch.Generator()
218
  generator.manual_seed(resumable_with_seed)
 
18
  from f5_tts.model import CFM
19
  from f5_tts.model.utils import exists, default
20
  from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
 
 
21
 
22
  # trainer
23
 
 
49
  bnb_optimizer: bool = False,
50
  export_samples=False,
51
  ):
52
+ # export audio and mel
53
+ self.export_samples = export_samples
54
+ if export_samples:
55
+ self.path_ckpts_project = checkpoint_path
56
+
57
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
58
 
59
  self.logger = logger
 
105
 
106
  self.writer = SummaryWriter(log_dir=folder_path)
107
 
 
 
 
 
 
 
 
108
  self.model = model
109
 
110
  if self.is_main:
 
209
  self.writer.add_scalar(key, value, step)
210
 
211
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
212
+ # import only when export_sample True
213
+ if self.export_samples:
214
+ from f5_tts.infer.utils_infer import (
215
+ target_sample_rate,
216
+ hop_length,
217
+ nfe_step,
218
+ cfg_strength,
219
+ sway_sampling_coef,
220
+ vocos,
221
+ )
222
+ from f5_tts.model.utils import get_sample
223
+
224
+ self.file_path_samples = os.path.join(self.path_ckpts_project, "samples")
225
+ os.makedirs(self.file_path_samples, exist_ok=True)
226
+
227
  if exists(resumable_with_seed):
228
  generator = torch.Generator()
229
  generator.manual_seed(resumable_with_seed)