File size: 5,876 Bytes
a52e760
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import re
import json
import argparse
from typing import List, Dict
import bisect

class ClaudeTokenizer:
    def __init__(self, config_file: str, algorithm: str = "trie"):
        with open(config_file, "r") as f:
            config = json.load(f)

        self.vocab = sorted(config["vocab"])  # Sort vocab for binary search
        self.vocab_size = config["n_vocab_size"]
        self.pat_str = config["pat_str"]
        self.special_tokens = config["special_tokens"]

        self.token_to_id = {token: i for i, token in enumerate(self.vocab)}
        self.id_to_token = {i: token for token, i in self.token_to_id.items()}

        for token, id in self.special_tokens.items():
            self.token_to_id[token] = id
            self.id_to_token[id] = token

        self.pat = re.compile(self.pat_str)
        self.vocab_trie = self._build_trie(self.vocab)

        self.algorithm = algorithm
        if algorithm not in ["trie", "linear"]:
            raise ValueError("Invalid algorithm. Choose 'trie' or 'linear'.")

    def _build_trie(self, vocab: List[str]) -> Dict:
        trie = {}
        for token in vocab:
            current = trie
            for char in token:
                if isinstance(current, str):
                    break
                if char not in current:
                    current[char] = {}
                current = current[char]
            if isinstance(current, dict):
                current["*"] = token
        return trie

    def tokenize(self, text: str) -> List[str]:
        if self.algorithm == "trie":
            tokens = []
            for part in self.pat.findall(text):
                tokens.extend(self._tokenize_part_trie(part))
            return tokens
        else:
            return self._tokenize_part_linear(text)

    def encode(self, text: str) -> List[int]:
        tokens = self.tokenize(text)
        return [
            self.token_to_id.get(token, self.special_tokens["<META>"])
            for token in tokens
        ]

    def decode(self, ids: List[int]) -> str:
        return "".join(self.id_to_token.get(id, "") for id in ids)

    def _tokenize_part_trie(self, text: str) -> List[str]:
        tokens = []
        while text:
            current = self.vocab_trie
            longest_match = ""
            for i, char in enumerate(text):
                if char not in current:
                    break
                current = current[char]
                if "*" in current:
                    longest_match = current["*"]
            if longest_match:
                tokens.append(longest_match)
                text = text[len(longest_match):]
            else:
                tokens.append(text[0])
                text = text[1:]
        return tokens

    def _tokenize_part_linear(self, text: str) -> List[str]:
        tokens = []
        while text:
            longest_match = self._binary_search_prefix(text)
            if longest_match:
                tokens.append(longest_match)
                text = text[len(longest_match):]
            else:
                tokens.append(text[0])
                text = text[1:]
        return tokens

    def _binary_search_prefix(self, text: str) -> str:
        left, right = 0, len(self.vocab) - 1
        longest_match = ""

        while left <= right:
            mid = (left + right) // 2
            if text.startswith(self.vocab[mid]):
                longest_match = self.vocab[mid]
                left = mid + 1
            elif self.vocab[mid] < text:
                left = mid + 1
            else:
                right = mid - 1

        return longest_match

def process_file(file_path: str, tokenizer: ClaudeTokenizer) -> List[Dict]:
    encodings = ['utf-8', 'utf-16', 'latin-1', 'iso-8859-1']
    
    for encoding in encodings:
        try:
            with open(file_path, 'r', encoding=encoding) as f:
                text = f.read()
            break
        except UnicodeDecodeError:
            continue
    else:
        raise ValueError(f"Unable to decode the file {file_path} with any of the attempted encodings.")

    tokens = tokenizer.tokenize(text)
    encoded = tokenizer.encode(text)

    result = [{"token": token, "id": id} for token, id in zip(tokens, encoded)]
    result.append({"total": len(tokens)})

    return result

def main():
    parser = argparse.ArgumentParser(description="Tokenize text using Claude Tokenizer")
    parser.add_argument("--text", type=str, help="Text to tokenize")
    parser.add_argument("--file", type=str, help="File to tokenize")
    parser.add_argument("--algo", type=str, choices=["linear", "trie"], required=True, help="Tokenization algorithm")
    args = parser.parse_args()

    if not args.text and not args.file:
        parser.error("Either --text or --file must be specified")

    try:
        tokenizer = ClaudeTokenizer("tokenizer_config.json", algorithm=args.algo)

        if args.file:
            result = process_file(args.file, tokenizer)
            output_file = args.file + ".tokens"
            with open(output_file, 'w', encoding='utf-8') as f:
                json.dump(result, f, indent=2, ensure_ascii=False)
            print(f"Tokenization results saved to {output_file}")
        else:
            tokens = tokenizer.tokenize(args.text)
            encoded = tokenizer.encode(args.text)
            result = [{"token": token, "id": id} for token, id in zip(tokens, encoded)]
            result.append({"total": len(tokens)})
            print(json.dumps(result, indent=2, ensure_ascii=False))
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()