ahmed-eisa commited on
Commit
d323684
·
1 Parent(s): 7201688

fixed qudrant search query

Browse files
Files changed (2) hide show
  1. main.py +76 -40
  2. rag/repository.py +16 -17
main.py CHANGED
@@ -1,55 +1,79 @@
1
  # main.py
2
- from fastapi import FastAPI,status,Response,Request,Depends,HTTPException,UploadFile, File,BackgroundTasks
3
- from fastapi.responses import StreamingResponse,FileResponse
4
- from models import load_text_model,generate_text,load_audio_model,generate_audio,load_image_model, generate_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from schemas import VoicePresets
6
- from utils import audio_array_to_buffer,img_to_bytes
7
  from contextlib import asynccontextmanager
8
- from typing import AsyncIterator,Callable,Awaitable,Annotated
9
  from uuid import uuid4
10
  import time
11
  from datetime import datetime, timezone
12
  import csv
13
- from dependencies import get_urls_content,get_rag_content
14
- from schemas import TextModelResponse,TextModelRequest
15
  import shutil, uuid
16
  from upload import save_file
17
  from rag import pdf_text_extractor, vector_service
 
18
 
 
19
 
20
- models = {}
21
 
22
- @asynccontextmanager
23
  async def lifespan(_: FastAPI) -> AsyncIterator[None]:
24
- # models["text2image"] = load_image_model()
25
  # models["text"]=load_text_model()
26
- yield
27
- models.clear()
 
28
 
29
  app = FastAPI(lifespan=lifespan)
30
 
31
  csv_header = [
32
- "Request ID", "Datetime", "Endpoint Triggered", "Client IP Address",
33
- "Response Time", "Status Code", "Successful"
 
 
 
 
 
34
  ]
35
 
36
 
37
- @app.middleware("http")
38
  async def monitor_service(
39
  req: Request, call_next: Callable[[Request], Awaitable[Response]]
40
- ) -> Response:
41
- request_id = uuid4().hex
42
  request_datetime = datetime.now(timezone.utc).isoformat()
43
  start_time = time.perf_counter()
44
  response: Response = await call_next(req)
45
- response_time = round(time.perf_counter() - start_time, 4)
46
  response.headers["X-Response-Time"] = str(response_time)
47
- response.headers["X-API-Request-ID"] = request_id
48
  with open("usage.csv", "a", newline="") as file:
49
  writer = csv.writer(file)
50
  if file.tell() == 0:
51
  writer.writerow(csv_header)
