Spaces:
Sleeping
Sleeping
import chromadb | |
from sentence_transformers import SentenceTransformer | |
from typing import List, Dict | |
import os | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain.schema import HumanMessage | |
class RAGAgent: | |
def __init__(self): | |
self.embedder = SentenceTransformer('all-mpnet-base-v2') | |
self.llm = ChatGoogleGenerativeAI( | |
model="gemini-1.5-flash", | |
temperature=0.3, | |
google_api_key=os.getenv("GOOGLE_API_KEY") | |
) | |
persist_directory = "./chroma_agnews/" | |
self.chroma_client = chromadb.PersistentClient(path=persist_directory) | |
self.collection = self.chroma_client.get_collection(name="ag_news") | |
print(f"Connected to ChromaDB with {self.collection.count()} documents") | |
def search(self, query: str, top_k: int = 5) -> Dict: | |
"""Search for relevant chunks and answer the question.""" | |
# Handle empty query base case scenario | |
if not query or query.strip() == "": | |
query = "news" | |
# Embed the query | |
query_embedding = self.embedder.encode(query).tolist() | |
# Query the collection | |
results = self.collection.query( | |
query_embeddings=[query_embedding], | |
n_results=min(top_k, self.collection.count()), | |
include=["documents", "metadatas", "distances"] | |
) | |
# Format results | |
formatted_results = [] | |
context_chunks = [] | |
if results['ids'] and len(results['ids'][0]) > 0: | |
for i in range(len(results['ids'][0])): | |
# Calculate similarity score | |
distance = results['distances'][0][i] if results['distances'] else 0 | |
similarity_score = 1 - (distance / 2) | |
doc_text = results['documents'][0][i] | |
formatted_results.append({ | |
'text': doc_text, | |
'category': results['metadatas'][0][i].get('label_text', 'Unknown'), | |
'score': similarity_score | |
}) | |
context_chunks.append(doc_text) | |
# Generate answer based on retrieved chunks | |
answer = self._generate_answer(query, context_chunks) | |
return { | |
"answer": answer, | |
"chunks": formatted_results, | |
"query": query | |
} | |
else: | |
return { | |
"answer": "No relevant information found for your question.", | |
"chunks": [], | |
"query": query | |
} | |
def _generate_answer(self, query: str, chunks: List[str]) -> str: | |
"""Generate answer based on retrieved chunks.""" | |
# Combine chunks as context | |
context = "\n\n".join([f"[{i+1}] {chunk}" for i, chunk in enumerate(chunks)]) | |
# Create prompt | |
prompt = f"""Based on the following information, answer the question. | |
Context: | |
{context} | |
Question: {query} | |
Answer:""" | |
# Generate answer using Gemini | |
response = self.llm.invoke([HumanMessage(content=prompt)]) | |
return response.content | |
def get_collection_stats(self) -> Dict: | |
"""Get statistics about the collection.""" | |
count = self.collection.count() | |
if count > 0: | |
sample = self.collection.get( | |
limit=min(100, count), | |
include=["metadatas"] | |
) | |
categories = {} | |
for metadata in sample['metadatas']: | |
cat = metadata.get('label_text', 'Unknown') | |
categories[cat] = categories.get(cat, 0) + 1 | |
return { | |
"total_documents": count, | |
"categories": categories | |
} | |
else: | |
return { | |
"total_documents": 0, | |
"categories": {} | |
} |