qitaoz commited on
Commit
d2ecda1
·
verified ·
1 Parent(s): cff671d

Update diffusionsfm/inference/load_model.py

Browse files
Files changed (1) hide show
  1. diffusionsfm/inference/load_model.py +12 -10
diffusionsfm/inference/load_model.py CHANGED
@@ -27,15 +27,17 @@ def load_model(
27
  device (str): Device to load the model on.
28
  custom_keys (dict): Dictionary of custom keys to override in the config.
29
  """
30
- if checkpoint is None:
31
- checkpoint_path = sorted(glob(osp.join(output_dir, "checkpoints", "*.pth")))[-1]
32
- else:
33
- if isinstance(checkpoint, int):
34
- checkpoint_name = f"ckpt_{checkpoint:08d}.pth"
35
- else:
36
- checkpoint_name = checkpoint
37
- checkpoint_path = osp.join(output_dir, "checkpoints", checkpoint_name)
38
- print("Loading checkpoint", osp.basename(checkpoint_path))
 
 
39
 
40
  cfg = OmegaConf.load(osp.join(output_dir, "hydra", "config.yaml"))
41
  if custom_keys is not None:
@@ -78,7 +80,7 @@ def load_model(
78
  cond_depth_mask=cfg.model.get("cond_depth_mask", False),
79
  ).to(device)
80
 
81
- data = torch.load(checkpoint_path)
82
  state_dict = {}
83
  for k, v in data["state_dict"].items():
84
  include = True
 
27
  device (str): Device to load the model on.
28
  custom_keys (dict): Dictionary of custom keys to override in the config.
29
  """
30
+ # if checkpoint is None:
31
+ # checkpoint_path = sorted(glob(osp.join(output_dir, "checkpoints", "*.pth")))[-1]
32
+ # else:
33
+ # if isinstance(checkpoint, int):
34
+ # checkpoint_name = f"ckpt_{checkpoint:08d}.pth"
35
+ # else:
36
+ # checkpoint_name = checkpoint
37
+ # checkpoint_path = osp.join(output_dir, "checkpoints", checkpoint_name)
38
+ _URL = "https://huggingface.co/qitaoz/DiffusionSfM/resolve/main/ckpt_00800000.pth"
39
+ data = torch.hub.load_state_dict_from_url(_URL)
40
+ print("Loading checkpoint", _URL)
41
 
42
  cfg = OmegaConf.load(osp.join(output_dir, "hydra", "config.yaml"))
43
  if custom_keys is not None:
 
80
  cond_depth_mask=cfg.model.get("cond_depth_mask", False),
81
  ).to(device)
82
 
83
+ # data = torch.load(checkpoint_path)
84
  state_dict = {}
85
  for k, v in data["state_dict"].items():
86
  include = True