anubhav77 commited on
Commit
d6aeaee
·
1 Parent(s): ddcedb6

Adding all other methods

Browse files
Files changed (2) hide show
  1. main.py +92 -9
  2. server.py +71 -9
main.py CHANGED
@@ -1,13 +1,15 @@
1
  import fastapi
2
  import json
3
  import uvicorn
4
- from fastapi import HTTPException
5
- from fastapi.responses import HTMLResponse
6
  from fastapi.middleware.cors import CORSMiddleware
 
7
  from sse_starlette.sse import EventSourceResponse
8
  from starlette.responses import StreamingResponse
 
9
  from pydantic import BaseModel
10
- from typing import List, Dict, Any, Generator, Optional, cast
11
  from server import client
12
  from chromadb.api.types import (
13
  Documents,
@@ -50,8 +52,29 @@ from uuid import UUID
50
  from chromadb.telemetry import Telemetry
51
  from overrides import override
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  app = fastapi.FastAPI(title="ChromaDB")
 
55
  app.add_middleware(
56
  CORSMiddleware,
57
  allow_origins=["*"],
@@ -62,6 +85,7 @@ app.add_middleware(
62
  api_base="/api/v1"
63
  bkend=client()
64
 
 
65
  @app.get(api_base+"")
66
  def heartbeat():
67
  print("Received heartbeat request")
@@ -98,18 +122,77 @@ def list_collections():
98
  return bkend.list_collections()
99
 
100
  @app.post(api_base+"/collections")
101
- def create_collection(
102
- collection: CreateCollection
103
- ) -> Collection:
104
  print("Received request to create_collection")
105
  return bkend.create_collection(name=collection.name,metadata=collection.metadata,get_or_create=collection.get_or_create)
106
 
107
  @app.get(api_base+"/collections/{collection_name}")
108
- def get_collection(
109
- collection_name: str,
110
- ) -> Collection:
111
  return bkend.get_collection(collection_name)
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
 
115
  if __name__ == "__main__":
 
1
  import fastapi
2
  import json
3
  import uvicorn
4
+ from fastapi import HTTPException , status
5
+ from fastapi.responses import JSONResponse
6
  from fastapi.middleware.cors import CORSMiddleware
7
+ from fastapi import FASTAPI as _FastAPI, Response
8
  from sse_starlette.sse import EventSourceResponse
9
  from starlette.responses import StreamingResponse
10
+ from starlette.requests import Request
11
  from pydantic import BaseModel
12
+ from typing import List, Dict, Any, Generator, Optional, cast, Callable
13
  from server import client
14
  from chromadb.api.types import (
15
  Documents,
 
52
  from chromadb.telemetry import Telemetry
53
  from overrides import override
54
 
55
+ async def catch_exceptions_middleware(
56
+ request: Request, call_next: Callable[[Request], Any]
57
+ ) -> Response:
58
+ try:
59
+ return await call_next(request)
60
+ except ChromaError as e:
61
+ return JSONResponse(
62
+ content={"error": e.name(), "message": e.message()}, status_code=e.code()
63
+ )
64
+ except Exception as e:
65
+ return JSONResponse(content={"error": repr(e)}, status_code=500)
66
+
67
+
68
+
69
+ def _uuid(uuid_str: str) -> UUID:
70
+ try:
71
+ return UUID(uuid_str)
72
+ except ValueError:
73
+ raise InvalidUUIDError(f"Could not parse {uuid_str} as a UUID")
74
+
75
 
76
  app = fastapi.FastAPI(title="ChromaDB")
77
+ app.middleware("http")(catch_exceptions_middleware)
78
  app.add_middleware(
79
  CORSMiddleware,
80
  allow_origins=["*"],
 
85
  api_base="/api/v1"
86
  bkend=client()
87
 
88
+
89
  @app.get(api_base+"")
90
  def heartbeat():
91
  print("Received heartbeat request")
 
122
  return bkend.list_collections()
123
 
124
  @app.post(api_base+"/collections")
125
+ def create_collection( collection: CreateCollection ) -> Collection:
 
 
126
  print("Received request to create_collection")
127
  return bkend.create_collection(name=collection.name,metadata=collection.metadata,get_or_create=collection.get_or_create)
128
 
129
  @app.get(api_base+"/collections/{collection_name}")
130
+ def get_collection( collection_name: str) -> Collection:
131
+ print("Received get_collection request")
 
132
  return bkend.get_collection(collection_name)
133
 
134
+ @app.post(api_base+"/collections/{collection_id}/add")
135
+ def add(collection_id:str , add:AddEmbedding) -> None:
136
+ print("Received add request")
137
+ try:
138
+ result=bkend.add(collection_id=_uuid(collection_id),embeddings=add.embeddings,metadatas=add.metadatas,documents=add.documents,ids=add.ids,increment_index=add.increment_index)
139
+ except InvalidDimensionException as e:
140
+ raise HTTPException(status_code=500, detail=str(e))
141
+ return result
142
+
143
+ @app.post(api_base+"/collections/{collection_id}/update")
144
+ def update(collection_id:str , update:UpdateEmbedding) -> None:
145
+ print("Received update request")
146
+ return bkend.update(ids=update.ids, collection_id=_uuid(collection_id), embeddings=update.embeddings, documents=update.documents, metadatas=update.metadatas)
147
+
148
+ @app.post(api_base+"/collections/{collection_id}/upsert")
149
+ def upsert(collection_id:str, upsert: AddEmbedding):
150
+ print("Received upsert request")
151
+ return bkend.upsert(collection_id=_uuid(collection_id),embeddings=upsert.embeddings,metadatas=upsert.metadatas,documents=upsert.documents,ids=upsert.ids,increment_index=upsert.increment_index)
152
+
153
+ @app.post(api_base+"/collections/{collection_id}/get")
154
+ def get(self, collection_id: str, get: GetEmbedding) -> GetResult:
155
+ print("Received get request")
156
+ return bkend.get(collection_id=_uuid(collection_id), ids=get.ids, where=get.where,
157
+ where_document=get.where_document, sort=get.sort, limit=get.limit,
158
+ offset=get.offset, include=get.include)
159
+
160
+ @app.post(api_base+"/collections/{collection_id}/delete")
161
+ def delete(collection_id: str, delete: DeleteEmbedding) -> List[UUID]:
162
+ print("Received delete request")
163
+ return bkend.delete(where=delete.where, ids=delete.ids,
164
+ collection_id=_uuid(collection_id), where_document=delete.where_document)
165
+
166
+ @app.get(api_base+"/collections/{collection_id}/count")
167
+ def count(collection_id:str) ->int:
168
+ print("Received count request")
169
+ return bkend.count(_uuid(collection_id))
170
+
171
+ @app.post(api_base+"/collections/{collection_id}/query")
172
+ def get_nearest_neighbors(collection_id: str, query: QueryEmbedding) -> QueryResult:
173
+ print("Received get_nearest_neighbors request")
174
+ return bkend.get_nearest_neighbors(collection_id=_uuid(collection_id), where=query.where, where_document=query.where_document,
175
+ query_embeddings=query.query_embeddings, n_results=query.n_results, include=query.include)
176
+
177
+ @app.post(api_base+"/collections/{collection_name}/create_index")
178
+ def create_index(collection_name:str)-> bool:
179
+ print("Received create_index request")
180
+ return bkend.create_index(collection_name)
181
+
182
+ @app.get(api_base+"/collections/{collection_name}")
183
+ def get_collection2( collection_name: str) -> Collection:
184
+ print("Received get_collection2 request")
185
+ return bkend.get_collection(collection_name)
186
+
187
+ @app.post(api_base+"/collections/{collection_id}")
188
+ def modify(collection_id: str, collection: UpdateCollection) -> None:
189
+ print("Received modify(collection) request")
190
+ return bkend.modify(id=_uuid(collection_id), new_name=collection.new_name, new_metadata=collection.new_metadata)
191
+
192
+ @app.delete(api_base+"/collections/{collection_name}")
193
+ def delete_collection(collection_name:str) -> None:
194
+ print("Received delete_collection request")
195
+ return bkend.delete_collection(collection_name)
196
 
197
 
198
  if __name__ == "__main__":
server.py CHANGED
@@ -1,7 +1,8 @@
1
- from chromadb.config import Settings
2
  import chromadb
 
 
3
  from pydantic import BaseModel
4
- from typing import List, Dict, Any, Generator, Optional, cast
5
  from chromadb.api.types import (
6
  Documents,
7
  Embeddings,
@@ -44,8 +45,6 @@ from chromadb.telemetry import Telemetry
44
  from overrides import override
45
 
46
 
47
- import time
48
-
49
  class client():
50
  def __init__(self):
51
  self.db = chromadb.Client(Settings(
@@ -73,20 +72,83 @@ class client():
73
  def get_collection(
74
  self,
75
  name: str,
76
- embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
77
  ) -> Collection:
78
  col=self.db.get_collection(name,embedding_function=embedding_function)
79
  print(col)
80
  return col
81
 
82
- def reset():
83
  return self.db.reset()
84
 
85
- def version():
86
  return self.db.get_version()
87
 
88
- def persist():
89
  return self.db.persist()
90
 
91
- def raw_sql(raw_sql: RawSql):
92
  return self.db.raw_sql(raw_sql.raw_sql)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import chromadb
2
+ from chromadb.config import Settings, System
3
+ import time,json
4
  from pydantic import BaseModel
5
+ from typing import List, Dict, Any, Generator, Optional, cast, Callable
6
  from chromadb.api.types import (
7
  Documents,
8
  Embeddings,
 
45
  from overrides import override
46
 
47
 
 
 
48
  class client():
49
  def __init__(self):
50
  self.db = chromadb.Client(Settings(
 
72
  def get_collection(
73
  self,
74
  name: str,
75
+ embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction()
76
  ) -> Collection:
77
  col=self.db.get_collection(name,embedding_function=embedding_function)
78
  print(col)
79
  return col
80
 
81
+ def reset(self):
82
  return self.db.reset()
83
 
84
+ def version(self):
85
  return self.db.get_version()
86
 
87
+ def persist(self):
88
  return self.db.persist()
89
 
90
+ def raw_sql(self,raw_sql: RawSql):
91
  return self.db.raw_sql(raw_sql.raw_sql)
92
+
93
+ def add(self,ids: IDs,
94
+ collection_id: UUID,
95
+ embeddings: Embeddings,
96
+ metadatas: Optional[Metadatas] = None,
97
+ documents: Optional[Documents] = None,
98
+ increment_index: bool = True,
99
+ ) -> bool:
100
+ return self.db._add(collection_id=collection_id,embeddings=embeddings,
101
+ metadatas=metadatas,documents=documents,
102
+ ids=ids,increment_index=increment_index)
103
+
104
+ def update( self, collection_id: UUID, ids: IDs,
105
+ embeddings: Optional[Embeddings] = None,
106
+ metadatas: Optional[Metadatas] = None,
107
+ documents: Optional[Documents] = None,
108
+ ) -> bool:
109
+ return self.db._update(ids=ids, collection_id=collection_id, embeddings=embeddings, documents=documents, metadatas=metadatas)
110
+
111
+ def upsert( self, collection_id: UUID, ids: IDs,
112
+ embeddings: Embeddings,
113
+ metadatas: Optional[Metadatas] = None,
114
+ documents: Optional[Documents] = None,
115
+ increment_index: bool = True,
116
+ ) -> bool:
117
+ return self.db._upsert(collection_id=collection_id,embeddings=embeddings,metadatas=metadatas,documents=documents,ids=ids,increment_index=increment_index)
118
+
119
+ def get( self, collection_id: UUID, ids: Optional[IDs] = None, where: Optional[Where] = {},
120
+ sort: Optional[str] = None, limit: Optional[int] = None, offset: Optional[int] = None,
121
+ page: Optional[int] = None, page_size: Optional[int] = None,
122
+ where_document: Optional[WhereDocument] = {},
123
+ include: Include = ["embeddings", "metadatas", "documents"],
124
+ ) -> GetResult:
125
+ return self.db._get(collection_id=collection_id, ids=ids, where=where,
126
+ where_document=where_document, sort=sort, limit=limit,
127
+ offset=offset, include=include)
128
+ def delete( self, collection_id: UUID, ids: Optional[IDs],
129
+ where: Optional[Where] = {}, where_document: Optional[WhereDocument] = {},
130
+ ) -> IDs:
131
+ return self.db._delete(where=where, ids=ids, collection_id=collection_id, where_document=where_document)
132
+
133
+ def count(self, collection_id: UUID) -> int:
134
+ return self.db._.count(collection_id)
135
+
136
+ def get_nearest_neighbors( self, collection_id: UUID, query_embeddings: Embeddings,
137
+ n_results: int = 10, where: Where = {}, where_document: WhereDocument = {},
138
+ include: Include = ["embeddings", "metadatas", "documents", "distances"],
139
+ ) -> QueryResult:
140
+ return self.db._query(collection_id=collection_id, where=where, where_document=where_document,
141
+ query_embeddings=query_embeddings, n_results=n_results, include=include)
142
+
143
+ def create_index(self, collection_name: str) -> bool:
144
+ return self.db.create_index(collection_name)
145
+
146
+ def modify( self, id: UUID, new_name: Optional[str] = None,
147
+ new_metadata: Optional[CollectionMetadata] = None,
148
+ ) -> None:
149
+ """This is for updating the collection"""
150
+ return self.db._modify(id=id, new_name=new_name, new_metadata=new_metadata)
151
+
152
+ def delete_collection( self, name: str,) -> None:
153
+ return self.db.delete_collection(name)
154
+