yasirme's picture
Update rag/RAG.py
106e2f6 verified
from google import genai
from google.genai import types
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain.text_splitter import RecursiveCharacterTextSplitter
import os
from dotenv import load_dotenv
load_dotenv()
client = genai.Client(api_key=os.getenv("api_key"))
class RAG:
def __init__(self):
self.CHUNK_SIZE = 1024;
self.CHUNK_OVERLAP = 75;
self.MAX_BATCH_SIZE = 100;
self.MODEL = "text-embedding-004";
self.TASK_TYPE = "SEMANTIC_SIMILARITY";
def split_text(self,text):
try:
return RecursiveCharacterTextSplitter(
chunk_size=self.CHUNK_SIZE,
chunk_overlap=self.CHUNK_OVERLAP,
separators=["\n\n", "\n", ".", "!", "?", "。", " ", ""]
).split_text(text)
except Exception as e:
raise ValueError(f"an error occured: {e}")
def _embed_batch(self, chunk_batch, task_type):
response = client.models.embed_content(
model=self.MODEL,
contents=chunk_batch,
config=types.EmbedContentConfig(task_type=task_type)
)
return [embedding.values for embedding in response.embeddings]
def generate_embedding(self, text, task_type=None):
try:
if not task_type:
task_type = self.TASK_TYPE
chunks = self.split_text(text)
batches = [
chunks[i:i + self.MAX_BATCH_SIZE]
for i in range(0, len(chunks), self.MAX_BATCH_SIZE)
]
embeddings = []
with ThreadPoolExecutor(max_workers=50) as executor:
futures = {
executor.submit(self._embed_batch, batch, task_type): batch
for batch in batches
}
for future in as_completed(futures):
result = future.result()
embeddings.extend(result)
return {"embeddings": embeddings, "chunks": chunks}, 200
except Exception as e:
return {"an error occurred": f"{e}"}, 500
rag = RAG()