hcsolakoglu commited on
Commit
33e8651
·
1 Parent(s): 93ae7d3

Refactor imports and improve code formatting in dataset and trainer modules

Browse files
src/f5_tts/model/dataset.py CHANGED
@@ -1,5 +1,4 @@
1
  import json
2
- import random
3
  from importlib.resources import files
4
 
5
  import torch
 
1
  import json
 
2
  from importlib.resources import files
3
 
4
  import torch
src/f5_tts/model/trainer.py CHANGED
@@ -279,11 +279,11 @@ class Trainer:
279
  self.accelerator.even_batches = False
280
  sampler = SequentialSampler(train_dataset)
281
  batch_sampler = DynamicBatchSampler(
282
- sampler,
283
- self.batch_size,
284
- max_samples=self.max_samples,
285
  random_seed=resumable_with_seed, # This enables reproducible shuffling
286
- drop_last=False
287
  )
288
  train_dataloader = DataLoader(
289
  train_dataset,
@@ -334,7 +334,7 @@ class Trainer:
334
  current_dataloader = train_dataloader
335
 
336
  # Set epoch for the batch sampler if it exists
337
- if hasattr(train_dataloader, 'batch_sampler') and hasattr(train_dataloader.batch_sampler, 'set_epoch'):
338
  train_dataloader.batch_sampler.set_epoch(epoch)
339
 
340
  progress_bar = tqdm(
 
279
  self.accelerator.even_batches = False
280
  sampler = SequentialSampler(train_dataset)
281
  batch_sampler = DynamicBatchSampler(
282
+ sampler,
283
+ self.batch_size,
284
+ max_samples=self.max_samples,
285
  random_seed=resumable_with_seed, # This enables reproducible shuffling
286
+ drop_last=False,
287
  )
288
  train_dataloader = DataLoader(
289
  train_dataset,
 
334
  current_dataloader = train_dataloader
335
 
336
  # Set epoch for the batch sampler if it exists
337
+ if hasattr(train_dataloader, "batch_sampler") and hasattr(train_dataloader.batch_sampler, "set_epoch"):
338
  train_dataloader.batch_sampler.set_epoch(epoch)
339
 
340
  progress_bar = tqdm(