ahmed-eisa commited on
Commit
0040dff
·
1 Parent(s): 43e97e3

added qudrant inside the container

Browse files
Files changed (7) hide show
  1. Dockerfile +7 -1
  2. dependencies.py +11 -1
  3. main.py +35 -8
  4. rag/__init__.py +4 -0
  5. rag/repository.py +84 -0
  6. rag/service.py +31 -0
  7. rag/transform.py +2 -2
Dockerfile CHANGED
@@ -4,12 +4,18 @@ RUN useradd -m -u 1000 user
4
  USER user
5
  ENV PATH="/home/user/.local/bin:$PATH"
6
 
 
 
 
 
7
  WORKDIR /app
8
 
9
  COPY --chown=user * ./
10
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
 
12
  COPY --chown=user . /app
13
- CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
 
 
14
 
15
 
 
4
  USER user
5
  ENV PATH="/home/user/.local/bin:$PATH"
6
 
7
+ # Install Qdrant
8
+ RUN pip install qdrant-client && \
9
+ apt-get update && apt-get install -y qdrant
10
+
11
  WORKDIR /app
12
 
13
  COPY --chown=user * ./
14
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
15
 
16
  COPY --chown=user . /app
17
+ # CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
18
+ CMD ["sh", "-c", "qdrant & uvicorn main:app --host 0.0.0.0 --port 7860"]
19
+
20
 
21
 
dependencies.py CHANGED
@@ -12,4 +12,14 @@ async def get_urls_content(body: TextModelRequest ) -> str:
12
  return urls_content
13
  except Exception as e:
14
  logger.warning(f"Failed to fetch one or several URls - Error: {e}")
15
- return ""
 
 
 
 
 
 
 
 
 
 
 
12
  return urls_content
13
  except Exception as e:
14
  logger.warning(f"Failed to fetch one or several URls - Error: {e}")
15
+ return ""
16
+
17
+ async def get_rag_content(body: TextModelRequest ) -> str:
18
+ rag_content = await vector_service.search(
19
+ "knowledgebase", embed(body.prompt), 3, 0.7
20
+ )
21
+ rag_content_str = "\n".join(
22
+ [c.payload["original_text"] for c in rag_content]
23
+ )
24
+
25
+ return rag_content_str
main.py CHANGED
@@ -1,5 +1,5 @@
1
  # main.py
2
- from fastapi import FastAPI,status,Response,Request,Depends,HTTPException,UploadFile, File
3
  from fastapi.responses import StreamingResponse,FileResponse
4
  from models import load_text_model,generate_text,load_audio_model,generate_audio,load_image_model, generate_image
5
  from schemas import VoicePresets
@@ -14,6 +14,8 @@ from dependencies import get_urls_content
14
  from schemas import TextModelResponse,TextModelRequest
15
  import shutil, uuid
16
  from upload import save_file
 
 
17
 
18
  models = {}
19
 
@@ -115,17 +117,42 @@ def serve_text_to_image_model_controller(prompt: str):
115
  output = generate_image(models["text2image"], prompt)
116
  return Response(content=img_to_bytes(output), media_type="image/png")
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  @app.post("/upload")
119
  async def file_upload_controller(
120
- file: Annotated[UploadFile, File(description="Uploaded PDF documents")]
 
121
  ):
122
- if file.content_type != "application/pdf":
123
- raise HTTPException(
124
- detail=f"Only uploading PDF documents are supported",
125
- status_code=status.HTTP_400_BAD_REQUEST,
126
- )
127
  try:
128
- await save_file(file)
 
 
 
 
 
 
 
 
 
129
  except Exception as e:
