fix #1111 #1037 remove redundant unwrap_model for AcceleratedOptimizer; which has no attribute '_modules' thus conflict with has_compiled_regions check introduced in accelerate v1.7.0
Browse files- pyproject.toml +1 -1
- src/f5_tts/model/trainer.py +1 -1
pyproject.toml
CHANGED
@@ -14,7 +14,7 @@ classifiers = [
|
|
14 |
"Programming Language :: Python :: 3",
|
15 |
]
|
16 |
dependencies = [
|
17 |
-
"accelerate>=0.33.0
|
18 |
"bitsandbytes>0.37.0; platform_machine != 'arm64' and platform_system != 'Darwin'",
|
19 |
"cached_path",
|
20 |
"click",
|
|
|
14 |
"Programming Language :: Python :: 3",
|
15 |
]
|
16 |
dependencies = [
|
17 |
+
"accelerate>=0.33.0",
|
18 |
"bitsandbytes>0.37.0; platform_machine != 'arm64' and platform_system != 'Darwin'",
|
19 |
"cached_path",
|
20 |
"click",
|
src/f5_tts/model/trainer.py
CHANGED
@@ -149,7 +149,7 @@ class Trainer:
|
|
149 |
if self.is_main:
|
150 |
checkpoint = dict(
|
151 |
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
|
152 |
-
optimizer_state_dict=self.
|
153 |
ema_model_state_dict=self.ema_model.state_dict(),
|
154 |
scheduler_state_dict=self.scheduler.state_dict(),
|
155 |
update=update,
|
|
|
149 |
if self.is_main:
|
150 |
checkpoint = dict(
|
151 |
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
|
152 |
+
optimizer_state_dict=self.optimizer.state_dict(),
|
153 |
ema_model_state_dict=self.ema_model.state_dict(),
|
154 |
scheduler_state_dict=self.scheduler.state_dict(),
|
155 |
update=update,
|