File size: 5,270 Bytes
287a0bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import importlib
import logging
import pkgutil
from typing import Union, Dict, Type, Callable  # noqa: F401

from chromadb.auth import (
    ClientAuthConfigurationProvider,
    ClientAuthCredentialsProvider,
    ClientAuthProtocolAdapter,
    ServerAuthProvider,
    ServerAuthConfigurationProvider,
    ServerAuthCredentialsProvider,
    ClientAuthProvider,
    ServerAuthorizationConfigurationProvider,
    ServerAuthorizationProvider,
)
from chromadb.utils import get_class

logger = logging.getLogger(__name__)
ProviderTypes = Union[
    "ClientAuthProvider",
    "ClientAuthConfigurationProvider",
    "ClientAuthCredentialsProvider",
    "ServerAuthProvider",
    "ServerAuthConfigurationProvider",
    "ServerAuthCredentialsProvider",
    "ClientAuthProtocolAdapter",
    "ServerAuthorizationProvider",
    "ServerAuthorizationConfigurationProvider",
]

_provider_registry = {
    "client_auth_providers": {},
    "client_auth_config_providers": {},
    "client_auth_credentials_providers": {},
    "client_auth_protocol_adapters": {},
    "server_auth_providers": {},
    "server_auth_config_providers": {},
    "server_auth_credentials_providers": {},
    "server_authz_providers": {},
    "server_authz_config_providers": {},
}  # type: Dict[str, Dict[str, Type[ProviderTypes]]]


def register_classes_from_package(package_name: str) -> None:
    package = importlib.import_module(package_name)
    for _, module_name, _ in pkgutil.iter_modules(package.__path__):
        full_module_name = f"{package_name}.{module_name}"
        _ = importlib.import_module(full_module_name)


def register_provider(
    short_hand: str,
) -> Callable[[Type[ProviderTypes]], Type[ProviderTypes]]:
    def decorator(cls: Type[ProviderTypes]) -> Type[ProviderTypes]:
        logger.debug("Registering provider: %s", short_hand)
        global _provider_registry
        if issubclass(cls, ClientAuthProvider):
            _provider_registry["client_auth_providers"][short_hand] = cls
        elif issubclass(cls, ClientAuthConfigurationProvider):
            _provider_registry["client_auth_config_providers"][short_hand] = cls
        elif issubclass(cls, ClientAuthCredentialsProvider):
            _provider_registry["client_auth_credentials_providers"][short_hand] = cls
        elif issubclass(cls, ClientAuthProtocolAdapter):
            _provider_registry["client_auth_protocol_adapters"][short_hand] = cls
        elif issubclass(cls, ServerAuthProvider):
            _provider_registry["server_auth_providers"][short_hand] = cls
        elif issubclass(cls, ServerAuthConfigurationProvider):
            _provider_registry["server_auth_config_providers"][short_hand] = cls
        elif issubclass(cls, ServerAuthCredentialsProvider):
            _provider_registry["server_auth_credentials_providers"][short_hand] = cls
        elif issubclass(cls, ServerAuthorizationProvider):
            _provider_registry["server_authz_providers"][short_hand] = cls
        elif issubclass(cls, ServerAuthorizationConfigurationProvider):
            _provider_registry["server_authz_config_providers"][short_hand] = cls
        else:
            raise ValueError(
                "Only ClientAuthProvider, ClientAuthConfigurationProvider, "
                "ClientAuthCredentialsProvider, ServerAuthProvider, "
                "ServerAuthConfigurationProvider, and ServerAuthCredentialsProvider, "
                "ClientAuthProtocolAdapter, ServerAuthorizationProvider, "
                "ServerAuthorizationConfigurationProvider can be registered."
            )
        return cls

    return decorator


def resolve_provider(
    class_or_name: str, cls: Type[ProviderTypes]
) -> Type[ProviderTypes]:
    register_classes_from_package("chromadb.auth")
    global _provider_registry
    if issubclass(cls, ClientAuthProvider):
        _key = "client_auth_providers"
    elif issubclass(cls, ClientAuthConfigurationProvider):
        _key = "client_auth_config_providers"
    elif issubclass(cls, ClientAuthCredentialsProvider):
        _key = "client_auth_credentials_providers"
    elif issubclass(cls, ClientAuthProtocolAdapter):
        _key = "client_auth_protocol_adapters"
    elif issubclass(cls, ServerAuthProvider):
        _key = "server_auth_providers"
    elif issubclass(cls, ServerAuthConfigurationProvider):
        _key = "server_auth_config_providers"
    elif issubclass(cls, ServerAuthCredentialsProvider):
        _key = "server_auth_credentials_providers"
    elif issubclass(cls, ServerAuthorizationProvider):
        _key = "server_authz_providers"
    elif issubclass(cls, ServerAuthorizationConfigurationProvider):
        _key = "server_authz_config_providers"
    else:
        raise ValueError(
            "Only ClientAuthProvider, ClientAuthConfigurationProvider, "
            "ClientAuthCredentialsProvider, ServerAuthProvider, "
            "ServerAuthConfigurationProvider, and ServerAuthCredentialsProvider, "
            "ClientAuthProtocolAdapter, ServerAuthorizationProvider,"
            "ServerAuthorizationConfigurationProvider, can be registered."
        )
    if class_or_name in _provider_registry[_key]:
        return _provider_registry[_key][class_or_name]
    else:
        return get_class(class_or_name, cls)  # type: ignore