Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,261 Bytes
78e32cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
###
# Author: Kai Li
# Date: 2022-02-12 15:16:35
# Email: [email protected]
# LastEditTime: 2022-10-04 16:24:53
###
from .base_model import BaseModel
from .apollo import Apollo
__all__ = [
"BaseModel",
"GullFullband",
"Apollo"
]
def register_model(custom_model):
"""Register a custom model, gettable with `models.get`.
Args:
custom_model: Custom model to register.
"""
if (
custom_model.__name__ in globals().keys()
or custom_model.__name__.lower() in globals().keys()
):
raise ValueError(
f"Model {custom_model.__name__} already exists. Choose another name."
)
globals().update({custom_model.__name__: custom_model})
def get(identifier):
"""Returns an model class from a string (case-insensitive).
Args:
identifier (str): the model name.
Returns:
:class:`torch.nn.Module`
"""
if isinstance(identifier, str):
to_get = {k.lower(): v for k, v in globals().items()}
cls = to_get.get(identifier.lower())
if cls is None:
raise ValueError(f"Could not interpret model name : {str(identifier)}")
return cls
raise ValueError(f"Could not interpret model name : {str(identifier)}")
|