File size: 3,990 Bytes
42cabf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import chromadb
from sentence_transformers import SentenceTransformer
from typing import List, Dict
import os
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.schema import HumanMessage

class RAGAgent:
    def __init__(self):
        self.embedder = SentenceTransformer('all-mpnet-base-v2')
        self.llm = ChatGoogleGenerativeAI(
            model="gemini-1.5-flash",
            temperature=0.3,
            google_api_key=os.getenv("GOOGLE_API_KEY")
        )
        
        persist_directory = "./chroma_agnews/"
        self.chroma_client = chromadb.PersistentClient(path=persist_directory)
        
        self.collection = self.chroma_client.get_collection(name="ag_news")
        print(f"Connected to ChromaDB with {self.collection.count()} documents")
    
    def search(self, query: str, top_k: int = 5) -> Dict:
        """Search for relevant chunks and answer the question."""
        # Handle empty query base case scenario
        if not query or query.strip() == "":
            query = "news"
        
        # Embed the query
        query_embedding = self.embedder.encode(query).tolist()
        
        # Query the collection
        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=min(top_k, self.collection.count()),
            include=["documents", "metadatas", "distances"]
        )
        
        # Format results
        formatted_results = []
        context_chunks = []
        
        if results['ids'] and len(results['ids'][0]) > 0:
            for i in range(len(results['ids'][0])):
                # Calculate similarity score
                distance = results['distances'][0][i] if results['distances'] else 0
                similarity_score = 1 - (distance / 2)
                
                doc_text = results['documents'][0][i]
                
                formatted_results.append({
                    'text': doc_text,
                    'category': results['metadatas'][0][i].get('label_text', 'Unknown'),
                    'score': similarity_score
                })
                
                context_chunks.append(doc_text)
            
            # Generate answer based on retrieved chunks
            answer = self._generate_answer(query, context_chunks)
            
            return {
                "answer": answer,
                "chunks": formatted_results,
                "query": query
            }
        else:
            return {
                "answer": "No relevant information found for your question.",
                "chunks": [],
                "query": query
            }
    
    def _generate_answer(self, query: str, chunks: List[str]) -> str:
        """Generate answer based on retrieved chunks."""
        # Combine chunks as context
        context = "\n\n".join([f"[{i+1}] {chunk}" for i, chunk in enumerate(chunks)])
        
        # Create prompt
        prompt = f"""Based on the following information, answer the question.

Context:
{context}

Question: {query}

Answer:"""
        
        # Generate answer using Gemini
        response = self.llm.invoke([HumanMessage(content=prompt)])
        return response.content
    
    def get_collection_stats(self) -> Dict:
        """Get statistics about the collection."""
        count = self.collection.count()
        
        if count > 0:
            sample = self.collection.get(
                limit=min(100, count),
                include=["metadatas"]
            )
            categories = {}
            
            for metadata in sample['metadatas']:
                cat = metadata.get('label_text', 'Unknown')
                categories[cat] = categories.get(cat, 0) + 1
            
            return {
                "total_documents": count,
                "categories": categories
            }
        else:
            return {
                "total_documents": 0,
                "categories": {}
            }