Spaces:
Configuration error
Configuration error
fix replacement of ckpt keys when do finetune training
Browse files- model/trainer.py +2 -2
model/trainer.py
CHANGED
@@ -138,7 +138,7 @@ class Trainer:
|
|
138 |
if "model_last.pt" in os.listdir(self.checkpoint_path):
|
139 |
latest_checkpoint = "model_last.pt"
|
140 |
else:
|
141 |
-
latest_checkpoint = sorted(os.listdir(self.checkpoint_path), key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
|
142 |
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
143 |
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
|
144 |
|
@@ -152,7 +152,7 @@ class Trainer:
|
|
152 |
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
153 |
step = checkpoint['step']
|
154 |
else:
|
155 |
-
checkpoint['model_state_dict'] = {k.replace("ema_model", "
|
156 |
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
|
157 |
step = 0
|
158 |
|
|
|
138 |
if "model_last.pt" in os.listdir(self.checkpoint_path):
|
139 |
latest_checkpoint = "model_last.pt"
|
140 |
else:
|
141 |
+
latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
|
142 |
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
143 |
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
|
144 |
|
|
|
152 |
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
153 |
step = checkpoint['step']
|
154 |
else:
|
155 |
+
checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
|
156 |
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
|
157 |
step = 0
|
158 |
|