import chromadb from sentence_transformers import SentenceTransformer from typing import List, Dict import os from langchain_google_genai import ChatGoogleGenerativeAI from langchain.schema import HumanMessage class RAGAgent: def __init__(self): self.embedder = SentenceTransformer('all-mpnet-base-v2') self.llm = ChatGoogleGenerativeAI( model="gemini-1.5-flash", temperature=0.3, google_api_key=os.getenv("GOOGLE_API_KEY") ) persist_directory = "./chroma_agnews/" self.chroma_client = chromadb.PersistentClient(path=persist_directory) self.collection = self.chroma_client.get_collection(name="ag_news") print(f"Connected to ChromaDB with {self.collection.count()} documents") def search(self, query: str, top_k: int = 5) -> Dict: """Search for relevant chunks and answer the question.""" # Handle empty query base case scenario if not query or query.strip() == "": query = "news" # Embed the query query_embedding = self.embedder.encode(query).tolist() # Query the collection results = self.collection.query( query_embeddings=[query_embedding], n_results=min(top_k, self.collection.count()), include=["documents", "metadatas", "distances"] ) # Format results formatted_results = [] context_chunks = [] if results['ids'] and len(results['ids'][0]) > 0: for i in range(len(results['ids'][0])): # Calculate similarity score distance = results['distances'][0][i] if results['distances'] else 0 similarity_score = 1 - (distance / 2) doc_text = results['documents'][0][i] formatted_results.append({ 'text': doc_text, 'category': results['metadatas'][0][i].get('label_text', 'Unknown'), 'score': similarity_score }) context_chunks.append(doc_text) # Generate answer based on retrieved chunks answer = self._generate_answer(query, context_chunks) return { "answer": answer, "chunks": formatted_results, "query": query } else: return { "answer": "No relevant information found for your question.", "chunks": [], "query": query } def _generate_answer(self, query: str, chunks: List[str]) -> str: """Generate answer based on retrieved chunks.""" # Combine chunks as context context = "\n\n".join([f"[{i+1}] {chunk}" for i, chunk in enumerate(chunks)]) # Create prompt prompt = f"""Based on the following information, answer the question. Context: {context} Question: {query} Answer:""" # Generate answer using Gemini response = self.llm.invoke([HumanMessage(content=prompt)]) return response.content def get_collection_stats(self) -> Dict: """Get statistics about the collection.""" count = self.collection.count() if count > 0: sample = self.collection.get( limit=min(100, count), include=["metadatas"] ) categories = {} for metadata in sample['metadatas']: cat = metadata.get('label_text', 'Unknown') categories[cat] = categories.get(cat, 0) + 1 return { "total_documents": count, "categories": categories } else: return { "total_documents": 0, "categories": {} }