""" Contains only Auth abstractions, no implementations. """ import base64 from functools import partial import logging from abc import ABC, abstractmethod from enum import Enum from typing import ( Any, Callable, List, Optional, Dict, TypeVar, Tuple, Generic, Union, ) from dataclasses import dataclass from overrides import EnforceOverrides, override from pydantic import SecretStr from chromadb.config import ( DEFAULT_DATABASE, DEFAULT_TENANT, Component, System, ) from chromadb.errors import ChromaError logger = logging.getLogger(__name__) T = TypeVar("T") S = TypeVar("S") class AuthInfoType(Enum): COOKIE = "cookie" HEADER = "header" URL = "url" METADATA = "metadata" # gRPC class UserIdentity(EnforceOverrides, ABC): @abstractmethod def get_user_id(self) -> str: ... @abstractmethod def get_user_tenant(self) -> Optional[str]: ... @abstractmethod def get_user_databases(self) -> Optional[List[str]]: ... @abstractmethod def get_user_attributes(self) -> Optional[Dict[str, Any]]: ... class SimpleUserIdentity(UserIdentity): def __init__( self, user_id: str, tenant: Optional[str] = None, databases: Optional[List[str]] = None, attributes: Optional[Dict[str, Any]] = None, ) -> None: self._user_id = user_id self._tenant = tenant self._attributes = attributes self._databases = databases @override def get_user_id(self) -> str: return self._user_id @override def get_user_tenant(self) -> Optional[str]: return self._tenant if self._tenant else DEFAULT_TENANT @override def get_user_databases(self) -> Optional[List[str]]: return self._databases @override def get_user_attributes(self) -> Optional[Dict[str, Any]]: return self._attributes class ClientAuthResponse(EnforceOverrides, ABC): @abstractmethod def get_auth_info_type(self) -> AuthInfoType: ... @abstractmethod def get_auth_info( self, ) -> Union[Tuple[str, SecretStr], List[Tuple[str, SecretStr]]]: ... class ClientAuthProvider(Component): def __init__(self, system: System) -> None: super().__init__(system) @abstractmethod def authenticate(self) -> ClientAuthResponse: pass class ClientAuthConfigurationProvider(Component): def __init__(self, system: System) -> None: super().__init__(system) @abstractmethod def get_configuration(self) -> Optional[T]: pass class ClientAuthCredentialsProvider(Component, Generic[T]): def __init__(self, system: System) -> None: super().__init__(system) @abstractmethod def get_credentials(self) -> T: pass class ClientAuthProtocolAdapter(Component, Generic[T]): def __init__(self, system: System) -> None: super().__init__(system) @abstractmethod def inject_credentials(self, injection_context: T) -> None: pass # SERVER-SIDE Abstractions class ServerAuthenticationRequest(EnforceOverrides, ABC, Generic[T]): @abstractmethod def get_auth_info(self, auth_info_type: AuthInfoType, auth_info_id: str) -> T: """ This method should return the necessary auth info based on the type of authentication (e.g. header, cookie, url) and a given id for the respective auth type (e.g. name of the header, cookie, url param). :param auth_info_type: The type of auth info to return :param auth_info_id: The id of the auth info to return :return: The auth info which can be specific to the implementation """ pass class ServerAuthenticationResponse(EnforceOverrides, ABC): @abstractmethod def success(self) -> bool: ... @abstractmethod def get_user_identity(self) -> Optional[UserIdentity]: ... class SimpleServerAuthenticationResponse(ServerAuthenticationResponse): """Simple implementation of ServerAuthenticationResponse""" _auth_success: bool _user_identity: Optional[UserIdentity] def __init__( self, auth_success: bool, user_identity: Optional[UserIdentity] ) -> None: self._auth_success = auth_success self._user_identity = user_identity @override def success(self) -> bool: return self._auth_success @override def get_user_identity(self) -> Optional[UserIdentity]: return self._user_identity class ServerAuthProvider(Component): def __init__(self, system: System) -> None: super().__init__(system) @abstractmethod def authenticate( self, request: ServerAuthenticationRequest[T] ) -> ServerAuthenticationResponse: pass class ChromaAuthMiddleware(Component): def __init__(self, system: System) -> None: super().__init__(system) @abstractmethod def authenticate( self, request: ServerAuthenticationRequest[T] ) -> ServerAuthenticationResponse: ... @abstractmethod def ignore_operation(self, verb: str, path: str) -> bool: ... @abstractmethod def instrument_server(self, app: T) -> None: ... class ServerAuthConfigurationProvider(Component): def __init__(self, system: System) -> None: super().__init__(system) @abstractmethod def get_configuration(self) -> Optional[T]: pass class AuthenticationError(ChromaError): @override def code(self) -> int: return 401 @classmethod @override def name(cls) -> str: return "AuthenticationError" class AbstractCredentials(EnforceOverrides, ABC, Generic[T]): """ The class is used by Auth Providers to encapsulate credentials received from the server and pass them to a ServerAuthCredentialsProvider. """ @abstractmethod def get_credentials(self) -> Dict[str, T]: """ Returns the data encapsulated by the credentials object. """ pass class SecretStrAbstractCredentials(AbstractCredentials[SecretStr]): @abstractmethod @override def get_credentials(self) -> Dict[str, SecretStr]: """ Returns the data encapsulated by the credentials object. """ pass class BasicAuthCredentials(SecretStrAbstractCredentials): def __init__(self, username: SecretStr, password: SecretStr) -> None: self.username = username self.password = password @override def get_credentials(self) -> Dict[str, SecretStr]: return {"username": self.username, "password": self.password} @staticmethod def from_header(header: str) -> "BasicAuthCredentials": """ Parses a basic auth header and returns a BasicAuthCredentials object. """ header = header.replace("Basic ", "") header = header.strip() base64_decoded = base64.b64decode(header).decode("utf-8") username, password = base64_decoded.split(":") return BasicAuthCredentials(SecretStr(username), SecretStr(password)) class ServerAuthCredentialsProvider(Component): def __init__(self, system: System) -> None: super().__init__(system) @abstractmethod def validate_credentials(self, credentials: AbstractCredentials[T]) -> bool: ... @abstractmethod def get_user_identity( self, credentials: AbstractCredentials[T] ) -> Optional[UserIdentity]: ... class AuthzResourceTypes(str, Enum): DB = "db" COLLECTION = "collection" TENANT = "tenant" class AuthzResourceActions(str, Enum): CREATE_DATABASE = "create_database" GET_DATABASE = "get_database" CREATE_TENANT = "create_tenant" GET_TENANT = "get_tenant" LIST_COLLECTIONS = "list_collections" COUNT_COLLECTIONS = "count_collections" GET_COLLECTION = "get_collection" CREATE_COLLECTION = "create_collection" GET_OR_CREATE_COLLECTION = "get_or_create_collection" DELETE_COLLECTION = "delete_collection" UPDATE_COLLECTION = "update_collection" ADD = "add" DELETE = "delete" GET = "get" QUERY = "query" COUNT = "count" UPDATE = "update" UPSERT = "upsert" RESET = "reset" @dataclass class AuthzUser: id: Optional[str] tenant: Optional[str] = DEFAULT_TENANT attributes: Optional[Dict[str, Any]] = None claims: Optional[Dict[str, Any]] = None @dataclass class AuthzResource: id: Optional[str] type: Optional[str] attributes: Optional[Dict[str, Any]] = None class DynamicAuthzResource: id: Optional[Union[str, Callable[..., str]]] type: Optional[Union[str, Callable[..., str]]] attributes: Optional[Union[Dict[str, Any], Callable[..., Dict[str, Any]]]] def __init__( self, id: Optional[Union[str, Callable[..., str]]] = None, attributes: Optional[ Union[Dict[str, Any], Callable[..., Dict[str, Any]]] ] = lambda **kwargs: {}, type: Optional[Union[str, Callable[..., str]]] = DEFAULT_DATABASE, ) -> None: self.id = id self.attributes = attributes self.type = type def to_authz_resource(self, **kwargs: Any) -> AuthzResource: return AuthzResource( id=self.id(**kwargs) if callable(self.id) else self.id, type=self.type(**kwargs) if callable(self.type) else self.type, attributes=self.attributes(**kwargs) if callable(self.attributes) else self.attributes, ) class AuthzDynamicParams: @staticmethod def from_function_name(**kwargs: Any) -> Callable[..., str]: return partial(lambda **kwargs: kwargs["function"].__name__, **kwargs) @staticmethod def from_function_args(**kwargs: Any) -> Callable[..., str]: return partial( lambda **kwargs: kwargs["function_args"][kwargs["arg_num"]], **kwargs ) @staticmethod def from_function_kwargs(**kwargs: Any) -> Callable[..., str]: return partial( lambda **kwargs: kwargs["function_kwargs"][kwargs["arg_name"]], **kwargs ) @staticmethod def dict_from_function_kwargs(**kwargs: Any) -> Callable[..., Dict[str, Any]]: return partial( lambda **kwargs: { k: kwargs["function_kwargs"][k] for k in kwargs["arg_names"] }, **kwargs, ) @dataclass class AuthzAction: id: str attributes: Optional[Dict[str, Any]] = None @dataclass class AuthorizationContext: user: AuthzUser resource: AuthzResource action: AuthzAction class ServerAuthorizationProvider(Component): def __init__(self, system: System) -> None: super().__init__(system) @abstractmethod def authorize(self, context: AuthorizationContext) -> bool: pass class AuthorizationRequestContext(EnforceOverrides, ABC, Generic[T]): @abstractmethod def get_request(self) -> T: ... class ChromaAuthzMiddleware(Component, Generic[T, S]): def __init__(self, system: System) -> None: super().__init__(system) @abstractmethod def pre_process(self, request: AuthorizationRequestContext[S]) -> None: ... @abstractmethod def ignore_operation(self, verb: str, path: str) -> bool: ... @abstractmethod def instrument_server(self, app: T) -> None: ... class ServerAuthorizationConfigurationProvider(Component, Generic[T]): def __init__(self, system: System) -> None: super().__init__(system) @abstractmethod def get_configuration(self) -> T: pass