from google import genai from google.genai import types import numpy as np from concurrent.futures import ThreadPoolExecutor, as_completed from langchain.text_splitter import RecursiveCharacterTextSplitter import os from dotenv import load_dotenv load_dotenv() client = genai.Client(api_key=os.getenv("api_key")) class RAG: def __init__(self): self.CHUNK_SIZE = 1024; self.CHUNK_OVERLAP = 75; self.MAX_BATCH_SIZE = 100; self.MODEL = "text-embedding-004"; self.TASK_TYPE = "SEMANTIC_SIMILARITY"; def split_text(self,text): try: return RecursiveCharacterTextSplitter( chunk_size=self.CHUNK_SIZE, chunk_overlap=self.CHUNK_OVERLAP, separators=["\n\n", "\n", ".", "!", "?", "。", " ", ""] ).split_text(text) except Exception as e: raise ValueError(f"an error occured: {e}") def _embed_batch(self, chunk_batch, task_type): response = client.models.embed_content( model=self.MODEL, contents=chunk_batch, config=types.EmbedContentConfig(task_type=task_type) ) return [embedding.values for embedding in response.embeddings] def generate_embedding(self, text, task_type=None): try: if not task_type: task_type = self.TASK_TYPE chunks = self.split_text(text) batches = [ chunks[i:i + self.MAX_BATCH_SIZE] for i in range(0, len(chunks), self.MAX_BATCH_SIZE) ] embeddings = [] with ThreadPoolExecutor(max_workers=50) as executor: futures = { executor.submit(self._embed_batch, batch, task_type): batch for batch in batches } for future in as_completed(futures): result = future.result() embeddings.extend(result) return {"embeddings": embeddings, "chunks": chunks}, 200 except Exception as e: return {"an error occurred": f"{e}"}, 500 rag = RAG()