kunci115 commited on
Commit
0fe34a8
·
unverified ·
2 Parent(s): e12fe35 b0f4824

Merge branch 'SWivid:main' into main

Browse files
src/f5_tts/infer/utils_infer.py CHANGED
@@ -156,6 +156,7 @@ def load_checkpoint(model, ckpt_path, device, dtype=None, use_ema=True):
156
  if k not in ["initted", "step"]
157
  }
158
 
 
159
  for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
160
  if key in checkpoint["model_state_dict"]:
161
  del checkpoint["model_state_dict"][key]
 
156
  if k not in ["initted", "step"]
157
  }
158
 
159
+ # patch for backward compatibility, 305e3ea
160
  for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
161
  if key in checkpoint["model_state_dict"]:
162
  del checkpoint["model_state_dict"][key]
src/f5_tts/model/trainer.py CHANGED
@@ -163,10 +163,20 @@ class Trainer:
163
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
164
  checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
165
 
 
 
 
 
 
166
  if self.is_main:
167
  self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
168
 
169
  if "step" in checkpoint:
 
 
 
 
 
170
  self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
171
  self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
172
  if self.scheduler:
 
163
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
164
  checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
165
 
166
+ # patch for backward compatibility, 305e3ea
167
+ for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
168
+ if key in checkpoint["ema_model_state_dict"]:
169
+ del checkpoint["ema_model_state_dict"][key]
170
+
171
  if self.is_main:
172
  self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
173
 
174
  if "step" in checkpoint:
175
+ # patch for backward compatibility, 305e3ea
176
+ for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
177
+ if key in checkpoint["model_state_dict"]:
178
+ del checkpoint["model_state_dict"][key]
179
+
180
  self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
181
  self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
182
  if self.scheduler: