Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,126 Bytes
2ac1c2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
from dataclasses import dataclass
import torch
import torch.nn as nn
from .config import parse_structured
from .misc import get_device, load_module_weights
from .typing import *
class Configurable:
@dataclass
class Config:
pass
def __init__(self, cfg: Optional[dict] = None) -> None:
super().__init__()
self.cfg = parse_structured(self.Config, cfg)
class Updateable:
def do_update_step(
self, epoch: int, global_step: int, on_load_weights: bool = False
):
for attr in self.__dir__():
if attr.startswith("_"):
continue
try:
module = getattr(self, attr)
except:
continue # ignore attributes like property, which can't be retrived using getattr?
if isinstance(module, Updateable):
module.do_update_step(
epoch, global_step, on_load_weights=on_load_weights
)
self.update_step(epoch, global_step, on_load_weights=on_load_weights)
def do_update_step_end(self, epoch: int, global_step: int):
for attr in self.__dir__():
if attr.startswith("_"):
continue
try:
module = getattr(self, attr)
except:
continue # ignore attributes like property, which can't be retrived using getattr?
if isinstance(module, Updateable):
module.do_update_step_end(epoch, global_step)
self.update_step_end(epoch, global_step)
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
# override this method to implement custom update logic
# if on_load_weights is True, you should be careful doing things related to model evaluations,
# as the models and tensors are not guarenteed to be on the same device
pass
def update_step_end(self, epoch: int, global_step: int):
pass
def update_if_possible(module: Any, epoch: int, global_step: int) -> None:
if isinstance(module, Updateable):
module.do_update_step(epoch, global_step)
def update_end_if_possible(module: Any, epoch: int, global_step: int) -> None:
if isinstance(module, Updateable):
module.do_update_step_end(epoch, global_step)
class BaseObject(Updateable):
@dataclass
class Config:
pass
cfg: Config # add this to every subclass of BaseObject to enable static type checking
def __init__(
self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
) -> None:
super().__init__()
self.cfg = parse_structured(self.Config, cfg)
self.device = get_device()
self.configure(*args, **kwargs)
def configure(self, *args, **kwargs) -> None:
pass
class BaseModule(nn.Module, Updateable):
@dataclass
class Config:
weights: Optional[str] = None
cfg: Config # add this to every subclass of BaseModule to enable static type checking
def __init__(
self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
) -> None:
super().__init__()
self.cfg = parse_structured(self.Config, cfg)
self.device = get_device()
self._non_modules = {}
self.configure(*args, **kwargs)
if self.cfg.weights is not None:
# format: path/to/weights:module_name
weights_path, module_name = self.cfg.weights.split(":")
state_dict, epoch, global_step = load_module_weights(
weights_path, module_name=module_name, map_location="cpu"
)
self.load_state_dict(state_dict)
self.do_update_step(
epoch, global_step, on_load_weights=True
) # restore states
def configure(self, *args, **kwargs) -> None:
pass
def register_non_module(self, name: str, module: nn.Module) -> None:
# non-modules won't be treated as model parameters
self._non_modules[name] = module
def non_module(self, name: str):
return self._non_modules.get(name, None)
|