Obseleting server.py, making direct calls in main
Browse files
main.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import fastapi
|
2 |
-
import json
|
3 |
import uvicorn
|
4 |
from fastapi import HTTPException , status
|
5 |
from fastapi.responses import JSONResponse
|
@@ -8,6 +8,8 @@ from fastapi import FastAPI as Response
|
|
8 |
from sse_starlette.sse import EventSourceResponse
|
9 |
from starlette.responses import StreamingResponse
|
10 |
from starlette.requests import Request
|
|
|
|
|
11 |
from pydantic import BaseModel
|
12 |
from typing import List, Dict, Any, Generator, Optional, cast, Callable
|
13 |
from server import client
|
@@ -83,13 +85,17 @@ app.add_middleware(
|
|
83 |
allow_headers=["*"],
|
84 |
)
|
85 |
api_base="/api/v1"
|
86 |
-
|
|
|
|
|
|
|
|
|
87 |
|
88 |
|
89 |
@app.get(api_base+"")
|
90 |
def heartbeat():
|
91 |
print("Received heartbeat request")
|
92 |
-
return
|
93 |
|
94 |
@app.post(api_base+"/reset")
|
95 |
def reset():
|
@@ -99,7 +105,7 @@ def reset():
|
|
99 |
@app.get(api_base+"/version")
|
100 |
def version():
|
101 |
print("Received version request")
|
102 |
-
return bkend.
|
103 |
|
104 |
@app.post(api_base+"/persist")
|
105 |
def persist():
|
@@ -109,12 +115,12 @@ def persist():
|
|
109 |
@app.post(api_base+"/raw_sql")
|
110 |
def raw_sql(raw_sql: RawSql):
|
111 |
print("Received raw_sql request")
|
112 |
-
return bkend.raw_sql(raw_sql)
|
113 |
|
114 |
@app.get(api_base+"/heartbeat")
|
115 |
-
def
|
116 |
-
print("Received
|
117 |
-
return
|
118 |
|
119 |
@app.get(api_base+"/collections")
|
120 |
def list_collections():
|
@@ -124,12 +130,12 @@ def list_collections():
|
|
124 |
@app.post(api_base+"/collections")
|
125 |
def create_collection( collection: CreateCollection ) -> Collection:
|
126 |
print("Received request to create_collection")
|
127 |
-
return bkend.create_collection(name=collection.name,metadata=collection.metadata,get_or_create=collection.get_or_create)
|
128 |
|
129 |
@app.get(api_base+"/collections/{collection_name}")
|
130 |
def get_collection( collection_name: str) -> Collection:
|
131 |
print("Received get_collection request")
|
132 |
-
return bkend.get_collection(collection_name)
|
133 |
|
134 |
@app.post(api_base+"/collections/{collection_id}/add")
|
135 |
def add(collection_id:str , add:AddEmbedding) -> None:
|
@@ -143,35 +149,35 @@ def add(collection_id:str , add:AddEmbedding) -> None:
|
|
143 |
@app.post(api_base+"/collections/{collection_id}/update")
|
144 |
def update(collection_id:str , update:UpdateEmbedding) -> None:
|
145 |
print("Received update request")
|
146 |
-
return bkend.
|
147 |
|
148 |
@app.post(api_base+"/collections/{collection_id}/upsert")
|
149 |
def upsert(collection_id:str, upsert: AddEmbedding):
|
150 |
print("Received upsert request")
|
151 |
-
return bkend.
|
152 |
|
153 |
@app.post(api_base+"/collections/{collection_id}/get")
|
154 |
def get(self, collection_id: str, get: GetEmbedding) -> GetResult:
|
155 |
print("Received get request")
|
156 |
-
return bkend.
|
157 |
where_document=get.where_document, sort=get.sort, limit=get.limit,
|
158 |
offset=get.offset, include=get.include)
|
159 |
|
160 |
@app.post(api_base+"/collections/{collection_id}/delete")
|
161 |
def delete(collection_id: str, delete: DeleteEmbedding) -> List[UUID]:
|
162 |
print("Received delete request")
|
163 |
-
return bkend.
|
164 |
collection_id=_uuid(collection_id), where_document=delete.where_document)
|
165 |
|
166 |
@app.get(api_base+"/collections/{collection_id}/count")
|
167 |
def count(collection_id:str) ->int:
|
168 |
print("Received count request")
|
169 |
-
return bkend.
|
170 |
|
171 |
@app.post(api_base+"/collections/{collection_id}/query")
|
172 |
def get_nearest_neighbors(collection_id: str, query: QueryEmbedding) -> QueryResult:
|
173 |
print("Received get_nearest_neighbors request")
|
174 |
-
return bkend.
|
175 |
query_embeddings=query.query_embeddings, n_results=query.n_results, include=query.include)
|
176 |
|
177 |
@app.post(api_base+"/collections/{collection_name}/create_index")
|
@@ -179,15 +185,10 @@ def create_index(collection_name:str)-> bool:
|
|
179 |
print("Received create_index request")
|
180 |
return bkend.create_index(collection_name)
|
181 |
|
182 |
-
@app.get(api_base+"/collections/{collection_name}")
|
183 |
-
def get_collection2( collection_name: str) -> Collection:
|
184 |
-
print("Received get_collection2 request")
|
185 |
-
return bkend.get_collection(collection_name)
|
186 |
-
|
187 |
@app.post(api_base+"/collections/{collection_id}")
|
188 |
def modify(collection_id: str, collection: UpdateCollection) -> None:
|
189 |
-
print("Received modify
|
190 |
-
return bkend.
|
191 |
|
192 |
@app.delete(api_base+"/collections/{collection_name}")
|
193 |
def delete_collection(collection_name:str) -> None:
|
|
|
1 |
import fastapi
|
2 |
+
import json,time
|
3 |
import uvicorn
|
4 |
from fastapi import HTTPException , status
|
5 |
from fastapi.responses import JSONResponse
|
|
|
8 |
from sse_starlette.sse import EventSourceResponse
|
9 |
from starlette.responses import StreamingResponse
|
10 |
from starlette.requests import Request
|
11 |
+
import chromadb
|
12 |
+
from chromadb.config import Settings, System
|
13 |
from pydantic import BaseModel
|
14 |
from typing import List, Dict, Any, Generator, Optional, cast, Callable
|
15 |
from server import client
|
|
|
85 |
allow_headers=["*"],
|
86 |
)
|
87 |
api_base="/api/v1"
|
88 |
+
embedding_function=ef.DefaultEmbeddingFunction()
|
89 |
+
bkend=chromadb.Client(Settings(
|
90 |
+
chroma_db_impl="duckdb+parquet",
|
91 |
+
persist_directory="./index/chroma" # Optional, defaults to .chromadb/ in the current directory
|
92 |
+
))
|
93 |
|
94 |
|
95 |
@app.get(api_base+"")
|
96 |
def heartbeat():
|
97 |
print("Received heartbeat request")
|
98 |
+
return {"nanosecond heartbeat":int(time.time_ns())}
|
99 |
|
100 |
@app.post(api_base+"/reset")
|
101 |
def reset():
|
|
|
105 |
@app.get(api_base+"/version")
|
106 |
def version():
|
107 |
print("Received version request")
|
108 |
+
return bkend.get_version()
|
109 |
|
110 |
@app.post(api_base+"/persist")
|
111 |
def persist():
|
|
|
115 |
@app.post(api_base+"/raw_sql")
|
116 |
def raw_sql(raw_sql: RawSql):
|
117 |
print("Received raw_sql request")
|
118 |
+
return bkend.raw_sql(raw_sql.raw_sql)
|
119 |
|
120 |
@app.get(api_base+"/heartbeat")
|
121 |
+
def heartbeat1():
|
122 |
+
print("Received heartbeat1 request")
|
123 |
+
return heartbeat()
|
124 |
|
125 |
@app.get(api_base+"/collections")
|
126 |
def list_collections():
|
|
|
130 |
@app.post(api_base+"/collections")
|
131 |
def create_collection( collection: CreateCollection ) -> Collection:
|
132 |
print("Received request to create_collection")
|
133 |
+
return bkend.create_collection(name=collection.name,metadata=collection.metadata,embedding_function=embedding_function,get_or_create=collection.get_or_create)
|
134 |
|
135 |
@app.get(api_base+"/collections/{collection_name}")
|
136 |
def get_collection( collection_name: str) -> Collection:
|
137 |
print("Received get_collection request")
|
138 |
+
return bkend.get_collection(collection_name,embedding_function=embedding_function)
|
139 |
|
140 |
@app.post(api_base+"/collections/{collection_id}/add")
|
141 |
def add(collection_id:str , add:AddEmbedding) -> None:
|
|
|
149 |
@app.post(api_base+"/collections/{collection_id}/update")
|
150 |
def update(collection_id:str , update:UpdateEmbedding) -> None:
|
151 |
print("Received update request")
|
152 |
+
return bkend._update(ids=update.ids, collection_id=_uuid(collection_id), embeddings=update.embeddings, documents=update.documents, metadatas=update.metadatas)
|
153 |
|
154 |
@app.post(api_base+"/collections/{collection_id}/upsert")
|
155 |
def upsert(collection_id:str, upsert: AddEmbedding):
|
156 |
print("Received upsert request")
|
157 |
+
return bkend._upsert(collection_id=_uuid(collection_id),embeddings=upsert.embeddings,metadatas=upsert.metadatas,documents=upsert.documents,ids=upsert.ids,increment_index=upsert.increment_index)
|
158 |
|
159 |
@app.post(api_base+"/collections/{collection_id}/get")
|
160 |
def get(self, collection_id: str, get: GetEmbedding) -> GetResult:
|
161 |
print("Received get request")
|
162 |
+
return bkend._get(collection_id=_uuid(collection_id), ids=get.ids, where=get.where,
|
163 |
where_document=get.where_document, sort=get.sort, limit=get.limit,
|
164 |
offset=get.offset, include=get.include)
|
165 |
|
166 |
@app.post(api_base+"/collections/{collection_id}/delete")
|
167 |
def delete(collection_id: str, delete: DeleteEmbedding) -> List[UUID]:
|
168 |
print("Received delete request")
|
169 |
+
return bkend._delete(where=delete.where, ids=delete.ids,
|
170 |
collection_id=_uuid(collection_id), where_document=delete.where_document)
|
171 |
|
172 |
@app.get(api_base+"/collections/{collection_id}/count")
|
173 |
def count(collection_id:str) ->int:
|
174 |
print("Received count request")
|
175 |
+
return bkend._count(_uuid(collection_id))
|
176 |
|
177 |
@app.post(api_base+"/collections/{collection_id}/query")
|
178 |
def get_nearest_neighbors(collection_id: str, query: QueryEmbedding) -> QueryResult:
|
179 |
print("Received get_nearest_neighbors request")
|
180 |
+
return bkend._query(collection_id=_uuid(collection_id), where=query.where, where_document=query.where_document,
|
181 |
query_embeddings=query.query_embeddings, n_results=query.n_results, include=query.include)
|
182 |
|
183 |
@app.post(api_base+"/collections/{collection_name}/create_index")
|
|
|
185 |
print("Received create_index request")
|
186 |
return bkend.create_index(collection_name)
|
187 |
|
|
|
|
|
|
|
|
|
|
|
188 |
@app.post(api_base+"/collections/{collection_id}")
|
189 |
def modify(collection_id: str, collection: UpdateCollection) -> None:
|
190 |
+
print("Received modify-collection request")
|
191 |
+
return bkend._modify(id=_uuid(collection_id), new_name=collection.new_name, new_metadata=collection.new_metadata)
|
192 |
|
193 |
@app.delete(api_base+"/collections/{collection_name}")
|
194 |
def delete_collection(collection_name:str) -> None:
|
server.py
CHANGED
@@ -131,7 +131,7 @@ class client():
|
|
131 |
return self.db._delete(where=where, ids=ids, collection_id=collection_id, where_document=where_document)
|
132 |
|
133 |
def count(self, collection_id: UUID) -> int:
|
134 |
-
return self.db.
|
135 |
|
136 |
def get_nearest_neighbors( self, collection_id: UUID, query_embeddings: Embeddings,
|
137 |
n_results: int = 10, where: Where = {}, where_document: WhereDocument = {},
|
|
|
131 |
return self.db._delete(where=where, ids=ids, collection_id=collection_id, where_document=where_document)
|
132 |
|
133 |
def count(self, collection_id: UUID) -> int:
|
134 |
+
return self.db._count(collection_id)
|
135 |
|
136 |
def get_nearest_neighbors( self, collection_id: UUID, query_embeddings: Embeddings,
|
137 |
n_results: int = 10, where: Where = {}, where_document: WhereDocument = {},
|