Spaces:
Build error
Build error
| import sys | |
| from abc import ABCMeta, abstractmethod | |
| from torch import nn | |
| from copy import copy | |
| import inspect | |
| class BaseModel(nn.Module, metaclass=ABCMeta): | |
| default_conf = {} | |
| required_inputs = [] | |
| def __init__(self, conf): | |
| """Perform some logic and call the _init method of the child model.""" | |
| super().__init__() | |
| self.conf = conf = {**self.default_conf, **conf} | |
| self.required_inputs = copy(self.required_inputs) | |
| self._init(conf) | |
| sys.stdout.flush() | |
| def forward(self, data): | |
| """Check the data and call the _forward method of the child model.""" | |
| for key in self.required_inputs: | |
| assert key in data, "Missing key {} in data".format(key) | |
| return self._forward(data) | |
| def _init(self, conf): | |
| """To be implemented by the child class.""" | |
| raise NotImplementedError | |
| def _forward(self, data): | |
| """To be implemented by the child class.""" | |
| raise NotImplementedError | |
| def dynamic_load(root, model): | |
| module_path = f"{root.__name__}.{model}" | |
| module = __import__(module_path, fromlist=[""]) | |
| classes = inspect.getmembers(module, inspect.isclass) | |
| # Filter classes defined in the module | |
| classes = [c for c in classes if c[1].__module__ == module_path] | |
| # Filter classes inherited from BaseModel | |
| classes = [c for c in classes if issubclass(c[1], BaseModel)] | |
| assert len(classes) == 1, classes | |
| return classes[0][1] | |
| # return getattr(module, 'Model') | |