Spaces:
Running
on
T4
Running
on
T4
Update diffusionsfm/inference/load_model.py
Browse files
diffusionsfm/inference/load_model.py
CHANGED
@@ -27,15 +27,17 @@ def load_model(
|
|
27 |
device (str): Device to load the model on.
|
28 |
custom_keys (dict): Dictionary of custom keys to override in the config.
|
29 |
"""
|
30 |
-
if checkpoint is None:
|
31 |
-
|
32 |
-
else:
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
39 |
|
40 |
cfg = OmegaConf.load(osp.join(output_dir, "hydra", "config.yaml"))
|
41 |
if custom_keys is not None:
|
@@ -78,7 +80,7 @@ def load_model(
|
|
78 |
cond_depth_mask=cfg.model.get("cond_depth_mask", False),
|
79 |
).to(device)
|
80 |
|
81 |
-
data = torch.load(checkpoint_path)
|
82 |
state_dict = {}
|
83 |
for k, v in data["state_dict"].items():
|
84 |
include = True
|
|
|
27 |
device (str): Device to load the model on.
|
28 |
custom_keys (dict): Dictionary of custom keys to override in the config.
|
29 |
"""
|
30 |
+
# if checkpoint is None:
|
31 |
+
# checkpoint_path = sorted(glob(osp.join(output_dir, "checkpoints", "*.pth")))[-1]
|
32 |
+
# else:
|
33 |
+
# if isinstance(checkpoint, int):
|
34 |
+
# checkpoint_name = f"ckpt_{checkpoint:08d}.pth"
|
35 |
+
# else:
|
36 |
+
# checkpoint_name = checkpoint
|
37 |
+
# checkpoint_path = osp.join(output_dir, "checkpoints", checkpoint_name)
|
38 |
+
_URL = "https://huggingface.co/qitaoz/DiffusionSfM/resolve/main/ckpt_00800000.pth"
|
39 |
+
data = torch.hub.load_state_dict_from_url(_URL)
|
40 |
+
print("Loading checkpoint", _URL)
|
41 |
|
42 |
cfg = OmegaConf.load(osp.join(output_dir, "hydra", "config.yaml"))
|
43 |
if custom_keys is not None:
|
|
|
80 |
cond_depth_mask=cfg.model.get("cond_depth_mask", False),
|
81 |
).to(device)
|
82 |
|
83 |
+
# data = torch.load(checkpoint_path)
|
84 |
state_dict = {}
|
85 |
for k, v in data["state_dict"].items():
|
86 |
include = True
|