Spaces:
Sleeping
Sleeping
Commit
·
d323684
1
Parent(s):
7201688
fixed qudrant search query
Browse files- main.py +76 -40
- rag/repository.py +16 -17
main.py
CHANGED
@@ -1,55 +1,79 @@
|
|
1 |
# main.py
|
2 |
-
from fastapi import
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
from schemas import VoicePresets
|
6 |
-
from utils import audio_array_to_buffer,img_to_bytes
|
7 |
from contextlib import asynccontextmanager
|
8 |
-
from typing import AsyncIterator,Callable,Awaitable,Annotated
|
9 |
from uuid import uuid4
|
10 |
import time
|
11 |
from datetime import datetime, timezone
|
12 |
import csv
|
13 |
-
from dependencies import get_urls_content,get_rag_content
|
14 |
-
from schemas import TextModelResponse,TextModelRequest
|
15 |
import shutil, uuid
|
16 |
from upload import save_file
|
17 |
from rag import pdf_text_extractor, vector_service
|
|
|
18 |
|
|
|
19 |
|
20 |
-
models = {}
|
21 |
|
22 |
-
@asynccontextmanager
|
23 |
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
|
24 |
-
# models["text2image"] = load_image_model()
|
25 |
# models["text"]=load_text_model()
|
26 |
-
yield
|
27 |
-
models.clear()
|
|
|
28 |
|
29 |
app = FastAPI(lifespan=lifespan)
|
30 |
|
31 |
csv_header = [
|
32 |
-
"Request ID",
|
33 |
-
"
|
|
|
|
|
|
|
|
|
|
|
34 |
]
|
35 |
|
36 |
|
37 |
-
@app.middleware("http")
|
38 |
async def monitor_service(
|
39 |
req: Request, call_next: Callable[[Request], Awaitable[Response]]
|
40 |
-
) -> Response:
|
41 |
-
request_id = uuid4().hex
|
42 |
request_datetime = datetime.now(timezone.utc).isoformat()
|
43 |
start_time = time.perf_counter()
|
44 |
response: Response = await call_next(req)
|
45 |
-
response_time = round(time.perf_counter() - start_time, 4)
|
46 |
response.headers["X-Response-Time"] = str(response_time)
|
47 |
-
response.headers["X-API-Request-ID"] = request_id
|
48 |
with open("usage.csv", "a", newline="") as file:
|
49 |
writer = csv.writer(file)
|
50 |
if file.tell() == 0:
|
51 |
writer.writerow(csv_header)
|
52 |
-
writer.writerow(
|
53 |
[
|
54 |
request_id,
|
55 |
request_datetime,
|
@@ -63,21 +87,24 @@ async def monitor_service(
|
|
63 |
return response
|
64 |
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
# app = FastAPI()
|
69 |
@app.get("/")
|
70 |
def root_controller():
|
71 |
return {"status": "healthy"}
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
79 |
return TextModelResponse(content=output, ip=request.client.host)
|
80 |
|
|
|
81 |
@app.get("/logs")
|
82 |
def get_logs():
|
83 |
# return FileResponse("usage.csv", media_type='text/csv', filename="usage.csv")
|
@@ -89,14 +116,15 @@ def get_logs():
|
|
89 |
temp_file,
|
90 |
media_type="text/csv",
|
91 |
filename="logs.csv",
|
92 |
-
headers={"Content-Disposition": "attachment; filename=logs.csv"}
|
93 |
)
|
94 |
|
|
|
95 |
@app.get(
|
96 |
"/generate/audio",
|
97 |
responses={status.HTTP_200_OK: {"content": {"audio/wav": {}}}},
|
98 |
response_class=StreamingResponse,
|
99 |
-
)
|
100 |
def serve_text_to_audio_model_controller(
|
101 |
prompt: str,
|
102 |
preset: VoicePresets = "v2/en_speaker_1",
|
@@ -108,14 +136,22 @@ def serve_text_to_audio_model_controller(
|
|
108 |
)
|
109 |
|
110 |
|
111 |
-
@app.get(
|
112 |
-
|
113 |
-
|
|
|
|
|
114 |
def serve_text_to_image_model_controller(prompt: str):
|
115 |
# pipe = load_image_model()
|
116 |
-
# output = generate_image(pipe, prompt)
|
117 |
output = generate_image(models["text2image"], prompt)
|
118 |
-
return Response(content=img_to_bytes(output), media_type="image/png")
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
# @app.post("/upload")
|
121 |
# async def file_upload_controller(
|
@@ -139,13 +175,13 @@ def serve_text_to_image_model_controller(prompt: str):
|
|
139 |
@app.post("/upload")
|
140 |
async def file_upload_controller(
|
141 |
file: Annotated[UploadFile, File(description="A file read as UploadFile")],
|
142 |
-
bg_text_processor: BackgroundTasks,
|
143 |
):
|
144 |
-
...
|
145 |
try:
|
146 |
filepath = await save_file(file)
|
147 |
-
bg_text_processor.add_task(pdf_text_extractor, filepath)
|
148 |
-
bg_text_processor.add_task(
|
149 |
vector_service.store_file_content_in_db,
|
150 |
filepath.replace("pdf", "txt"),
|
151 |
512,
|
|
|
1 |
# main.py
|
2 |
+
from fastapi import (
|
3 |
+
FastAPI,
|
4 |
+
status,
|
5 |
+
Response,
|
6 |
+
Request,
|
7 |
+
Depends,
|
8 |
+
HTTPException,
|
9 |
+
UploadFile,
|
10 |
+
File,
|
11 |
+
BackgroundTasks,
|
12 |
+
)
|
13 |
+
from fastapi.responses import StreamingResponse, FileResponse
|
14 |
+
from models import (
|
15 |
+
load_text_model,
|
16 |
+
generate_text,
|
17 |
+
load_audio_model,
|
18 |
+
generate_audio,
|
19 |
+
load_image_model,
|
20 |
+
generate_image,
|
21 |
+
)
|
22 |
from schemas import VoicePresets
|
23 |
+
from utils import audio_array_to_buffer, img_to_bytes
|
24 |
from contextlib import asynccontextmanager
|
25 |
+
from typing import AsyncIterator, Callable, Awaitable, Annotated
|
26 |
from uuid import uuid4
|
27 |
import time
|
28 |
from datetime import datetime, timezone
|
29 |
import csv
|
30 |
+
from dependencies import get_urls_content, get_rag_content
|
31 |
+
from schemas import TextModelResponse, TextModelRequest
|
32 |
import shutil, uuid
|
33 |
from upload import save_file
|
34 |
from rag import pdf_text_extractor, vector_service
|
35 |
+
from scalar_fastapi import get_scalar_api_reference
|
36 |
|
37 |
+
models = {}
|
38 |
|
|
|
39 |
|
40 |
+
@asynccontextmanager
|
41 |
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
|
42 |
+
# models["text2image"] = load_image_model()
|
43 |
# models["text"]=load_text_model()
|
44 |
+
yield
|
45 |
+
models.clear()
|
46 |
+
|
47 |
|
48 |
app = FastAPI(lifespan=lifespan)
|
49 |
|
50 |
csv_header = [
|
51 |
+
"Request ID",
|
52 |
+
"Datetime",
|
53 |
+
"Endpoint Triggered",
|
54 |
+
"Client IP Address",
|
55 |
+
"Response Time",
|
56 |
+
"Status Code",
|
57 |
+
"Successful",
|
58 |
]
|
59 |
|
60 |
|
61 |
+
@app.middleware("http")
|
62 |
async def monitor_service(
|
63 |
req: Request, call_next: Callable[[Request], Awaitable[Response]]
|
64 |
+
) -> Response:
|
65 |
+
request_id = uuid4().hex
|
66 |
request_datetime = datetime.now(timezone.utc).isoformat()
|
67 |
start_time = time.perf_counter()
|
68 |
response: Response = await call_next(req)
|
69 |
+
response_time = round(time.perf_counter() - start_time, 4)
|
70 |
response.headers["X-Response-Time"] = str(response_time)
|
71 |
+
response.headers["X-API-Request-ID"] = request_id
|
72 |
with open("usage.csv", "a", newline="") as file:
|
73 |
writer = csv.writer(file)
|
74 |
if file.tell() == 0:
|
75 |
writer.writerow(csv_header)
|
76 |
+
writer.writerow(
|
77 |
[
|
78 |
request_id,
|
79 |
request_datetime,
|
|
|
87 |
return response
|
88 |
|
89 |
|
90 |
+
# app = FastAPI()
|
|
|
|
|
91 |
@app.get("/")
|
92 |
def root_controller():
|
93 |
return {"status": "healthy"}
|
94 |
|
95 |
+
|
96 |
+
@app.post("/generate/text")
|
97 |
+
async def serve_language_model_controller(
|
98 |
+
request: Request,
|
99 |
+
body: TextModelRequest,
|
100 |
+
urls_content: str = Depends(get_urls_content),
|
101 |
+
rag_content: str = Depends(get_rag_content),
|
102 |
+
) -> TextModelResponse:
|
103 |
+
prompt = body.prompt + " " + urls_content + rag_content
|
104 |
+
output = generate_text(models["text"], prompt, body.temperature)
|
105 |
return TextModelResponse(content=output, ip=request.client.host)
|
106 |
|
107 |
+
|
108 |
@app.get("/logs")
|
109 |
def get_logs():
|
110 |
# return FileResponse("usage.csv", media_type='text/csv', filename="usage.csv")
|
|
|
116 |
temp_file,
|
117 |
media_type="text/csv",
|
118 |
filename="logs.csv",
|
119 |
+
headers={"Content-Disposition": "attachment; filename=logs.csv"},
|
120 |
)
|
121 |
|
122 |
+
|
123 |
@app.get(
|
124 |
"/generate/audio",
|
125 |
responses={status.HTTP_200_OK: {"content": {"audio/wav": {}}}},
|
126 |
response_class=StreamingResponse,
|
127 |
+
)
|
128 |
def serve_text_to_audio_model_controller(
|
129 |
prompt: str,
|
130 |
preset: VoicePresets = "v2/en_speaker_1",
|
|
|
136 |
)
|
137 |
|
138 |
|
139 |
+
@app.get(
|
140 |
+
"/generate/image",
|
141 |
+
responses={status.HTTP_200_OK: {"content": {"image/png": {}}}},
|
142 |
+
response_class=Response,
|
143 |
+
)
|
144 |
def serve_text_to_image_model_controller(prompt: str):
|
145 |
# pipe = load_image_model()
|
146 |
+
# output = generate_image(pipe, prompt)
|
147 |
output = generate_image(models["text2image"], prompt)
|
148 |
+
return Response(content=img_to_bytes(output), media_type="image/png")
|
149 |
+
|
150 |
+
|
151 |
+
@app.get("/scalar")
|
152 |
+
def get_scalar_docs():
|
153 |
+
return get_scalar_api_reference(openapi_url=app.openapi_url, title=app.title)
|
154 |
+
|
155 |
|
156 |
# @app.post("/upload")
|
157 |
# async def file_upload_controller(
|
|
|
175 |
@app.post("/upload")
|
176 |
async def file_upload_controller(
|
177 |
file: Annotated[UploadFile, File(description="A file read as UploadFile")],
|
178 |
+
bg_text_processor: BackgroundTasks,
|
179 |
):
|
180 |
+
... # Raise an HTTPException if data upload is not a PDF file
|
181 |
try:
|
182 |
filepath = await save_file(file)
|
183 |
+
bg_text_processor.add_task(pdf_text_extractor, filepath)
|
184 |
+
bg_text_processor.add_task(
|
185 |
vector_service.store_file_content_in_db,
|
186 |
filepath.replace("pdf", "txt"),
|
187 |
512,
|
rag/repository.py
CHANGED
@@ -3,27 +3,26 @@ from qdrant_client import AsyncQdrantClient
|
|
3 |
from qdrant_client.http import models
|
4 |
from qdrant_client.http.models import ScoredPoint
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
8 |
# self.db_client = AsyncQdrantClient(host=host, port=port)
|
9 |
self.db_client = AsyncQdrantClient(
|
10 |
-
|
11 |
-
|
12 |
-
)
|
13 |
-
async def create_collection(self, collection_name: str, size: int) -> bool:
|
14 |
-
vectors_config = models.VectorParams(
|
15 |
-
size=size, distance=models.Distance.COSINE
|
16 |
)
|
|
|
|
|
|
|
17 |
response = await self.db_client.get_collections()
|
18 |
|
19 |
collection_exists = any(
|
20 |
-
collection.name == collection_name
|
21 |
-
for collection in response.collections
|
22 |
)
|
23 |
-
if collection_exists:
|
24 |
-
logger.debug(
|
25 |
-
f"Collection {collection_name} already exists - recreating it"
|
26 |
-
)
|
27 |
await self.db_client.delete_collection(collection_name)
|
28 |
return await self.db_client.create_collection(
|
29 |
collection_name,
|
@@ -73,15 +72,15 @@ class VectorRepository:
|
|
73 |
collection_name: str,
|
74 |
query_vector: list[float],
|
75 |
retrieval_limit: int,
|
76 |
-
score_threshold: float,
|
77 |
) -> list[ScoredPoint]:
|
78 |
logger.debug(
|
79 |
f"Searching for relevant items in the {collection_name} collection"
|
80 |
)
|
81 |
response = await self.db_client.query_points(
|
82 |
collection_name=collection_name,
|
83 |
-
|
84 |
limit=retrieval_limit,
|
85 |
score_threshold=score_threshold,
|
86 |
)
|
87 |
-
return response.points
|
|
|
3 |
from qdrant_client.http import models
|
4 |
from qdrant_client.http.models import ScoredPoint
|
5 |
|
6 |
+
|
7 |
+
class VectorRepository:
|
8 |
+
def __init__(
|
9 |
+
self, host: str = "https://ahmed-eisa-qdrant-db.hf.space", port: int = 6333
|
10 |
+
) -> None:
|
11 |
# self.db_client = AsyncQdrantClient(host=host, port=port)
|
12 |
self.db_client = AsyncQdrantClient(
|
13 |
+
url="https://e8342d34-1b50-48e3-95e2-d4eacd0755eb.us-east4-0.gcp.cloud.qdrant.io:6333",
|
14 |
+
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.Q6rLdYDzVyr10B4AdYJHcPp9pCqWG7yhQ-NNmfWZqg8",
|
|
|
|
|
|
|
|
|
15 |
)
|
16 |
+
|
17 |
+
async def create_collection(self, collection_name: str, size: int) -> bool:
|
18 |
+
vectors_config = models.VectorParams(size=size, distance=models.Distance.COSINE)
|
19 |
response = await self.db_client.get_collections()
|
20 |
|
21 |
collection_exists = any(
|
22 |
+
collection.name == collection_name for collection in response.collections
|
|
|
23 |
)
|
24 |
+
if collection_exists:
|
25 |
+
logger.debug(f"Collection {collection_name} already exists - recreating it")
|
|
|
|
|
26 |
await self.db_client.delete_collection(collection_name)
|
27 |
return await self.db_client.create_collection(
|
28 |
collection_name,
|
|
|
72 |
collection_name: str,
|
73 |
query_vector: list[float],
|
74 |
retrieval_limit: int,
|
75 |
+
score_threshold: float,
|
76 |
) -> list[ScoredPoint]:
|
77 |
logger.debug(
|
78 |
f"Searching for relevant items in the {collection_name} collection"
|
79 |
)
|
80 |
response = await self.db_client.query_points(
|
81 |
collection_name=collection_name,
|
82 |
+
query=query_vector,
|
83 |
limit=retrieval_limit,
|
84 |
score_threshold=score_threshold,
|
85 |
)
|
86 |
+
return response.points
|