52
- writer.writerow(
53
  [
54
  request_id,
55
  request_datetime,
@@ -63,21 +87,24 @@ async def monitor_service(
63
  return response
64
 
65
 
66
-
67
-
68
- # app = FastAPI()
69
  @app.get("/")
70
  def root_controller():
71
  return {"status": "healthy"}
72
 
73
- @app.post("/generate/text")
74
- async def serve_language_model_controller(request: Request,
75
- body: TextModelRequest ,
76
- urls_content: str = Depends(get_urls_content), rag_content: str = Depends(get_rag_content)) -> TextModelResponse:
77
- prompt = body.prompt + " " + urls_content +rag_content
78
- output = generate_text(models["text"], prompt, body.temperature)
 
 
 
 
79
  return TextModelResponse(content=output, ip=request.client.host)
80
 
 
81
  @app.get("/logs")
82
  def get_logs():
83
  # return FileResponse("usage.csv", media_type='text/csv', filename="usage.csv")
@@ -89,14 +116,15 @@ def get_logs():
89
  temp_file,
90
  media_type="text/csv",
91
  filename="logs.csv",
92
- headers={"Content-Disposition": "attachment; filename=logs.csv"}
93
  )
94
 
 
95
  @app.get(
96
  "/generate/audio",
97
  responses={status.HTTP_200_OK: {"content": {"audio/wav": {}}}},
98
  response_class=StreamingResponse,
99
- )
100
  def serve_text_to_audio_model_controller(
101
  prompt: str,
102
  preset: VoicePresets = "v2/en_speaker_1",
@@ -108,14 +136,22 @@ def serve_text_to_audio_model_controller(
108
  )
109
 
110
 
111
- @app.get("/generate/image",
112
- responses={status.HTTP_200_OK: {"content": {"image/png": {}}}},
113
- response_class=Response)
 
 
114
  def serve_text_to_image_model_controller(prompt: str):
115
  # pipe = load_image_model()
116
- # output = generate_image(pipe, prompt)
117
  output = generate_image(models["text2image"], prompt)
118
- return Response(content=img_to_bytes(output), media_type="image/png")
 
 
 
 
 
 
119
 
120
  # @app.post("/upload")
121
  # async def file_upload_controller(
@@ -139,13 +175,13 @@ def serve_text_to_image_model_controller(prompt: str):
139
  @app.post("/upload")
140
  async def file_upload_controller(
141
  file: Annotated[UploadFile, File(description="A file read as UploadFile")],
142
- bg_text_processor: BackgroundTasks,
143
  ):
144
- ... # Raise an HTTPException if data upload is not a PDF file
145
  try:
146
  filepath = await save_file(file)
147
- bg_text_processor.add_task(pdf_text_extractor, filepath)
148
- bg_text_processor.add_task(
149
  vector_service.store_file_content_in_db,
150
  filepath.replace("pdf", "txt"),
151
  512,
 
1
  # main.py
2
+ from fastapi import (
3
+ FastAPI,
4
+ status,
5
+ Response,
6
+ Request,
7
+ Depends,
8
+ HTTPException,
9
+ UploadFile,
10
+ File,
11
+ BackgroundTasks,
12
+ )
13
+ from fastapi.responses import StreamingResponse, FileResponse
14
+ from models import (
15
+ load_text_model,
16
+ generate_text,
17
+ load_audio_model,
18
+ generate_audio,
19
+ load_image_model,
20
+ generate_image,
21
+ )
22
  from schemas import VoicePresets
23
+ from utils import audio_array_to_buffer, img_to_bytes
24
  from contextlib import asynccontextmanager
25
+ from typing import AsyncIterator, Callable, Awaitable, Annotated
26
  from uuid import uuid4
27
  import time
28
  from datetime import datetime, timezone
29
  import csv
30
+ from dependencies import get_urls_content, get_rag_content
31
+ from schemas import TextModelResponse, TextModelRequest
32
  import shutil, uuid
33
  from upload import save_file
34
  from rag import pdf_text_extractor, vector_service
35
+ from scalar_fastapi import get_scalar_api_reference
36
 
37
+ models = {}
38
 
 
39
 
40
+ @asynccontextmanager
41
  async def lifespan(_: FastAPI) -> AsyncIterator[None]:
42
+ # models["text2image"] = load_image_model()
43
  # models["text"]=load_text_model()
44
+ yield
45
+ models.clear()
46
+
47
 
48
  app = FastAPI(lifespan=lifespan)
49
 
50
  csv_header = [
51
+ "Request ID",
52
+ "Datetime",
53
+ "Endpoint Triggered",
54
+ "Client IP Address",
55
+ "Response Time",
56
+ "Status Code",
57
+ "Successful",
58
  ]
59
 
60
 
61
+ @app.middleware("http")
62
  async def monitor_service(
63
  req: Request, call_next: Callable[[Request], Awaitable[Response]]
64
+ ) -> Response:
65
+ request_id = uuid4().hex
66
  request_datetime = datetime.now(timezone.utc).isoformat()
67
  start_time = time.perf_counter()
68
  response: Response = await call_next(req)
69
+ response_time = round(time.perf_counter() - start_time, 4)
70
  response.headers["X-Response-Time"] = str(response_time)
71
+ response.headers["X-API-Request-ID"] = request_id
72
  with open("usage.csv", "a", newline="") as file:
73
  writer = csv.writer(file)
74
  if file.tell() == 0:
75
  writer.writerow(csv_header)
76
+ writer.writerow(
77
  [
78
  request_id,
79
  request_datetime,
 
87
  return response
88
 
89
 
90
+ # app = FastAPI()
 
 
91
  @app.get("/")
92
  def root_controller():
93
  return {"status": "healthy"}
94
 
95
+
96
+ @app.post("/generate/text")
97
+ async def serve_language_model_controller(
98
+ request: Request,
99
+ body: TextModelRequest,
100
+ urls_content: str = Depends(get_urls_content),
101
+ rag_content: str = Depends(get_rag_content),
102
+ ) -> TextModelResponse:
103
+ prompt = body.prompt + " " + urls_content + rag_content
104
+ output = generate_text(models["text"], prompt, body.temperature)
105
  return TextModelResponse(content=output, ip=request.client.host)
106
 
107
+
108
  @app.get("/logs")
109
  def get_logs():
110
  # return FileResponse("usage.csv", media_type='text/csv', filename="usage.csv")
 
116
  temp_file,
117
  media_type="text/csv",
118
  filename="logs.csv",
119
+ headers={"Content-Disposition": "attachment; filename=logs.csv"},
120
  )
121
 
122
+
123
  @app.get(
124
  "/generate/audio",
125
  responses={status.HTTP_200_OK: {"content": {"audio/wav": {}}}},
126
  response_class=StreamingResponse,
127
+ )
128
  def serve_text_to_audio_model_controller(
129
  prompt: str,
130
  preset: VoicePresets = "v2/en_speaker_1",
 
136
  )
137
 
138
 
139
+ @app.get(
140
+ "/generate/image",
141
+ responses={status.HTTP_200_OK: {"content": {"image/png": {}}}},
142
+ response_class=Response,
143
+ )
144
  def serve_text_to_image_model_controller(prompt: str):
145
  # pipe = load_image_model()
146
+ # output = generate_image(pipe, prompt)
147
  output = generate_image(models["text2image"], prompt)
148
+ return Response(content=img_to_bytes(output), media_type="image/png")
149
+
150
+
151
+ @app.get("/scalar")
152
+ def get_scalar_docs():
153
+ return get_scalar_api_reference(openapi_url=app.openapi_url, title=app.title)
154
+
155
 
156
  # @app.post("/upload")
157
  # async def file_upload_controller(
 
175
  @app.post("/upload")
176
  async def file_upload_controller(
177
  file: Annotated[UploadFile, File(description="A file read as UploadFile")],
178
+ bg_text_processor: BackgroundTasks,
179
  ):
180
+ ... # Raise an HTTPException if data upload is not a PDF file
181
  try:
182
  filepath = await save_file(file)
183
+ bg_text_processor.add_task(pdf_text_extractor, filepath)
184
+ bg_text_processor.add_task(
185
  vector_service.store_file_content_in_db,
186
  filepath.replace("pdf", "txt"),
187
  512,
rag/repository.py CHANGED
@@ -3,27 +3,26 @@ from qdrant_client import AsyncQdrantClient
3
  from qdrant_client.http import models
4
  from qdrant_client.http.models import ScoredPoint
5
 
6
- class VectorRepository:
7
- def __init__(self, host: str = "https://ahmed-eisa-qdrant-db.hf.space", port: int = 6333) -> None:
 
 
 
8
  # self.db_client = AsyncQdrantClient(host=host, port=port)
9
  self.db_client = AsyncQdrantClient(
10
- url="https://e8342d34-1b50-48e3-95e2-d4eacd0755eb.us-east4-0.gcp.cloud.qdrant.io:6333",
11
- api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.Q6rLdYDzVyr10B4AdYJHcPp9pCqWG7yhQ-NNmfWZqg8",
12
- )
13
- async def create_collection(self, collection_name: str, size: int) -> bool:
14
- vectors_config = models.VectorParams(
15
- size=size, distance=models.Distance.COSINE
16
  )
 
 
 
17
  response = await self.db_client.get_collections()
18
 
19
  collection_exists = any(
20
- collection.name == collection_name
21
- for collection in response.collections
22
  )
23
- if collection_exists:
24
- logger.debug(
25
- f"Collection {collection_name} already exists - recreating it"
26
- )
27
  await self.db_client.delete_collection(collection_name)
28
  return await self.db_client.create_collection(
29
  collection_name,
@@ -73,15 +72,15 @@ class VectorRepository:
73
  collection_name: str,
74
  query_vector: list[float],
75
  retrieval_limit: int,
76
- score_threshold: float,
77
  ) -> list[ScoredPoint]:
78
  logger.debug(
79
  f"Searching for relevant items in the {collection_name} collection"
80
  )
81
  response = await self.db_client.query_points(
82
  collection_name=collection_name,
83
- vector=query_vector,
84
  limit=retrieval_limit,
85
  score_threshold=score_threshold,
86
  )
87
- return response.points
 
3
  from qdrant_client.http import models
4
  from qdrant_client.http.models import ScoredPoint
5
 
6
+
7
+ class VectorRepository:
8
+ def __init__(
9
+ self, host: str = "https://ahmed-eisa-qdrant-db.hf.space", port: int = 6333
10
+ ) -> None:
11
  # self.db_client = AsyncQdrantClient(host=host, port=port)
12
  self.db_client = AsyncQdrantClient(
13
+ url="https://e8342d34-1b50-48e3-95e2-d4eacd0755eb.us-east4-0.gcp.cloud.qdrant.io:6333",
14
+ api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.Q6rLdYDzVyr10B4AdYJHcPp9pCqWG7yhQ-NNmfWZqg8",
 
 
 
 
15
  )
16
+
17
+ async def create_collection(self, collection_name: str, size: int) -> bool:
18
+ vectors_config = models.VectorParams(size=size, distance=models.Distance.COSINE)
19
  response = await self.db_client.get_collections()
20
 
21
  collection_exists = any(
22
+ collection.name == collection_name for collection in response.collections
 
23
  )
24
+ if collection_exists:
25
+ logger.debug(f"Collection {collection_name} already exists - recreating it")
 
 
26
  await self.db_client.delete_collection(collection_name)
27
  return await self.db_client.create_collection(
28
  collection_name,
 
72
  collection_name: str,
73
  query_vector: list[float],
74
  retrieval_limit: int,
75
+ score_threshold: float,
76
  ) -> list[ScoredPoint]:
77
  logger.debug(
78
  f"Searching for relevant items in the {collection_name} collection"
79
  )
80
  response = await self.db_client.query_points(
81
  collection_name=collection_name,
82
+ query=query_vector,
83
  limit=retrieval_limit,
84
  score_threshold=score_threshold,
85
  )
86
+ return response.points