kevinwang676 commited on
Commit
8a047bc
·
verified ·
1 Parent(s): 12375ac

Update utils/__init__.py

Browse files
Files changed (1) hide show
  1. utils/__init__.py +1 -1
utils/__init__.py CHANGED
@@ -185,7 +185,7 @@ def load_ckpt(cur_model, ckpt_base_dir, prefix_in_ckpt='model', force=True, stri
185
  lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))
186
  if len(checkpoint_path) > 0:
187
  checkpoint_path = checkpoint_path[-1]
188
- state_dict = torch.load(checkpoint_path)["state_dict"]
189
  state_dict = {k[len(prefix_in_ckpt) + 1:]: v for k, v in state_dict.items()
190
  if k.startswith(f'{prefix_in_ckpt}.')}
191
  if not strict:
 
185
  lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))
186
  if len(checkpoint_path) > 0:
187
  checkpoint_path = checkpoint_path[-1]
188
+ state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
189
  state_dict = {k[len(prefix_in_ckpt) + 1:]: v for k, v in state_dict.items()
190
  if k.startswith(f'{prefix_in_ckpt}.')}
191
  if not strict: