File size: 4,303 Bytes
e7e9247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# optimal_chunker.py
from typing import Dict, List, Tuple
from statistics import mean
from langchain.text_splitter import (
    CharacterTextSplitter,
    RecursiveCharacterTextSplitter,
    TokenTextSplitter,
)
from pdf_loader import load_pdf

# --- Helpers ---
def docs_to_text(docs) -> str:
    return "\n\n".join([d.page_content for d in docs])

def run_splitter(text: str, splitter) -> List[str]:
    return splitter.split_text(text)

def metrics(chunks: List[str]) -> Dict:
    if not chunks:
        return {"chunks": 0, "avg_len": 0, "max_len": 0}
    lens = [len(c) for c in chunks]
    return {
        "chunks": len(chunks),
        "avg_len": round(mean(lens), 1),
        "max_len": max(lens),
    }

# --- Strategy evaluation ---
def evaluate_strategies(
    text: str,
    char_size: int = 800,
    char_overlap: int = 100,
    token_size: int = 512,
    token_overlap: int = 64,
) -> Dict[str, Dict]:
    fixed = CharacterTextSplitter(chunk_size=char_size, chunk_overlap=char_overlap, separator="\n")
    recursive = RecursiveCharacterTextSplitter(
        chunk_size=char_size, chunk_overlap=char_overlap, separators=["\n\n", "\n", " ", ""]
    )
    token = TokenTextSplitter(chunk_size=token_size, chunk_overlap=token_overlap)

    fixed_chunks = run_splitter(text, fixed)
    rec_chunks = run_splitter(text, recursive)
    tok_chunks = run_splitter(text, token)

    return {
        "fixed":   {"chunks": fixed_chunks, "metrics": metrics(fixed_chunks), "meta": {"size": char_size, "overlap": char_overlap, "unit": "chars"}},
        "recursive": {"chunks": rec_chunks, "metrics": metrics(rec_chunks), "meta": {"size": char_size, "overlap": char_overlap, "unit": "chars"}},
        "token":   {"chunks": tok_chunks, "metrics": metrics(tok_chunks), "meta": {"size": token_size, "overlap": token_overlap, "unit": "tokens"}},
    }

def score(candidate: Dict, target_avg: int = 800, hard_max: int = 1500) -> float:
    """Lower is better: distance to target + penalty if max chunk too large."""
    m = candidate["metrics"]
    dist = abs(m["avg_len"] - target_avg)
    penalty = 0 if m["max_len"] <= hard_max else (m["max_len"] - hard_max)
    # Favor more, smaller chunks over 1 giant chunk
    few_chunk_penalty = 500 if m["chunks"] <= 1 else 0
    return dist + penalty + few_chunk_penalty

def select_best(evals: Dict[str, Dict], target_avg: int = 800, hard_max: int = 1500) -> Tuple[str, Dict]:
    scored = [(name, score(info, target_avg, hard_max)) for name, info in evals.items()]
    scored.sort(key=lambda x: x[1])
    return scored[0][0], evals[scored[0][0]]

# --- Final pipeline API ---
class OptimalChunker:
    def __init__(
        self,
        char_size: int = 800,
        char_overlap: int = 100,
        token_size: int = 512,
        token_overlap: int = 64,
        target_avg: int = 800,
        hard_max: int = 1500,
    ):
        self.char_size = char_size
        self.char_overlap = char_overlap
        self.token_size = token_size
        self.token_overlap = token_overlap
        self.target_avg = target_avg
        self.hard_max = hard_max
        self.best_name = None
        self.best_info = None

    def fit_on_text(self, text: str) -> Dict:
        evals = evaluate_strategies(
            text,
            char_size=self.char_size,
            char_overlap=self.char_overlap,
            token_size=self.token_size,
            token_overlap=self.token_overlap,
        )
        self.best_name, self.best_info = select_best(evals, self.target_avg, self.hard_max)
        return {"best": self.best_name, "metrics": self.best_info["metrics"], "meta": self.best_info["meta"]}

    def transform(self) -> List[str]:
        assert self.best_info is not None, "Call fit_on_text first."
        return self.best_info["chunks"]

    def fit_transform_pdf(self, pdf_path: str) -> Tuple[str, List[str], Dict]:
        docs = load_pdf(pdf_path)
        text = docs_to_text(docs)
        summary = self.fit_on_text(text)
        return self.best_name, self.transform(), summary

if __name__ == "__main__":
    # Demo on sample.pdf
    ch = OptimalChunker()
    best, chunks, summary = ch.fit_transform_pdf("sample.pdf")
    print("=== Best Strategy ===")
    print(best, summary)
    print(f"First chunk preview:\n{chunks[0][:300] if chunks else ''}")