Spaces:
Sleeping
Sleeping
File size: 4,303 Bytes
e7e9247 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
# optimal_chunker.py
from typing import Dict, List, Tuple
from statistics import mean
from langchain.text_splitter import (
CharacterTextSplitter,
RecursiveCharacterTextSplitter,
TokenTextSplitter,
)
from pdf_loader import load_pdf
# --- Helpers ---
def docs_to_text(docs) -> str:
return "\n\n".join([d.page_content for d in docs])
def run_splitter(text: str, splitter) -> List[str]:
return splitter.split_text(text)
def metrics(chunks: List[str]) -> Dict:
if not chunks:
return {"chunks": 0, "avg_len": 0, "max_len": 0}
lens = [len(c) for c in chunks]
return {
"chunks": len(chunks),
"avg_len": round(mean(lens), 1),
"max_len": max(lens),
}
# --- Strategy evaluation ---
def evaluate_strategies(
text: str,
char_size: int = 800,
char_overlap: int = 100,
token_size: int = 512,
token_overlap: int = 64,
) -> Dict[str, Dict]:
fixed = CharacterTextSplitter(chunk_size=char_size, chunk_overlap=char_overlap, separator="\n")
recursive = RecursiveCharacterTextSplitter(
chunk_size=char_size, chunk_overlap=char_overlap, separators=["\n\n", "\n", " ", ""]
)
token = TokenTextSplitter(chunk_size=token_size, chunk_overlap=token_overlap)
fixed_chunks = run_splitter(text, fixed)
rec_chunks = run_splitter(text, recursive)
tok_chunks = run_splitter(text, token)
return {
"fixed": {"chunks": fixed_chunks, "metrics": metrics(fixed_chunks), "meta": {"size": char_size, "overlap": char_overlap, "unit": "chars"}},
"recursive": {"chunks": rec_chunks, "metrics": metrics(rec_chunks), "meta": {"size": char_size, "overlap": char_overlap, "unit": "chars"}},
"token": {"chunks": tok_chunks, "metrics": metrics(tok_chunks), "meta": {"size": token_size, "overlap": token_overlap, "unit": "tokens"}},
}
def score(candidate: Dict, target_avg: int = 800, hard_max: int = 1500) -> float:
"""Lower is better: distance to target + penalty if max chunk too large."""
m = candidate["metrics"]
dist = abs(m["avg_len"] - target_avg)
penalty = 0 if m["max_len"] <= hard_max else (m["max_len"] - hard_max)
# Favor more, smaller chunks over 1 giant chunk
few_chunk_penalty = 500 if m["chunks"] <= 1 else 0
return dist + penalty + few_chunk_penalty
def select_best(evals: Dict[str, Dict], target_avg: int = 800, hard_max: int = 1500) -> Tuple[str, Dict]:
scored = [(name, score(info, target_avg, hard_max)) for name, info in evals.items()]
scored.sort(key=lambda x: x[1])
return scored[0][0], evals[scored[0][0]]
# --- Final pipeline API ---
class OptimalChunker:
def __init__(
self,
char_size: int = 800,
char_overlap: int = 100,
token_size: int = 512,
token_overlap: int = 64,
target_avg: int = 800,
hard_max: int = 1500,
):
self.char_size = char_size
self.char_overlap = char_overlap
self.token_size = token_size
self.token_overlap = token_overlap
self.target_avg = target_avg
self.hard_max = hard_max
self.best_name = None
self.best_info = None
def fit_on_text(self, text: str) -> Dict:
evals = evaluate_strategies(
text,
char_size=self.char_size,
char_overlap=self.char_overlap,
token_size=self.token_size,
token_overlap=self.token_overlap,
)
self.best_name, self.best_info = select_best(evals, self.target_avg, self.hard_max)
return {"best": self.best_name, "metrics": self.best_info["metrics"], "meta": self.best_info["meta"]}
def transform(self) -> List[str]:
assert self.best_info is not None, "Call fit_on_text first."
return self.best_info["chunks"]
def fit_transform_pdf(self, pdf_path: str) -> Tuple[str, List[str], Dict]:
docs = load_pdf(pdf_path)
text = docs_to_text(docs)
summary = self.fit_on_text(text)
return self.best_name, self.transform(), summary
if __name__ == "__main__":
# Demo on sample.pdf
ch = OptimalChunker()
best, chunks, summary = ch.fit_transform_pdf("sample.pdf")
print("=== Best Strategy ===")
print(best, summary)
print(f"First chunk preview:\n{chunks[0][:300] if chunks else ''}")
|