Spaces:
Sleeping
Sleeping
Hamid Omarov
commited on
Commit
·
e7e9247
1
Parent(s):
b834d82
HF Space app + minimal pipeline code (no secrets)
Browse files- day3/chunking_test.py +83 -0
- day3/embeddings.py +12 -0
- day3/optimal_chunker.py +113 -0
- day3/pdf_loader.py +13 -0
- day3/rag_system.py +89 -0
- day3/vector_store.py +15 -0
day3/chunking_test.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# chunking_test.py
|
2 |
+
from langchain.text_splitter import (
|
3 |
+
CharacterTextSplitter,
|
4 |
+
RecursiveCharacterTextSplitter,
|
5 |
+
TokenTextSplitter,
|
6 |
+
)
|
7 |
+
from pdf_loader import load_pdf
|
8 |
+
|
9 |
+
# GPT/Copilot: "utility to flatten pages into a single string"
|
10 |
+
def docs_to_text(docs):
|
11 |
+
return "\n\n".join([d.page_content for d in docs])
|
12 |
+
|
13 |
+
# GPT/Copilot: "run a splitter on text and return list[str]"
|
14 |
+
def split_text(text, splitter):
|
15 |
+
return splitter.split_text(text)
|
16 |
+
|
17 |
+
# GPT/Copilot: "compute metrics: chunk count, average size (chars or tokens), and overlap setting"
|
18 |
+
def compute_metrics(chunks, unit="chars", chunk_size=None, chunk_overlap=None):
|
19 |
+
if unit == "chars":
|
20 |
+
sizes = [len(c) for c in chunks]
|
21 |
+
avg = sum(sizes) / len(sizes) if sizes else 0
|
22 |
+
return {
|
23 |
+
"chunks": len(chunks),
|
24 |
+
"avg_chars": round(avg, 1),
|
25 |
+
"overlap": chunk_overlap,
|
26 |
+
}
|
27 |
+
else:
|
28 |
+
# token mode will pass unit="tokens" and precomputed token sizes if needed
|
29 |
+
sizes = [len(c) for c in chunks] # placeholder, we’ll report char length anyway
|
30 |
+
avg = sum(sizes) / len(sizes) if sizes else 0
|
31 |
+
return {
|
32 |
+
"chunks": len(chunks),
|
33 |
+
"avg_len_str": round(avg, 1),
|
34 |
+
"overlap": chunk_overlap,
|
35 |
+
}
|
36 |
+
|
37 |
+
def run_comparison(pdf_path="sample.pdf"):
|
38 |
+
docs = load_pdf(pdf_path)
|
39 |
+
text = docs_to_text(docs)
|
40 |
+
|
41 |
+
# 1) Fixed size (CharacterTextSplitter)
|
42 |
+
fixed = CharacterTextSplitter(
|
43 |
+
chunk_size=800, chunk_overlap=100, separator="\n"
|
44 |
+
)
|
45 |
+
fixed_chunks = split_text(text, fixed)
|
46 |
+
fixed_metrics = compute_metrics(
|
47 |
+
fixed_chunks, unit="chars", chunk_size=800, chunk_overlap=100
|
48 |
+
)
|
49 |
+
|
50 |
+
# 2) Recursive (RecursiveCharacterTextSplitter)
|
51 |
+
recursive = RecursiveCharacterTextSplitter(
|
52 |
+
chunk_size=800,
|
53 |
+
chunk_overlap=100,
|
54 |
+
separators=["\n\n", "\n", " ", ""],
|
55 |
+
)
|
56 |
+
recursive_chunks = split_text(text, recursive)
|
57 |
+
recursive_metrics = compute_metrics(
|
58 |
+
recursive_chunks, unit="chars", chunk_size=800, chunk_overlap=100
|
59 |
+
)
|
60 |
+
|
61 |
+
# 3) Token-based (TokenTextSplitter)
|
62 |
+
token = TokenTextSplitter(
|
63 |
+
chunk_size=512,
|
64 |
+
chunk_overlap=64,
|
65 |
+
)
|
66 |
+
token_chunks = split_text(text, token)
|
67 |
+
token_metrics = compute_metrics(
|
68 |
+
token_chunks, unit="tokens", chunk_size=512, chunk_overlap=64
|
69 |
+
)
|
70 |
+
|
71 |
+
print("=== Chunking Comparison ===")
|
72 |
+
print("Fixed (chars): ", fixed_metrics)
|
73 |
+
print("Recursive (chars):", recursive_metrics)
|
74 |
+
print("Token-based: ", token_metrics)
|
75 |
+
|
76 |
+
# Optional: show first chunk samples for sanity
|
77 |
+
print("\n--- Sample Chunks ---")
|
78 |
+
for name, chunks in [("Fixed", fixed_chunks), ("Recursive", recursive_chunks), ("Token", token_chunks)]:
|
79 |
+
preview = chunks[0][:200].replace("\n", " ") + ("..." if len(chunks[0]) > 200 else "")
|
80 |
+
print(f"{name} #1 →", preview)
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
run_comparison("sample.pdf")
|
day3/embeddings.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
|
4 |
+
_embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
5 |
+
|
6 |
+
def embed_texts(texts: List[str]) -> List[List[float]]:
|
7 |
+
# Return as Python lists of floats (Chroma-compatible)
|
8 |
+
return _embedder.encode(texts, convert_to_numpy=True).tolist()
|
9 |
+
|
10 |
+
def create_embeddings(chunks: List[str]) -> Dict:
|
11 |
+
vectors = embed_texts(chunks)
|
12 |
+
return {"embeddings": vectors, "count": len(vectors)}
|
day3/optimal_chunker.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# optimal_chunker.py
|
2 |
+
from typing import Dict, List, Tuple
|
3 |
+
from statistics import mean
|
4 |
+
from langchain.text_splitter import (
|
5 |
+
CharacterTextSplitter,
|
6 |
+
RecursiveCharacterTextSplitter,
|
7 |
+
TokenTextSplitter,
|
8 |
+
)
|
9 |
+
from pdf_loader import load_pdf
|
10 |
+
|
11 |
+
# --- Helpers ---
|
12 |
+
def docs_to_text(docs) -> str:
|
13 |
+
return "\n\n".join([d.page_content for d in docs])
|
14 |
+
|
15 |
+
def run_splitter(text: str, splitter) -> List[str]:
|
16 |
+
return splitter.split_text(text)
|
17 |
+
|
18 |
+
def metrics(chunks: List[str]) -> Dict:
|
19 |
+
if not chunks:
|
20 |
+
return {"chunks": 0, "avg_len": 0, "max_len": 0}
|
21 |
+
lens = [len(c) for c in chunks]
|
22 |
+
return {
|
23 |
+
"chunks": len(chunks),
|
24 |
+
"avg_len": round(mean(lens), 1),
|
25 |
+
"max_len": max(lens),
|
26 |
+
}
|
27 |
+
|
28 |
+
# --- Strategy evaluation ---
|
29 |
+
def evaluate_strategies(
|
30 |
+
text: str,
|
31 |
+
char_size: int = 800,
|
32 |
+
char_overlap: int = 100,
|
33 |
+
token_size: int = 512,
|
34 |
+
token_overlap: int = 64,
|
35 |
+
) -> Dict[str, Dict]:
|
36 |
+
fixed = CharacterTextSplitter(chunk_size=char_size, chunk_overlap=char_overlap, separator="\n")
|
37 |
+
recursive = RecursiveCharacterTextSplitter(
|
38 |
+
chunk_size=char_size, chunk_overlap=char_overlap, separators=["\n\n", "\n", " ", ""]
|
39 |
+
)
|
40 |
+
token = TokenTextSplitter(chunk_size=token_size, chunk_overlap=token_overlap)
|
41 |
+
|
42 |
+
fixed_chunks = run_splitter(text, fixed)
|
43 |
+
rec_chunks = run_splitter(text, recursive)
|
44 |
+
tok_chunks = run_splitter(text, token)
|
45 |
+
|
46 |
+
return {
|
47 |
+
"fixed": {"chunks": fixed_chunks, "metrics": metrics(fixed_chunks), "meta": {"size": char_size, "overlap": char_overlap, "unit": "chars"}},
|
48 |
+
"recursive": {"chunks": rec_chunks, "metrics": metrics(rec_chunks), "meta": {"size": char_size, "overlap": char_overlap, "unit": "chars"}},
|
49 |
+
"token": {"chunks": tok_chunks, "metrics": metrics(tok_chunks), "meta": {"size": token_size, "overlap": token_overlap, "unit": "tokens"}},
|
50 |
+
}
|
51 |
+
|
52 |
+
def score(candidate: Dict, target_avg: int = 800, hard_max: int = 1500) -> float:
|
53 |
+
"""Lower is better: distance to target + penalty if max chunk too large."""
|
54 |
+
m = candidate["metrics"]
|
55 |
+
dist = abs(m["avg_len"] - target_avg)
|
56 |
+
penalty = 0 if m["max_len"] <= hard_max else (m["max_len"] - hard_max)
|
57 |
+
# Favor more, smaller chunks over 1 giant chunk
|
58 |
+
few_chunk_penalty = 500 if m["chunks"] <= 1 else 0
|
59 |
+
return dist + penalty + few_chunk_penalty
|
60 |
+
|
61 |
+
def select_best(evals: Dict[str, Dict], target_avg: int = 800, hard_max: int = 1500) -> Tuple[str, Dict]:
|
62 |
+
scored = [(name, score(info, target_avg, hard_max)) for name, info in evals.items()]
|
63 |
+
scored.sort(key=lambda x: x[1])
|
64 |
+
return scored[0][0], evals[scored[0][0]]
|
65 |
+
|
66 |
+
# --- Final pipeline API ---
|
67 |
+
class OptimalChunker:
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
char_size: int = 800,
|
71 |
+
char_overlap: int = 100,
|
72 |
+
token_size: int = 512,
|
73 |
+
token_overlap: int = 64,
|
74 |
+
target_avg: int = 800,
|
75 |
+
hard_max: int = 1500,
|
76 |
+
):
|
77 |
+
self.char_size = char_size
|
78 |
+
self.char_overlap = char_overlap
|
79 |
+
self.token_size = token_size
|
80 |
+
self.token_overlap = token_overlap
|
81 |
+
self.target_avg = target_avg
|
82 |
+
self.hard_max = hard_max
|
83 |
+
self.best_name = None
|
84 |
+
self.best_info = None
|
85 |
+
|
86 |
+
def fit_on_text(self, text: str) -> Dict:
|
87 |
+
evals = evaluate_strategies(
|
88 |
+
text,
|
89 |
+
char_size=self.char_size,
|
90 |
+
char_overlap=self.char_overlap,
|
91 |
+
token_size=self.token_size,
|
92 |
+
token_overlap=self.token_overlap,
|
93 |
+
)
|
94 |
+
self.best_name, self.best_info = select_best(evals, self.target_avg, self.hard_max)
|
95 |
+
return {"best": self.best_name, "metrics": self.best_info["metrics"], "meta": self.best_info["meta"]}
|
96 |
+
|
97 |
+
def transform(self) -> List[str]:
|
98 |
+
assert self.best_info is not None, "Call fit_on_text first."
|
99 |
+
return self.best_info["chunks"]
|
100 |
+
|
101 |
+
def fit_transform_pdf(self, pdf_path: str) -> Tuple[str, List[str], Dict]:
|
102 |
+
docs = load_pdf(pdf_path)
|
103 |
+
text = docs_to_text(docs)
|
104 |
+
summary = self.fit_on_text(text)
|
105 |
+
return self.best_name, self.transform(), summary
|
106 |
+
|
107 |
+
if __name__ == "__main__":
|
108 |
+
# Demo on sample.pdf
|
109 |
+
ch = OptimalChunker()
|
110 |
+
best, chunks, summary = ch.fit_transform_pdf("sample.pdf")
|
111 |
+
print("=== Best Strategy ===")
|
112 |
+
print(best, summary)
|
113 |
+
print(f"First chunk preview:\n{chunks[0][:300] if chunks else ''}")
|
day3/pdf_loader.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.document_loaders import PyPDFLoader
|
2 |
+
|
3 |
+
def load_pdf(file_path):
|
4 |
+
loader = PyPDFLoader(file_path)
|
5 |
+
pages = loader.load()
|
6 |
+
return pages
|
7 |
+
|
8 |
+
if __name__ == "__main__":
|
9 |
+
docs = load_pdf("sample.pdf")
|
10 |
+
print(f"✅ Loaded {len(docs)} pages")
|
11 |
+
for i, page in enumerate(docs, start=1):
|
12 |
+
print(f"--- Page {i} ---")
|
13 |
+
print(page.page_content)
|
day3/rag_system.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# rag_system.py
|
2 |
+
from typing import List, Dict
|
3 |
+
import chromadb
|
4 |
+
from pdf_loader import load_pdf
|
5 |
+
from optimal_chunker import OptimalChunker
|
6 |
+
from embeddings import embed_texts
|
7 |
+
from langchain_groq import ChatGroq
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
import os
|
10 |
+
|
11 |
+
load_dotenv()
|
12 |
+
|
13 |
+
|
14 |
+
class RAGPipeline:
|
15 |
+
def __init__(self, persist_dir: str = "./chroma_db", collection_name: str = "pdf_docs"):
|
16 |
+
# Vector DB (Chroma 1.x new API)
|
17 |
+
self.client = chromadb.PersistentClient(path=persist_dir)
|
18 |
+
self.col = self.client.get_or_create_collection(name=collection_name)
|
19 |
+
|
20 |
+
# Chunker
|
21 |
+
self.chunker = OptimalChunker()
|
22 |
+
|
23 |
+
# LLM (Groq)
|
24 |
+
self.llm = ChatGroq(
|
25 |
+
model="llama3-8b-8192",
|
26 |
+
temperature=0.0,
|
27 |
+
api_key=os.getenv("GROQ_API_KEY"),
|
28 |
+
)
|
29 |
+
|
30 |
+
# 1) Load 2) Chunk 3) Embed 4) Upsert to Chroma
|
31 |
+
def index_document(self, pdf_path: str, doc_id_prefix: str = "doc") -> Dict:
|
32 |
+
docs = load_pdf(pdf_path)
|
33 |
+
text = "\n\n".join(d.page_content for d in docs)
|
34 |
+
|
35 |
+
summary = self.chunker.fit_on_text(text)
|
36 |
+
chunks = self.chunker.transform()
|
37 |
+
|
38 |
+
# embeddings: list[list[float]]
|
39 |
+
vectors = embed_texts(chunks)
|
40 |
+
ids = [f"{doc_id_prefix}-{i}" for i in range(len(chunks))]
|
41 |
+
|
42 |
+
self.col.add(
|
43 |
+
ids=ids,
|
44 |
+
documents=chunks,
|
45 |
+
embeddings=vectors,
|
46 |
+
metadatas=[{"source": pdf_path, "chunk": i} for i in range(len(chunks))],
|
47 |
+
)
|
48 |
+
return {"chunks_indexed": len(chunks), "best_strategy": summary}
|
49 |
+
|
50 |
+
# 5) Retrieve 6) Ask LLM
|
51 |
+
def query(self, question: str, k: int = 4) -> Dict:
|
52 |
+
results = self.col.query(query_texts=[question], n_results=k)
|
53 |
+
chunks: List[str] = results["documents"][0] if results.get("documents") else []
|
54 |
+
|
55 |
+
context = "\n\n".join(chunks)
|
56 |
+
prompt = f"""You are an extraction assistant. Use ONLY the Context to answer.
|
57 |
+
Rules:
|
58 |
+
- If the answer is explicitly present in Context, return that substring EXACTLY.
|
59 |
+
- Do not paraphrase. Do not add words. Return a verbatim span from Context.
|
60 |
+
- If the answer is not in Context, reply exactly: I don't know
|
61 |
+
|
62 |
+
Question: {question}
|
63 |
+
|
64 |
+
Context:
|
65 |
+
{context}
|
66 |
+
|
67 |
+
Answer (verbatim from Context):"""
|
68 |
+
resp = self.llm.invoke(prompt)
|
69 |
+
answer = resp.content.strip()
|
70 |
+
|
71 |
+
# Fallback if the model still hedges
|
72 |
+
if (not answer or answer.lower().startswith("i don't know")) and context.strip():
|
73 |
+
answer = chunks[0] if chunks else "I don't know"
|
74 |
+
|
75 |
+
return {
|
76 |
+
"answer": answer,
|
77 |
+
"used_chunks": len(chunks),
|
78 |
+
"context_preview": context[:500],
|
79 |
+
}
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
rag = RAGPipeline()
|
84 |
+
info = rag.index_document("sample.pdf") # ensure day3/sample.pdf exists
|
85 |
+
print("Indexed:", info)
|
86 |
+
|
87 |
+
out = rag.query("What text does the PDF contain?")
|
88 |
+
print("Answer:", out["answer"])
|
89 |
+
print("Used chunks:", out["used_chunks"])
|
day3/vector_store.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# vector_store.py
|
2 |
+
import chromadb
|
3 |
+
|
4 |
+
# New persistent client (replaces Settings / duckdb+parquet)
|
5 |
+
client = chromadb.PersistentClient(path="./chroma_db")
|
6 |
+
|
7 |
+
# Create or get collection
|
8 |
+
collection = client.get_or_create_collection("pdf_docs")
|
9 |
+
|
10 |
+
def reset_db():
|
11 |
+
client.delete_collection("pdf_docs")
|
12 |
+
return client.get_or_create_collection("pdf_docs")
|
13 |
+
|
14 |
+
if __name__ == "__main__":
|
15 |
+
print("ChromaDB ready. Collections:", [c.name for c in client.list_collections()])
|