ReubenSun's picture
1
2ac1c2d
from dataclasses import dataclass
import os
import copy
import json
from omegaconf import OmegaConf
import torch
import torch.nn as nn
from diffusers.models.modeling_utils import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import (
extract_commit_hash,
)
from step1x3d_geometry.utils.config import parse_structured
from step1x3d_geometry.utils.misc import get_device, load_module_weights
from step1x3d_geometry.utils.typing import *
class Configurable:
@dataclass
class Config:
pass
def __init__(self, cfg: Optional[dict] = None) -> None:
super().__init__()
self.cfg = parse_structured(self.Config, cfg)
class Updateable:
def do_update_step(
self, epoch: int, global_step: int, on_load_weights: bool = False
):
for attr in self.__dir__():
if attr.startswith("_"):
continue
try:
module = getattr(self, attr)
except:
continue # ignore attributes like property, which can't be retrived using getattr?
if isinstance(module, Updateable):
module.do_update_step(
epoch, global_step, on_load_weights=on_load_weights
)
self.update_step(epoch, global_step, on_load_weights=on_load_weights)
def do_update_step_end(self, epoch: int, global_step: int):
for attr in self.__dir__():
if attr.startswith("_"):
continue
try:
module = getattr(self, attr)
except:
continue # ignore attributes like property, which can't be retrived using getattr?
if isinstance(module, Updateable):
module.do_update_step_end(epoch, global_step)
self.update_step_end(epoch, global_step)
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
# override this method to implement custom update logic
# if on_load_weights is True, you should be careful doing things related to model evaluations,
# as the models and tensors are not guarenteed to be on the same device
pass
def update_step_end(self, epoch: int, global_step: int):
pass
def update_if_possible(module: Any, epoch: int, global_step: int) -> None:
if isinstance(module, Updateable):
module.do_update_step(epoch, global_step)
def update_end_if_possible(module: Any, epoch: int, global_step: int) -> None:
if isinstance(module, Updateable):
module.do_update_step_end(epoch, global_step)
class BaseObject(Updateable):
@dataclass
class Config:
pass
cfg: Config # add this to every subclass of BaseObject to enable static type checking
def __init__(
self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
) -> None:
super().__init__()
self.cfg = parse_structured(self.Config, cfg)
self.device = get_device()
self.configure(*args, **kwargs)
def configure(self, *args, **kwargs) -> None:
pass
class BaseModule(ModelMixin, Updateable, nn.Module):
@dataclass
class Config:
weights: Optional[str] = None
cfg: Config # add this to every subclass of BaseModule to enable static type checking
config_name = "config.json"
def __init__(
self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
) -> None:
super().__init__()
self.cfg = parse_structured(self.Config, cfg)
# self.device = get_device()
self.configure(*args, **kwargs)
if self.cfg.weights is not None:
# format: path/to/weights:module_name
weights_path, module_name = self.cfg.weights.split(":")
state_dict, epoch, global_step = load_module_weights(
weights_path, module_name=module_name, map_location="cpu"
)
self.load_state_dict(state_dict)
self.do_update_step(
epoch, global_step, on_load_weights=True
) # restore states
# dummy tensor to indicate model state
self._dummy: Float[Tensor, "..."]
self.register_buffer("_dummy", torch.zeros(0).float(), persistent=False)
def configure(self, *args, **kwargs) -> None:
pass
@classmethod
def load_config(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
return_unused_kwargs=False,
return_commit_hash=False,
**kwargs,
):
subfolder = kwargs.pop("subfolder", None)
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
if subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
):
config_file = os.path.join(
pretrained_model_name_or_path, subfolder, cls.config_name
)
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, cls.config_name)
):
# Load from a PyTorch checkpoint
config_file = os.path.join(
pretrained_model_name_or_path, cls.config_name
)
else:
raise EnvironmentError(
f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
)
else:
raise ValueError
config_dict = json.load(open(config_file, "r"))
commit_hash = extract_commit_hash(config_file)
outputs = (config_dict,)
if return_unused_kwargs:
outputs += (kwargs,)
if return_commit_hash:
outputs += (commit_hash,)
return outputs
@classmethod
def from_config(cls, config: Dict[str, Any] = None, **kwargs):
model = cls(config)
return model
def register_to_config(self, **kwargs):
pass
def save_config(self, save_directory: Union[str, os.PathLike], **kwargs):
"""
Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
[`~ConfigMixin.from_config`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file is saved (will be created if it does not exist).
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
if os.path.isfile(save_directory):
raise AssertionError(
f"Provided path ({save_directory}) should be a directory, not a file"
)
os.makedirs(save_directory, exist_ok=True)
# If we save using the predefined names, we can load using `from_config`
output_config_file = os.path.join(save_directory, self.config_name)
config_dict = OmegaConf.to_container(self.cfg, resolve=True)
for k in copy.deepcopy(config_dict).keys():
if k.startswith("pretrained"):
config_dict.pop(k)
config_dict.pop("weights")
with open(output_config_file, "w", encoding="utf-8") as f:
json.dump(config_dict, f, ensure_ascii=False, indent=4)
print(f"Configuration saved in {output_config_file}")