Spaces:
Sleeping
Sleeping
| from typing import Any, Callable, Dict, List, Sequence, Optional | |
| import fastapi | |
| from fastapi import FastAPI as _FastAPI, Response | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.routing import APIRoute | |
| from fastapi import HTTPException, status | |
| from uuid import UUID | |
| from chromadb.api.models.Collection import Collection | |
| from chromadb.api.types import GetResult, QueryResult | |
| from chromadb.auth import ( | |
| AuthzDynamicParams, | |
| AuthzResourceActions, | |
| AuthzResourceTypes, | |
| DynamicAuthzResource, | |
| ) | |
| from chromadb.auth.fastapi import ( | |
| FastAPIChromaAuthMiddleware, | |
| FastAPIChromaAuthMiddlewareWrapper, | |
| FastAPIChromaAuthzMiddleware, | |
| FastAPIChromaAuthzMiddlewareWrapper, | |
| authz_context, | |
| set_overwrite_singleton_tenant_database_access_from_auth, | |
| ) | |
| from chromadb.auth.fastapi_utils import ( | |
| attr_from_collection_lookup, | |
| attr_from_resource_object, | |
| ) | |
| from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System | |
| import chromadb.api | |
| from chromadb.api import ServerAPI | |
| from chromadb.errors import ( | |
| ChromaError, | |
| InvalidDimensionException, | |
| InvalidHTTPVersion, | |
| ) | |
| from chromadb.server.fastapi.types import ( | |
| AddEmbedding, | |
| CreateDatabase, | |
| CreateTenant, | |
| DeleteEmbedding, | |
| GetEmbedding, | |
| QueryEmbedding, | |
| CreateCollection, | |
| UpdateCollection, | |
| UpdateEmbedding, | |
| ) | |
| from starlette.requests import Request | |
| import logging | |
| from chromadb.server.fastapi.utils import fastapi_json_response, string_to_uuid as _uuid | |
| from chromadb.telemetry.opentelemetry.fastapi import instrument_fastapi | |
| from chromadb.types import Database, Tenant | |
| from chromadb.telemetry.product import ServerContext, ProductTelemetryClient | |
| from chromadb.telemetry.opentelemetry import ( | |
| OpenTelemetryClient, | |
| OpenTelemetryGranularity, | |
| trace_method, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def use_route_names_as_operation_ids(app: _FastAPI) -> None: | |
| """ | |
| Simplify operation IDs so that generated API clients have simpler function | |
| names. | |
| Should be called only after all routes have been added. | |
| """ | |
| for route in app.routes: | |
| if isinstance(route, APIRoute): | |
| route.operation_id = route.name | |
| async def catch_exceptions_middleware( | |
| request: Request, call_next: Callable[[Request], Any] | |
| ) -> Response: | |
| try: | |
| return await call_next(request) | |
| except ChromaError as e: | |
| return fastapi_json_response(e) | |
| except Exception as e: | |
| logger.exception(e) | |
| return JSONResponse(content={"error": repr(e)}, status_code=500) | |
| async def check_http_version_middleware( | |
| request: Request, call_next: Callable[[Request], Any] | |
| ) -> Response: | |
| http_version = request.scope.get("http_version") | |
| if http_version not in ["1.1", "2"]: | |
| raise InvalidHTTPVersion(f"HTTP version {http_version} is not supported") | |
| return await call_next(request) | |
| class ChromaAPIRouter(fastapi.APIRouter): # type: ignore | |
| # A simple subclass of fastapi's APIRouter which treats URLs with a trailing "/" the | |
| # same as URLs without. Docs will only contain URLs without trailing "/"s. | |
| def add_api_route(self, path: str, *args: Any, **kwargs: Any) -> None: | |
| # If kwargs["include_in_schema"] isn't passed OR is True, we should only | |
| # include the non-"/" path. If kwargs["include_in_schema"] is False, include | |
| # neither. | |
| exclude_from_schema = ( | |
| "include_in_schema" in kwargs and not kwargs["include_in_schema"] | |
| ) | |
| def include_in_schema(path: str) -> bool: | |
| nonlocal exclude_from_schema | |
| return not exclude_from_schema and not path.endswith("/") | |
| kwargs["include_in_schema"] = include_in_schema(path) | |
| super().add_api_route(path, *args, **kwargs) | |
| if path.endswith("/"): | |
| path = path[:-1] | |
| else: | |
| path = path + "/" | |
| kwargs["include_in_schema"] = include_in_schema(path) | |
| super().add_api_route(path, *args, **kwargs) | |
| class FastAPI(chromadb.server.Server): | |
| def __init__(self, settings: Settings): | |
| super().__init__(settings) | |
| ProductTelemetryClient.SERVER_CONTEXT = ServerContext.FASTAPI | |
| self._app = fastapi.FastAPI(debug=True) | |
| self._system = System(settings) | |
| self._api: ServerAPI = self._system.instance(ServerAPI) | |
| self._opentelemetry_client = self._api.require(OpenTelemetryClient) | |
| self._system.start() | |
| self._app.middleware("http")(check_http_version_middleware) | |
| self._app.middleware("http")(catch_exceptions_middleware) | |
| self._app.add_middleware( | |
| CORSMiddleware, | |
| allow_headers=["*"], | |
| allow_origins=settings.chroma_server_cors_allow_origins, | |
| allow_methods=["*"], | |
| ) | |
| self._app.on_event("shutdown")(self.shutdown) | |
| if settings.chroma_server_authz_provider: | |
| self._app.add_middleware( | |
| FastAPIChromaAuthzMiddlewareWrapper, | |
| authz_middleware=self._api.require(FastAPIChromaAuthzMiddleware), | |
| ) | |
| if settings.chroma_server_auth_provider: | |
| self._app.add_middleware( | |
| FastAPIChromaAuthMiddlewareWrapper, | |
| auth_middleware=self._api.require(FastAPIChromaAuthMiddleware), | |
| ) | |
| set_overwrite_singleton_tenant_database_access_from_auth( | |
| settings.chroma_overwrite_singleton_tenant_database_access_from_auth | |
| ) | |
| self.router = ChromaAPIRouter() | |
| self.router.add_api_route("/api/v1", self.root, methods=["GET"]) | |
| self.router.add_api_route("/api/v1/reset", self.reset, methods=["POST"]) | |
| self.router.add_api_route("/api/v1/version", self.version, methods=["GET"]) | |
| self.router.add_api_route("/api/v1/heartbeat", self.heartbeat, methods=["GET"]) | |
| self.router.add_api_route( | |
| "/api/v1/pre-flight-checks", self.pre_flight_checks, methods=["GET"] | |
| ) | |
| self.router.add_api_route( | |
| "/api/v1/databases", | |
| self.create_database, | |
| methods=["POST"], | |
| response_model=None, | |
| ) | |
| self.router.add_api_route( | |
| "/api/v1/databases/{database}", | |
| self.get_database, | |
| methods=["GET"], | |
| response_model=None, | |
| ) | |
| self.router.add_api_route( | |
| "/api/v1/tenants", | |
| self.create_tenant, | |
| methods=["POST"], | |
| response_model=None, | |
| ) | |
| self.router.add_api_route( | |
| "/api/v1/tenants/{tenant}", | |
| self.get_tenant, | |
| methods=["GET"], | |
| response_model=None, | |
| ) | |
| self.router.add_api_route( | |
| "/api/v1/collections", | |
| self.list_collections, | |
| methods=["GET"], | |
| response_model=None, | |
| ) | |
| self.router.add_api_route( | |
| "/api/v1/count_collections", | |
| self.count_collections, | |
| methods=["GET"], | |
| response_model=None, | |
| ) | |
| self.router.add_api_route( | |
| "/api/v1/collections", | |
| self.create_collection, | |
| methods=["POST"], | |
| response_model=None, | |
| ) | |
| self.router.add_api_route( | |
| "/api/v1/collections/{collection_id}/add", | |
| self.add, | |
| methods=["POST"], | |
| status_code=status.HTTP_201_CREATED, | |
| response_model=None, | |
| ) | |
| self.router.add_api_route( | |
| "/api/v1/collections/{collection_id}/update", | |
| self.update, | |
| methods=["POST"], | |
| response_model=None, | |
| ) | |
| self.router.add_api_route( | |
| "/api/v1/collections/{collection_id}/upsert", | |
| self.upsert, | |
| methods=["POST"], | |
| response_model=None, | |
| ) | |
| self.router.add_api_route( | |
| "/api/v1/collections/{collection_id}/get", | |
| self.get, | |
| methods=["POST"], | |
| response_model=None, | |
| ) | |
| self.router.add_api_route( | |
| "/api/v1/collections/{collection_id}/delete", | |
| self.delete, | |
| methods=["POST"], | |
| response_model=None, | |
| ) | |
| self.router.add_api_route( | |
| "/api/v1/collections/{collection_id}/count", | |
| self.count, | |
| methods=["GET"], | |
| response_model=None, | |
| ) | |
| self.router.add_api_route( | |
| "/api/v1/collections/{collection_id}/query", | |
| self.get_nearest_neighbors, | |
| methods=["POST"], | |
| response_model=None, | |
| ) | |
| self.router.add_api_route( | |
| "/api/v1/collections/{collection_name}", | |
| self.get_collection, | |
| methods=["GET"], | |
| response_model=None, | |
| ) | |
| self.router.add_api_route( | |
| "/api/v1/collections/{collection_id}", | |
| self.update_collection, | |
| methods=["PUT"], | |
| response_model=None, | |
| ) | |
| self.router.add_api_route( | |
| "/api/v1/collections/{collection_name}", | |
| self.delete_collection, | |
| methods=["DELETE"], | |
| response_model=None, | |
| ) | |
| self._app.include_router(self.router) | |
| use_route_names_as_operation_ids(self._app) | |
| instrument_fastapi(self._app) | |
| def shutdown(self) -> None: | |
| self._system.stop() | |
| def app(self) -> fastapi.FastAPI: | |
| return self._app | |
| def root(self) -> Dict[str, int]: | |
| return {"nanosecond heartbeat": self._api.heartbeat()} | |
| def heartbeat(self) -> Dict[str, int]: | |
| return self.root() | |
| def version(self) -> str: | |
| return self._api.get_version() | |
| def create_database( | |
| self, database: CreateDatabase, tenant: str = DEFAULT_TENANT | |
| ) -> None: | |
| return self._api.create_database(database.name, tenant) | |
| def get_database(self, database: str, tenant: str = DEFAULT_TENANT) -> Database: | |
| return self._api.get_database(database, tenant) | |
| def create_tenant(self, tenant: CreateTenant) -> None: | |
| return self._api.create_tenant(tenant.name) | |
| def get_tenant(self, tenant: str) -> Tenant: | |
| return self._api.get_tenant(tenant) | |
| def list_collections( | |
| self, | |
| limit: Optional[int] = None, | |
| offset: Optional[int] = None, | |
| tenant: str = DEFAULT_TENANT, | |
| database: str = DEFAULT_DATABASE, | |
| ) -> Sequence[Collection]: | |
| return self._api.list_collections( | |
| limit=limit, offset=offset, tenant=tenant, database=database | |
| ) | |
| def count_collections( | |
| self, | |
| tenant: str = DEFAULT_TENANT, | |
| database: str = DEFAULT_DATABASE, | |
| ) -> int: | |
| return self._api.count_collections(tenant=tenant, database=database) | |
| def create_collection( | |
| self, | |
| collection: CreateCollection, | |
| tenant: str = DEFAULT_TENANT, | |
| database: str = DEFAULT_DATABASE, | |
| ) -> Collection: | |
| return self._api.create_collection( | |
| name=collection.name, | |
| metadata=collection.metadata, | |
| get_or_create=collection.get_or_create, | |
| tenant=tenant, | |
| database=database, | |
| ) | |
| def get_collection( | |
| self, | |
| collection_name: str, | |
| tenant: str = DEFAULT_TENANT, | |
| database: str = DEFAULT_DATABASE, | |
| ) -> Collection: | |
| return self._api.get_collection( | |
| collection_name, tenant=tenant, database=database | |
| ) | |
| def update_collection( | |
| self, collection_id: str, collection: UpdateCollection | |
| ) -> None: | |
| return self._api._modify( | |
| id=_uuid(collection_id), | |
| new_name=collection.new_name, | |
| new_metadata=collection.new_metadata, | |
| ) | |
| def delete_collection( | |
| self, | |
| collection_name: str, | |
| tenant: str = DEFAULT_TENANT, | |
| database: str = DEFAULT_DATABASE, | |
| ) -> None: | |
| return self._api.delete_collection( | |
| collection_name, tenant=tenant, database=database | |
| ) | |
| def add(self, collection_id: str, add: AddEmbedding) -> None: | |
| try: | |
| result = self._api._add( | |
| collection_id=_uuid(collection_id), | |
| embeddings=add.embeddings, # type: ignore | |
| metadatas=add.metadatas, # type: ignore | |
| documents=add.documents, # type: ignore | |
| uris=add.uris, # type: ignore | |
| ids=add.ids, | |
| ) | |
| except InvalidDimensionException as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| return result # type: ignore | |
| def update(self, collection_id: str, add: UpdateEmbedding) -> None: | |
| self._api._update( | |
| ids=add.ids, | |
| collection_id=_uuid(collection_id), | |
| embeddings=add.embeddings, | |
| documents=add.documents, # type: ignore | |
| uris=add.uris, # type: ignore | |
| metadatas=add.metadatas, # type: ignore | |
| ) | |
| def upsert(self, collection_id: str, upsert: AddEmbedding) -> None: | |
| self._api._upsert( | |
| collection_id=_uuid(collection_id), | |
| ids=upsert.ids, | |
| embeddings=upsert.embeddings, # type: ignore | |
| documents=upsert.documents, # type: ignore | |
| uris=upsert.uris, # type: ignore | |
| metadatas=upsert.metadatas, # type: ignore | |
| ) | |
| def get(self, collection_id: str, get: GetEmbedding) -> GetResult: | |
| return self._api._get( | |
| collection_id=_uuid(collection_id), | |
| ids=get.ids, | |
| where=get.where, | |
| where_document=get.where_document, | |
| sort=get.sort, | |
| limit=get.limit, | |
| offset=get.offset, | |
| include=get.include, | |
| ) | |
| def delete(self, collection_id: str, delete: DeleteEmbedding) -> List[UUID]: | |
| return self._api._delete( | |
| where=delete.where, # type: ignore | |
| ids=delete.ids, | |
| collection_id=_uuid(collection_id), | |
| where_document=delete.where_document, | |
| ) | |
| def count(self, collection_id: str) -> int: | |
| return self._api._count(_uuid(collection_id)) | |
| def reset(self) -> bool: | |
| return self._api.reset() | |
| def get_nearest_neighbors( | |
| self, collection_id: str, query: QueryEmbedding | |
| ) -> QueryResult: | |
| nnresult = self._api._query( | |
| collection_id=_uuid(collection_id), | |
| where=query.where, # type: ignore | |
| where_document=query.where_document, # type: ignore | |
| query_embeddings=query.query_embeddings, | |
| n_results=query.n_results, | |
| include=query.include, | |
| ) | |
| return nnresult | |
| def pre_flight_checks(self) -> Dict[str, Any]: | |
| return { | |
| "max_batch_size": self._api.max_batch_size, | |
| } | |