chroma / main.py
anubhav77's picture
Adding all other methods
d6aeaee
raw
history blame
7.06 kB
import fastapi
import json
import uvicorn
from fastapi import HTTPException , status
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FASTAPI as _FastAPI, Response
from sse_starlette.sse import EventSourceResponse
from starlette.responses import StreamingResponse
from starlette.requests import Request
from pydantic import BaseModel
from typing import List, Dict, Any, Generator, Optional, cast, Callable
from server import client
from chromadb.api.types import (
Documents,
Embeddings,
EmbeddingFunction,
IDs,
Include,
Metadatas,
Where,
WhereDocument,
GetResult,
QueryResult,
CollectionMetadata,
)
from chromadb.errors import (
ChromaError,
InvalidUUIDError,
InvalidDimensionException,
)
from chromadb.server.fastapi.types import (
AddEmbedding,
DeleteEmbedding,
GetEmbedding,
QueryEmbedding,
RawSql, # Results,
CreateCollection,
UpdateCollection,
UpdateEmbedding,
)
from chromadb.api import API
from chromadb.config import System
import chromadb.utils.embedding_functions as ef
import pandas as pd
import requests
import json
from typing import Sequence
from chromadb.api.models.Collection import Collection
import chromadb.errors as errors
from uuid import UUID
from chromadb.telemetry import Telemetry
from overrides import override
async def catch_exceptions_middleware(
request: Request, call_next: Callable[[Request], Any]
) -> Response:
try:
return await call_next(request)
except ChromaError as e:
return JSONResponse(
content={"error": e.name(), "message": e.message()}, status_code=e.code()
)
except Exception as e:
return JSONResponse(content={"error": repr(e)}, status_code=500)
def _uuid(uuid_str: str) -> UUID:
try:
return UUID(uuid_str)
except ValueError:
raise InvalidUUIDError(f"Could not parse {uuid_str} as a UUID")
app = fastapi.FastAPI(title="ChromaDB")
app.middleware("http")(catch_exceptions_middleware)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
api_base="/api/v1"
bkend=client()
@app.get(api_base+"")
def heartbeat():
print("Received heartbeat request")
return bkend.heartbeat()
@app.post(api_base+"/reset")
def reset():
print("Received reset request")
return bkend.reset()
@app.get(api_base+"/version")
def version():
print("Received version request")
return bkend.version()
@app.post(api_base+"/persist")
def persist():
print("Received persist request")
return bkend.persist()
@app.post(api_base+"/raw_sql")
def raw_sql(raw_sql: RawSql):
print("Received raw_sql request")
return bkend.raw_sql(raw_sql)
@app.get(api_base+"/heartbeat")
def heartbeat():
print("Received heartbeat request")
return bkend.heartbeat()
@app.get(api_base+"/collections")
def list_collections():
print("Received list_collections request")
return bkend.list_collections()
@app.post(api_base+"/collections")
def create_collection( collection: CreateCollection ) -> Collection:
print("Received request to create_collection")
return bkend.create_collection(name=collection.name,metadata=collection.metadata,get_or_create=collection.get_or_create)
@app.get(api_base+"/collections/{collection_name}")
def get_collection( collection_name: str) -> Collection:
print("Received get_collection request")
return bkend.get_collection(collection_name)
@app.post(api_base+"/collections/{collection_id}/add")
def add(collection_id:str , add:AddEmbedding) -> None:
print("Received add request")
try:
result=bkend.add(collection_id=_uuid(collection_id),embeddings=add.embeddings,metadatas=add.metadatas,documents=add.documents,ids=add.ids,increment_index=add.increment_index)
except InvalidDimensionException as e:
raise HTTPException(status_code=500, detail=str(e))
return result
@app.post(api_base+"/collections/{collection_id}/update")
def update(collection_id:str , update:UpdateEmbedding) -> None:
print("Received update request")
return bkend.update(ids=update.ids, collection_id=_uuid(collection_id), embeddings=update.embeddings, documents=update.documents, metadatas=update.metadatas)
@app.post(api_base+"/collections/{collection_id}/upsert")
def upsert(collection_id:str, upsert: AddEmbedding):
print("Received upsert request")
return bkend.upsert(collection_id=_uuid(collection_id),embeddings=upsert.embeddings,metadatas=upsert.metadatas,documents=upsert.documents,ids=upsert.ids,increment_index=upsert.increment_index)
@app.post(api_base+"/collections/{collection_id}/get")
def get(self, collection_id: str, get: GetEmbedding) -> GetResult:
print("Received get request")
return bkend.get(collection_id=_uuid(collection_id), ids=get.ids, where=get.where,
where_document=get.where_document, sort=get.sort, limit=get.limit,
offset=get.offset, include=get.include)
@app.post(api_base+"/collections/{collection_id}/delete")
def delete(collection_id: str, delete: DeleteEmbedding) -> List[UUID]:
print("Received delete request")
return bkend.delete(where=delete.where, ids=delete.ids,
collection_id=_uuid(collection_id), where_document=delete.where_document)
@app.get(api_base+"/collections/{collection_id}/count")
def count(collection_id:str) ->int:
print("Received count request")
return bkend.count(_uuid(collection_id))
@app.post(api_base+"/collections/{collection_id}/query")
def get_nearest_neighbors(collection_id: str, query: QueryEmbedding) -> QueryResult:
print("Received get_nearest_neighbors request")
return bkend.get_nearest_neighbors(collection_id=_uuid(collection_id), where=query.where, where_document=query.where_document,
query_embeddings=query.query_embeddings, n_results=query.n_results, include=query.include)
@app.post(api_base+"/collections/{collection_name}/create_index")
def create_index(collection_name:str)-> bool:
print("Received create_index request")
return bkend.create_index(collection_name)
@app.get(api_base+"/collections/{collection_name}")
def get_collection2( collection_name: str) -> Collection:
print("Received get_collection2 request")
return bkend.get_collection(collection_name)
@app.post(api_base+"/collections/{collection_id}")
def modify(collection_id: str, collection: UpdateCollection) -> None:
print("Received modify(collection) request")
return bkend.modify(id=_uuid(collection_id), new_name=collection.new_name, new_metadata=collection.new_metadata)
@app.delete(api_base+"/collections/{collection_name}")
def delete_collection(collection_name:str) -> None:
print("Received delete_collection request")
return bkend.delete_collection(collection_name)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)