File size: 2,925 Bytes
e7e9247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# rag_system.py
from typing import List, Dict
import chromadb
from pdf_loader import load_pdf
from optimal_chunker import OptimalChunker
from embeddings import embed_texts
from langchain_groq import ChatGroq
from dotenv import load_dotenv
import os

load_dotenv()


class RAGPipeline:
    def __init__(self, persist_dir: str = "./chroma_db", collection_name: str = "pdf_docs"):
        # Vector DB (Chroma 1.x new API)
        self.client = chromadb.PersistentClient(path=persist_dir)
        self.col = self.client.get_or_create_collection(name=collection_name)

        # Chunker
        self.chunker = OptimalChunker()

        # LLM (Groq)
        self.llm = ChatGroq(
            model="llama3-8b-8192",
            temperature=0.0,
            api_key=os.getenv("GROQ_API_KEY"),
        )

    # 1) Load  2) Chunk  3) Embed  4) Upsert to Chroma
    def index_document(self, pdf_path: str, doc_id_prefix: str = "doc") -> Dict:
        docs = load_pdf(pdf_path)
        text = "\n\n".join(d.page_content for d in docs)

        summary = self.chunker.fit_on_text(text)
        chunks = self.chunker.transform()

        # embeddings: list[list[float]]
        vectors = embed_texts(chunks)
        ids = [f"{doc_id_prefix}-{i}" for i in range(len(chunks))]

        self.col.add(
            ids=ids,
            documents=chunks,
            embeddings=vectors,
            metadatas=[{"source": pdf_path, "chunk": i} for i in range(len(chunks))],
        )
        return {"chunks_indexed": len(chunks), "best_strategy": summary}

    # 5) Retrieve  6) Ask LLM
    def query(self, question: str, k: int = 4) -> Dict:
        results = self.col.query(query_texts=[question], n_results=k)
        chunks: List[str] = results["documents"][0] if results.get("documents") else []

        context = "\n\n".join(chunks)
        prompt = f"""You are an extraction assistant. Use ONLY the Context to answer.
Rules:
- If the answer is explicitly present in Context, return that substring EXACTLY.
- Do not paraphrase. Do not add words. Return a verbatim span from Context.
- If the answer is not in Context, reply exactly: I don't know

Question: {question}

Context:
{context}

Answer (verbatim from Context):"""
        resp = self.llm.invoke(prompt)
        answer = resp.content.strip()

        # Fallback if the model still hedges
        if (not answer or answer.lower().startswith("i don't know")) and context.strip():
            answer = chunks[0] if chunks else "I don't know"

        return {
            "answer": answer,
            "used_chunks": len(chunks),
            "context_preview": context[:500],
        }


if __name__ == "__main__":
    rag = RAGPipeline()
    info = rag.index_document("sample.pdf")  # ensure day3/sample.pdf exists
    print("Indexed:", info)

    out = rag.query("What text does the PDF contain?")
    print("Answer:", out["answer"])
    print("Used chunks:", out["used_chunks"])