Spaces:
Running
Running
Update utils/__init__.py
Browse files- utils/__init__.py +1 -1
utils/__init__.py
CHANGED
@@ -185,7 +185,7 @@ def load_ckpt(cur_model, ckpt_base_dir, prefix_in_ckpt='model', force=True, stri
|
|
185 |
lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))
|
186 |
if len(checkpoint_path) > 0:
|
187 |
checkpoint_path = checkpoint_path[-1]
|
188 |
-
state_dict = torch.load(checkpoint_path)["state_dict"]
|
189 |
state_dict = {k[len(prefix_in_ckpt) + 1:]: v for k, v in state_dict.items()
|
190 |
if k.startswith(f'{prefix_in_ckpt}.')}
|
191 |
if not strict:
|
|
|
185 |
lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))
|
186 |
if len(checkpoint_path) > 0:
|
187 |
checkpoint_path = checkpoint_path[-1]
|
188 |
+
state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
|
189 |
state_dict = {k[len(prefix_in_ckpt) + 1:]: v for k, v in state_dict.items()
|
190 |
if k.startswith(f'{prefix_in_ckpt}.')}
|
191 |
if not strict:
|