Spaces:
Runtime error
Runtime error
| import os.path as osp | |
| from glob import glob | |
| import torch | |
| from omegaconf import OmegaConf | |
| from diffusionsfm.model.diffuser import RayDiffuser | |
| from diffusionsfm.model.diffuser_dpt import RayDiffuserDPT | |
| from diffusionsfm.model.scheduler import NoiseScheduler | |
| def load_model( | |
| output_dir, checkpoint=None, device="cuda:0", custom_keys=None, ignore_keys=() | |
| ): | |
| """ | |
| Loads a model and config from an output directory. | |
| E.g. to load with different number of images, | |
| ``` | |
| custom_keys={"model.num_images": 15}, ignore_keys=["pos_table"] | |
| ``` | |
| Args: | |
| output_dir (str): Path to the output directory. | |
| checkpoint (str or int): Path to the checkpoint to load. If None, loads the | |
| latest checkpoint. | |
| device (str): Device to load the model on. | |
| custom_keys (dict): Dictionary of custom keys to override in the config. | |
| """ | |
| if checkpoint is None: | |
| checkpoint_path = sorted(glob(osp.join(output_dir, "checkpoints", "*.pth")))[-1] | |
| else: | |
| if isinstance(checkpoint, int): | |
| checkpoint_name = f"ckpt_{checkpoint:08d}.pth" | |
| else: | |
| checkpoint_name = checkpoint | |
| checkpoint_path = osp.join(output_dir, "checkpoints", checkpoint_name) | |
| print("Loading checkpoint", osp.basename(checkpoint_path)) | |
| cfg = OmegaConf.load(osp.join(output_dir, "hydra", "config.yaml")) | |
| if custom_keys is not None: | |
| for k, v in custom_keys.items(): | |
| OmegaConf.update(cfg, k, v) | |
| noise_scheduler = NoiseScheduler( | |
| type=cfg.noise_scheduler.type, | |
| max_timesteps=cfg.noise_scheduler.max_timesteps, | |
| beta_start=cfg.noise_scheduler.beta_start, | |
| beta_end=cfg.noise_scheduler.beta_end, | |
| ) | |
| if not cfg.training.get("dpt_head", False): | |
| model = RayDiffuser( | |
| depth=cfg.model.depth, | |
| width=cfg.model.num_patches_x, | |
| P=1, | |
| max_num_images=cfg.model.num_images, | |
| noise_scheduler=noise_scheduler, | |
| feature_extractor=cfg.model.feature_extractor, | |
| append_ndc=cfg.model.append_ndc, | |
| diffuse_depths=cfg.training.get("diffuse_depths", False), | |
| depth_resolution=cfg.training.get("depth_resolution", 1), | |
| use_homogeneous=cfg.model.get("use_homogeneous", False), | |
| cond_depth_mask=cfg.model.get("cond_depth_mask", False), | |
| ).to(device) | |
| else: | |
| model = RayDiffuserDPT( | |
| depth=cfg.model.depth, | |
| width=cfg.model.num_patches_x, | |
| P=1, | |
| max_num_images=cfg.model.num_images, | |
| noise_scheduler=noise_scheduler, | |
| feature_extractor=cfg.model.feature_extractor, | |
| append_ndc=cfg.model.append_ndc, | |
| diffuse_depths=cfg.training.get("diffuse_depths", False), | |
| depth_resolution=cfg.training.get("depth_resolution", 1), | |
| encoder_features=cfg.training.get("dpt_encoder_features", False), | |
| use_homogeneous=cfg.model.get("use_homogeneous", False), | |
| cond_depth_mask=cfg.model.get("cond_depth_mask", False), | |
| ).to(device) | |
| data = torch.load(checkpoint_path) | |
| state_dict = {} | |
| for k, v in data["state_dict"].items(): | |
| include = True | |
| for ignore_key in ignore_keys: | |
| if ignore_key in k: | |
| include = False | |
| if include: | |
| state_dict[k] = v | |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) | |
| if len(missing) > 0: | |
| print("Missing keys:", missing) | |
| if len(unexpected) > 0: | |
| print("Unexpected keys:", unexpected) | |
| model = model.eval() | |
| return model, cfg | |