SWivid commited on
Commit
c325889
·
1 Parent(s): 45da22d

fix #1111 #1037 remove redundant unwrap_model for AcceleratedOptimizer; which has no attribute '_modules' thus conflict with has_compiled_regions check introduced in accelerate v1.7.0

Browse files
Files changed (2) hide show
  1. pyproject.toml +1 -1
  2. src/f5_tts/model/trainer.py +1 -1
pyproject.toml CHANGED
@@ -14,7 +14,7 @@ classifiers = [
14
  "Programming Language :: Python :: 3",
15
  ]
16
  dependencies = [
17
- "accelerate>=0.33.0,!=1.7.0",
18
  "bitsandbytes>0.37.0; platform_machine != 'arm64' and platform_system != 'Darwin'",
19
  "cached_path",
20
  "click",
 
14
  "Programming Language :: Python :: 3",
15
  ]
16
  dependencies = [
17
+ "accelerate>=0.33.0",
18
  "bitsandbytes>0.37.0; platform_machine != 'arm64' and platform_system != 'Darwin'",
19
  "cached_path",
20
  "click",
src/f5_tts/model/trainer.py CHANGED
@@ -149,7 +149,7 @@ class Trainer:
149
  if self.is_main:
150
  checkpoint = dict(
151
  model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
152
- optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
153
  ema_model_state_dict=self.ema_model.state_dict(),
154
  scheduler_state_dict=self.scheduler.state_dict(),
155
  update=update,
 
149
  if self.is_main:
150
  checkpoint = dict(
151
  model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
152
+ optimizer_state_dict=self.optimizer.state_dict(),
153
  ema_model_state_dict=self.ema_model.state_dict(),
154
  scheduler_state_dict=self.scheduler.state_dict(),
155
  update=update,