Spaces:
Sleeping
Sleeping
File size: 3,990 Bytes
42cabf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import chromadb
from sentence_transformers import SentenceTransformer
from typing import List, Dict
import os
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.schema import HumanMessage
class RAGAgent:
def __init__(self):
self.embedder = SentenceTransformer('all-mpnet-base-v2')
self.llm = ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
temperature=0.3,
google_api_key=os.getenv("GOOGLE_API_KEY")
)
persist_directory = "./chroma_agnews/"
self.chroma_client = chromadb.PersistentClient(path=persist_directory)
self.collection = self.chroma_client.get_collection(name="ag_news")
print(f"Connected to ChromaDB with {self.collection.count()} documents")
def search(self, query: str, top_k: int = 5) -> Dict:
"""Search for relevant chunks and answer the question."""
# Handle empty query base case scenario
if not query or query.strip() == "":
query = "news"
# Embed the query
query_embedding = self.embedder.encode(query).tolist()
# Query the collection
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=min(top_k, self.collection.count()),
include=["documents", "metadatas", "distances"]
)
# Format results
formatted_results = []
context_chunks = []
if results['ids'] and len(results['ids'][0]) > 0:
for i in range(len(results['ids'][0])):
# Calculate similarity score
distance = results['distances'][0][i] if results['distances'] else 0
similarity_score = 1 - (distance / 2)
doc_text = results['documents'][0][i]
formatted_results.append({
'text': doc_text,
'category': results['metadatas'][0][i].get('label_text', 'Unknown'),
'score': similarity_score
})
context_chunks.append(doc_text)
# Generate answer based on retrieved chunks
answer = self._generate_answer(query, context_chunks)
return {
"answer": answer,
"chunks": formatted_results,
"query": query
}
else:
return {
"answer": "No relevant information found for your question.",
"chunks": [],
"query": query
}
def _generate_answer(self, query: str, chunks: List[str]) -> str:
"""Generate answer based on retrieved chunks."""
# Combine chunks as context
context = "\n\n".join([f"[{i+1}] {chunk}" for i, chunk in enumerate(chunks)])
# Create prompt
prompt = f"""Based on the following information, answer the question.
Context:
{context}
Question: {query}
Answer:"""
# Generate answer using Gemini
response = self.llm.invoke([HumanMessage(content=prompt)])
return response.content
def get_collection_stats(self) -> Dict:
"""Get statistics about the collection."""
count = self.collection.count()
if count > 0:
sample = self.collection.get(
limit=min(100, count),
include=["metadatas"]
)
categories = {}
for metadata in sample['metadatas']:
cat = metadata.get('label_text', 'Unknown')
categories[cat] = categories.get(cat, 0) + 1
return {
"total_documents": count,
"categories": categories
}
else:
return {
"total_documents": 0,
"categories": {}
} |