Spaces:
Sleeping
Sleeping
# optimal_chunker.py | |
from typing import Dict, List, Tuple | |
from statistics import mean | |
from langchain.text_splitter import ( | |
CharacterTextSplitter, | |
RecursiveCharacterTextSplitter, | |
TokenTextSplitter, | |
) | |
from pdf_loader import load_pdf | |
# --- Helpers --- | |
def docs_to_text(docs) -> str: | |
return "\n\n".join([d.page_content for d in docs]) | |
def run_splitter(text: str, splitter) -> List[str]: | |
return splitter.split_text(text) | |
def metrics(chunks: List[str]) -> Dict: | |
if not chunks: | |
return {"chunks": 0, "avg_len": 0, "max_len": 0} | |
lens = [len(c) for c in chunks] | |
return { | |
"chunks": len(chunks), | |
"avg_len": round(mean(lens), 1), | |
"max_len": max(lens), | |
} | |
# --- Strategy evaluation --- | |
def evaluate_strategies( | |
text: str, | |
char_size: int = 800, | |
char_overlap: int = 100, | |
token_size: int = 512, | |
token_overlap: int = 64, | |
) -> Dict[str, Dict]: | |
fixed = CharacterTextSplitter(chunk_size=char_size, chunk_overlap=char_overlap, separator="\n") | |
recursive = RecursiveCharacterTextSplitter( | |
chunk_size=char_size, chunk_overlap=char_overlap, separators=["\n\n", "\n", " ", ""] | |
) | |
token = TokenTextSplitter(chunk_size=token_size, chunk_overlap=token_overlap) | |
fixed_chunks = run_splitter(text, fixed) | |
rec_chunks = run_splitter(text, recursive) | |
tok_chunks = run_splitter(text, token) | |
return { | |
"fixed": {"chunks": fixed_chunks, "metrics": metrics(fixed_chunks), "meta": {"size": char_size, "overlap": char_overlap, "unit": "chars"}}, | |
"recursive": {"chunks": rec_chunks, "metrics": metrics(rec_chunks), "meta": {"size": char_size, "overlap": char_overlap, "unit": "chars"}}, | |
"token": {"chunks": tok_chunks, "metrics": metrics(tok_chunks), "meta": {"size": token_size, "overlap": token_overlap, "unit": "tokens"}}, | |
} | |
def score(candidate: Dict, target_avg: int = 800, hard_max: int = 1500) -> float: | |
"""Lower is better: distance to target + penalty if max chunk too large.""" | |
m = candidate["metrics"] | |
dist = abs(m["avg_len"] - target_avg) | |
penalty = 0 if m["max_len"] <= hard_max else (m["max_len"] - hard_max) | |
# Favor more, smaller chunks over 1 giant chunk | |
few_chunk_penalty = 500 if m["chunks"] <= 1 else 0 | |
return dist + penalty + few_chunk_penalty | |
def select_best(evals: Dict[str, Dict], target_avg: int = 800, hard_max: int = 1500) -> Tuple[str, Dict]: | |
scored = [(name, score(info, target_avg, hard_max)) for name, info in evals.items()] | |
scored.sort(key=lambda x: x[1]) | |
return scored[0][0], evals[scored[0][0]] | |
# --- Final pipeline API --- | |
class OptimalChunker: | |
def __init__( | |
self, | |
char_size: int = 800, | |
char_overlap: int = 100, | |
token_size: int = 512, | |
token_overlap: int = 64, | |
target_avg: int = 800, | |
hard_max: int = 1500, | |
): | |
self.char_size = char_size | |
self.char_overlap = char_overlap | |
self.token_size = token_size | |
self.token_overlap = token_overlap | |
self.target_avg = target_avg | |
self.hard_max = hard_max | |
self.best_name = None | |
self.best_info = None | |
def fit_on_text(self, text: str) -> Dict: | |
evals = evaluate_strategies( | |
text, | |
char_size=self.char_size, | |
char_overlap=self.char_overlap, | |
token_size=self.token_size, | |
token_overlap=self.token_overlap, | |
) | |
self.best_name, self.best_info = select_best(evals, self.target_avg, self.hard_max) | |
return {"best": self.best_name, "metrics": self.best_info["metrics"], "meta": self.best_info["meta"]} | |
def transform(self) -> List[str]: | |
assert self.best_info is not None, "Call fit_on_text first." | |
return self.best_info["chunks"] | |
def fit_transform_pdf(self, pdf_path: str) -> Tuple[str, List[str], Dict]: | |
docs = load_pdf(pdf_path) | |
text = docs_to_text(docs) | |
summary = self.fit_on_text(text) | |
return self.best_name, self.transform(), summary | |
if __name__ == "__main__": | |
# Demo on sample.pdf | |
ch = OptimalChunker() | |
best, chunks, summary = ch.fit_transform_pdf("sample.pdf") | |
print("=== Best Strategy ===") | |
print(best, summary) | |
print(f"First chunk preview:\n{chunks[0][:300] if chunks else ''}") | |