SWivid commited on
Commit
46d391a
·
1 Parent(s): 0d7b47b

fix replacement of ckpt keys when do finetune training

Browse files
Files changed (1) hide show
  1. model/trainer.py +2 -2
model/trainer.py CHANGED
@@ -138,7 +138,7 @@ class Trainer:
138
  if "model_last.pt" in os.listdir(self.checkpoint_path):
139
  latest_checkpoint = "model_last.pt"
140
  else:
141
- latest_checkpoint = sorted(os.listdir(self.checkpoint_path), key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
142
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
143
  checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
144
 
@@ -152,7 +152,7 @@ class Trainer:
152
  self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
153
  step = checkpoint['step']
154
  else:
155
- checkpoint['model_state_dict'] = {k.replace("ema_model", "model"): v for k, v in checkpoint['ema_model_state_dict'].items()}
156
  self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
157
  step = 0
158
 
 
138
  if "model_last.pt" in os.listdir(self.checkpoint_path):
139
  latest_checkpoint = "model_last.pt"
140
  else:
141
+ latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
142
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
143
  checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
144
 
 
152
  self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
153
  step = checkpoint['step']
154
  else:
155
+ checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
156
  self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
157
  step = 0
158