qitaoz's picture
Upload 57 files
4562a06 verified
raw
history blame
3.65 kB
import os.path as osp
from glob import glob
import torch
from omegaconf import OmegaConf
from diffusionsfm.model.diffuser import RayDiffuser
from diffusionsfm.model.diffuser_dpt import RayDiffuserDPT
from diffusionsfm.model.scheduler import NoiseScheduler
def load_model(
output_dir, checkpoint=None, device="cuda:0", custom_keys=None, ignore_keys=()
):
"""
Loads a model and config from an output directory.
E.g. to load with different number of images,
```
custom_keys={"model.num_images": 15}, ignore_keys=["pos_table"]
```
Args:
output_dir (str): Path to the output directory.
checkpoint (str or int): Path to the checkpoint to load. If None, loads the
latest checkpoint.
device (str): Device to load the model on.
custom_keys (dict): Dictionary of custom keys to override in the config.
"""
if checkpoint is None:
checkpoint_path = sorted(glob(osp.join(output_dir, "checkpoints", "*.pth")))[-1]
else:
if isinstance(checkpoint, int):
checkpoint_name = f"ckpt_{checkpoint:08d}.pth"
else:
checkpoint_name = checkpoint
checkpoint_path = osp.join(output_dir, "checkpoints", checkpoint_name)
print("Loading checkpoint", osp.basename(checkpoint_path))
cfg = OmegaConf.load(osp.join(output_dir, "hydra", "config.yaml"))
if custom_keys is not None:
for k, v in custom_keys.items():
OmegaConf.update(cfg, k, v)
noise_scheduler = NoiseScheduler(
type=cfg.noise_scheduler.type,
max_timesteps=cfg.noise_scheduler.max_timesteps,
beta_start=cfg.noise_scheduler.beta_start,
beta_end=cfg.noise_scheduler.beta_end,
)
if not cfg.training.get("dpt_head", False):
model = RayDiffuser(
depth=cfg.model.depth,
width=cfg.model.num_patches_x,
P=1,
max_num_images=cfg.model.num_images,
noise_scheduler=noise_scheduler,
feature_extractor=cfg.model.feature_extractor,
append_ndc=cfg.model.append_ndc,
diffuse_depths=cfg.training.get("diffuse_depths", False),
depth_resolution=cfg.training.get("depth_resolution", 1),
use_homogeneous=cfg.model.get("use_homogeneous", False),
cond_depth_mask=cfg.model.get("cond_depth_mask", False),
).to(device)
else:
model = RayDiffuserDPT(
depth=cfg.model.depth,
width=cfg.model.num_patches_x,
P=1,
max_num_images=cfg.model.num_images,
noise_scheduler=noise_scheduler,
feature_extractor=cfg.model.feature_extractor,
append_ndc=cfg.model.append_ndc,
diffuse_depths=cfg.training.get("diffuse_depths", False),
depth_resolution=cfg.training.get("depth_resolution", 1),
encoder_features=cfg.training.get("dpt_encoder_features", False),
use_homogeneous=cfg.model.get("use_homogeneous", False),
cond_depth_mask=cfg.model.get("cond_depth_mask", False),
).to(device)
data = torch.load(checkpoint_path)
state_dict = {}
for k, v in data["state_dict"].items():
include = True
for ignore_key in ignore_keys:
if ignore_key in k:
include = False
if include:
state_dict[k] = v
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if len(missing) > 0:
print("Missing keys:", missing)
if len(unexpected) > 0:
print("Unexpected keys:", unexpected)
model = model.eval()
return model, cfg