Spaces:
Sleeping
Sleeping
# Models | |
# from .conv_tasnet import ConvTasNet | |
# from .dccrnet import DCCRNet | |
# from .dcunet import DCUNet | |
# from .dprnn_tasnet import DPRNNTasNet | |
# from .sudormrf import SuDORMRFImprovedNet, SuDORMRFNet | |
from .dptnet import DPTNet | |
# from .lstm_tasnet import LSTMTasNet | |
# from .demask import DeMask | |
# Sharing-related | |
# from .publisher import save_publishable, upload_publishable | |
__all__ = [ | |
# "ConvTasNet", | |
# "DPRNNTasNet", | |
# "SuDORMRFImprovedNet", | |
# "SuDORMRFNet", | |
"DPTNet", | |
# "LSTMTasNet", | |
# "DeMask", | |
# "DCUNet", | |
# "DCCRNet", | |
# "save_publishable", | |
# "upload_publishable", | |
] | |
def register_model(custom_model): | |
"""Register a custom model, gettable with `models.get`. | |
Args: | |
custom_model: Custom model to register. | |
""" | |
if ( | |
custom_model.__name__ in globals().keys() | |
or custom_model.__name__.lower() in globals().keys() | |
): | |
raise ValueError(f"Model {custom_model.__name__} already exists. Choose another name.") | |
globals().update({custom_model.__name__: custom_model}) | |
def get(identifier): | |
"""Returns an model class from a string (case-insensitive). | |
Args: | |
identifier (str): the model name. | |
Returns: | |
:class:`torch.nn.Module` | |
""" | |
if isinstance(identifier, str): | |
to_get = {k.lower(): v for k, v in globals().items()} | |
cls = to_get.get(identifier.lower()) | |
if cls is None: | |
raise ValueError(f"Could not interpret model name : {str(identifier)}") | |
return cls | |
raise ValueError(f"Could not interpret model name : {str(identifier)}") | |