Spaces:
Running
Running
File size: 5,270 Bytes
287a0bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import importlib
import logging
import pkgutil
from typing import Union, Dict, Type, Callable # noqa: F401
from chromadb.auth import (
ClientAuthConfigurationProvider,
ClientAuthCredentialsProvider,
ClientAuthProtocolAdapter,
ServerAuthProvider,
ServerAuthConfigurationProvider,
ServerAuthCredentialsProvider,
ClientAuthProvider,
ServerAuthorizationConfigurationProvider,
ServerAuthorizationProvider,
)
from chromadb.utils import get_class
logger = logging.getLogger(__name__)
ProviderTypes = Union[
"ClientAuthProvider",
"ClientAuthConfigurationProvider",
"ClientAuthCredentialsProvider",
"ServerAuthProvider",
"ServerAuthConfigurationProvider",
"ServerAuthCredentialsProvider",
"ClientAuthProtocolAdapter",
"ServerAuthorizationProvider",
"ServerAuthorizationConfigurationProvider",
]
_provider_registry = {
"client_auth_providers": {},
"client_auth_config_providers": {},
"client_auth_credentials_providers": {},
"client_auth_protocol_adapters": {},
"server_auth_providers": {},
"server_auth_config_providers": {},
"server_auth_credentials_providers": {},
"server_authz_providers": {},
"server_authz_config_providers": {},
} # type: Dict[str, Dict[str, Type[ProviderTypes]]]
def register_classes_from_package(package_name: str) -> None:
package = importlib.import_module(package_name)
for _, module_name, _ in pkgutil.iter_modules(package.__path__):
full_module_name = f"{package_name}.{module_name}"
_ = importlib.import_module(full_module_name)
def register_provider(
short_hand: str,
) -> Callable[[Type[ProviderTypes]], Type[ProviderTypes]]:
def decorator(cls: Type[ProviderTypes]) -> Type[ProviderTypes]:
logger.debug("Registering provider: %s", short_hand)
global _provider_registry
if issubclass(cls, ClientAuthProvider):
_provider_registry["client_auth_providers"][short_hand] = cls
elif issubclass(cls, ClientAuthConfigurationProvider):
_provider_registry["client_auth_config_providers"][short_hand] = cls
elif issubclass(cls, ClientAuthCredentialsProvider):
_provider_registry["client_auth_credentials_providers"][short_hand] = cls
elif issubclass(cls, ClientAuthProtocolAdapter):
_provider_registry["client_auth_protocol_adapters"][short_hand] = cls
elif issubclass(cls, ServerAuthProvider):
_provider_registry["server_auth_providers"][short_hand] = cls
elif issubclass(cls, ServerAuthConfigurationProvider):
_provider_registry["server_auth_config_providers"][short_hand] = cls
elif issubclass(cls, ServerAuthCredentialsProvider):
_provider_registry["server_auth_credentials_providers"][short_hand] = cls
elif issubclass(cls, ServerAuthorizationProvider):
_provider_registry["server_authz_providers"][short_hand] = cls
elif issubclass(cls, ServerAuthorizationConfigurationProvider):
_provider_registry["server_authz_config_providers"][short_hand] = cls
else:
raise ValueError(
"Only ClientAuthProvider, ClientAuthConfigurationProvider, "
"ClientAuthCredentialsProvider, ServerAuthProvider, "
"ServerAuthConfigurationProvider, and ServerAuthCredentialsProvider, "
"ClientAuthProtocolAdapter, ServerAuthorizationProvider, "
"ServerAuthorizationConfigurationProvider can be registered."
)
return cls
return decorator
def resolve_provider(
class_or_name: str, cls: Type[ProviderTypes]
) -> Type[ProviderTypes]:
register_classes_from_package("chromadb.auth")
global _provider_registry
if issubclass(cls, ClientAuthProvider):
_key = "client_auth_providers"
elif issubclass(cls, ClientAuthConfigurationProvider):
_key = "client_auth_config_providers"
elif issubclass(cls, ClientAuthCredentialsProvider):
_key = "client_auth_credentials_providers"
elif issubclass(cls, ClientAuthProtocolAdapter):
_key = "client_auth_protocol_adapters"
elif issubclass(cls, ServerAuthProvider):
_key = "server_auth_providers"
elif issubclass(cls, ServerAuthConfigurationProvider):
_key = "server_auth_config_providers"
elif issubclass(cls, ServerAuthCredentialsProvider):
_key = "server_auth_credentials_providers"
elif issubclass(cls, ServerAuthorizationProvider):
_key = "server_authz_providers"
elif issubclass(cls, ServerAuthorizationConfigurationProvider):
_key = "server_authz_config_providers"
else:
raise ValueError(
"Only ClientAuthProvider, ClientAuthConfigurationProvider, "
"ClientAuthCredentialsProvider, ServerAuthProvider, "
"ServerAuthConfigurationProvider, and ServerAuthCredentialsProvider, "
"ClientAuthProtocolAdapter, ServerAuthorizationProvider,"
"ServerAuthorizationConfigurationProvider, can be registered."
)
if class_or_name in _provider_registry[_key]:
return _provider_registry[_key][class_or_name]
else:
return get_class(class_or_name, cls) # type: ignore
|