130
  raise HTTPException(
131
  detail=f"An error occurred while saving file - Error: {e}",
 
1
  # main.py
2
+ from fastapi import FastAPI,status,Response,Request,Depends,HTTPException,UploadFile, File,BackgroundTasks
3
  from fastapi.responses import StreamingResponse,FileResponse
4
  from models import load_text_model,generate_text,load_audio_model,generate_audio,load_image_model, generate_image
5
  from schemas import VoicePresets
 
14
  from schemas import TextModelResponse,TextModelRequest
15
  import shutil, uuid
16
  from upload import save_file
17
+ from rag import pdf_text_extractor, vector_service
18
+
19
 
20
  models = {}
21
 
 
117
  output = generate_image(models["text2image"], prompt)
118
  return Response(content=img_to_bytes(output), media_type="image/png")
119
 
120
+ # @app.post("/upload")
121
+ # async def file_upload_controller(
122
+ # file: Annotated[UploadFile, File(description="Uploaded PDF documents")]
123
+ # ):
124
+ # if file.content_type != "application/pdf":
125
+ # raise HTTPException(
126
+ # detail=f"Only uploading PDF documents are supported",
127
+ # status_code=status.HTTP_400_BAD_REQUEST,
128
+ # )
129
+ # try:
130
+ # await save_file(file)
131
+ # except Exception as e:
132
+ # raise HTTPException(
133
+ # detail=f"An error occurred while saving file - Error: {e}",
134
+ # status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
135
+ # )
136
+ # return {"filename": file.filename, "message": "File uploaded successfully"}
137
+
138
+
139
  @app.post("/upload")
140
  async def file_upload_controller(
141
+ file: Annotated[UploadFile, File(description="A file read as UploadFile")],
142
+ bg_text_processor: BackgroundTasks,
143
  ):
144
+ ... # Raise an HTTPException if data upload is not a PDF file
 
 
 
 
145
  try:
146
+ filepath = await save_file(file)
147
+ bg_text_processor.add_task(pdf_text_extractor, filepath)
148
+ bg_text_processor.add_task(
149
+ vector_service.store_file_content_in_db,
150
+ filepath.replace("pdf", "txt"),
151
+ 512,
152
+ "knowledgebase",
153
+ 768,
154
+ )
155
+
156
  except Exception as e:
157
  raise HTTPException(
158
  detail=f"An error occurred while saving file - Error: {e}",
rag/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .extractor import *
2
+ from .repository import *
3
+ from .service import *
4
+ from .transform import *
rag/repository.py CHANGED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from loguru import logger
2
+ from qdrant_client import AsyncQdrantClient
3
+ from qdrant_client.http import models
4
+ from qdrant_client.http.models import ScoredPoint
5
+
6
+ class VectorRepository:
7
+ def __init__(self, host: str = "https://ahmed-eisa-qdrant-db.hf.space", port: int = 6333) -> None:
8
+ self.db_client = AsyncQdrantClient(host=host, port=port)
9
+
10
+ async def create_collection(self, collection_name: str, size: int) -> bool:
11
+ vectors_config = models.VectorParams(
12
+ size=size, distance=models.Distance.COSINE
13
+ )
14
+ response = await self.db_client.get_collections()
15
+
16
+ collection_exists = any(
17
+ collection.name == collection_name
18
+ for collection in response.collections
19
+ )
20
+ if collection_exists:
21
+ logger.debug(
22
+ f"Collection {collection_name} already exists - recreating it"
23
+ )
24
+ await self.db_client.delete_collection(collection_name)
25
+ return await self.db_client.create_collection(
26
+ collection_name,
27
+ vectors_config=vectors_config,
28
+ )
29
+
30
+ logger.debug(f"Creating collection {collection_name}")
31
+ return await self.db_client.create_collection(
32
+ collection_name=collection_name,
33
+ vectors_config=models.VectorParams(
34
+ size=size, distance=models.Distance.COSINE
35
+ ),
36
+ )
37
+
38
+ async def delete_collection(self, name: str) -> bool:
39
+ logger.debug(f"Deleting collection {name}")
40
+ return await self.db_client.delete_collection(name)
41
+
42
+ async def create(
43
+ self,
44
+ collection_name: str,
45
+ embedding_vector: list[float],
46
+ original_text: str,
47
+ source: str,
48
+ ) -> None:
49
+ response = await self.db_client.count(collection_name=collection_name)
50
+ logger.debug(
51
+ f"Creating a new vector with ID {response.count} "
52
+ f"inside the {collection_name}"
53
+ )
54
+ await self.db_client.upsert(
55
+ collection_name=collection_name,
56
+ points=[
57
+ models.PointStruct(
58
+ id=response.count,
59
+ vector=embedding_vector,
60
+ payload={
61
+ "source": source,
62
+ "original_text": original_text,
63
+ },
64
+ )
65
+ ],
66
+ )
67
+
68
+ async def search(
69
+ self,
70
+ collection_name: str,
71
+ query_vector: list[float],
72
+ retrieval_limit: int,
73
+ score_threshold: float,
74
+ ) -> list[ScoredPoint]:
75
+ logger.debug(
76
+ f"Searching for relevant items in the {collection_name} collection"
77
+ )
78
+ response = await self.db_client.query_points(
79
+ collection_name=collection_name,
80
+ query_vector=query_vector,
81
+ limit=retrieval_limit,
82
+ score_threshold=score_threshold,
83
+ )
84
+ return response.points
rag/service.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from loguru import logger
4
+ from .repository import VectorRepository
5
+ from .transform import clean, embed, load
6
+
7
+
8
+ class VectorService(VectorRepository):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ async def store_file_content_in_db(
13
+ self,
14
+ filepath: str,
15
+ chunk_size: int = 512,
16
+ collection_name: str = "knowledgebase",
17
+ collection_size: int = 768,
18
+ ) -> None:
19
+ await self.create_collection(collection_name, collection_size)
20
+ logger.debug(f"Inserting {filepath} content into database")
21
+ async for chunk in load(filepath, chunk_size):
22
+ logger.debug(f"Inserting '{chunk[0:20]}...' into database")
23
+
24
+ embedding_vector = embed(clean(chunk))
25
+ filename = os.path.basename(filepath)
26
+ await self.create(
27
+ collection_name, embedding_vector, chunk, filename
28
+ )
29
+
30
+
31
+ vector_service = VectorService()
rag/transform.py CHANGED
@@ -10,9 +10,9 @@ embedder = AutoModel.from_pretrained(
10
  "jinaai/jina-embeddings-v2-base-en", trust_remote_code=True
11
  )
12
 
13
- async def load(filepath: str) -> AsyncGenerator[str, Any]:
14
  async with aiofiles.open(filepath, "r", encoding="utf-8") as f:
15
- while chunk := await f.read(DEFAULT_CHUNK_SIZE):
16
  yield chunk
17
 
18
  def clean(text: str) -> str:
 
10
  "jinaai/jina-embeddings-v2-base-en", trust_remote_code=True
11
  )
12
 
13
+ async def load(filepath: str, chunksize:int=DEFAULT_CHUNK_SIZE) -> AsyncGenerator[str, Any]:
14
  async with aiofiles.open(filepath, "r", encoding="utf-8") as f:
15
+ while chunk := await f.read(chunksize):
16
  yield chunk
17
 
18
  def clean(text: str) -> str: