import re import json import argparse from typing import List, Dict import bisect class ClaudeTokenizer: def __init__(self, config_file: str, algorithm: str = "trie"): with open(config_file, "r") as f: config = json.load(f) self.vocab = sorted(config["vocab"]) # Sort vocab for binary search self.vocab_size = config["n_vocab_size"] self.pat_str = config["pat_str"] self.special_tokens = config["special_tokens"] self.token_to_id = {token: i for i, token in enumerate(self.vocab)} self.id_to_token = {i: token for token, i in self.token_to_id.items()} for token, id in self.special_tokens.items(): self.token_to_id[token] = id self.id_to_token[id] = token self.pat = re.compile(self.pat_str) self.vocab_trie = self._build_trie(self.vocab) self.algorithm = algorithm if algorithm not in ["trie", "linear"]: raise ValueError("Invalid algorithm. Choose 'trie' or 'linear'.") def _build_trie(self, vocab: List[str]) -> Dict: trie = {} for token in vocab: current = trie for char in token: if isinstance(current, str): break if char not in current: current[char] = {} current = current[char] if isinstance(current, dict): current["*"] = token return trie def tokenize(self, text: str) -> List[str]: if self.algorithm == "trie": tokens = [] for part in self.pat.findall(text): tokens.extend(self._tokenize_part_trie(part)) return tokens else: return self._tokenize_part_linear(text) def encode(self, text: str) -> List[int]: tokens = self.tokenize(text) return [ self.token_to_id.get(token, self.special_tokens[""]) for token in tokens ] def decode(self, ids: List[int]) -> str: return "".join(self.id_to_token.get(id, "") for id in ids) def _tokenize_part_trie(self, text: str) -> List[str]: tokens = [] while text: current = self.vocab_trie longest_match = "" for i, char in enumerate(text): if char not in current: break current = current[char] if "*" in current: longest_match = current["*"] if longest_match: tokens.append(longest_match) text = text[len(longest_match):] else: tokens.append(text[0]) text = text[1:] return tokens def _tokenize_part_linear(self, text: str) -> List[str]: tokens = [] while text: longest_match = self._binary_search_prefix(text) if longest_match: tokens.append(longest_match) text = text[len(longest_match):] else: tokens.append(text[0]) text = text[1:] return tokens def _binary_search_prefix(self, text: str) -> str: left, right = 0, len(self.vocab) - 1 longest_match = "" while left <= right: mid = (left + right) // 2 if text.startswith(self.vocab[mid]): longest_match = self.vocab[mid] left = mid + 1 elif self.vocab[mid] < text: left = mid + 1 else: right = mid - 1 return longest_match def process_file(file_path: str, tokenizer: ClaudeTokenizer) -> List[Dict]: encodings = ['utf-8', 'utf-16', 'latin-1', 'iso-8859-1'] for encoding in encodings: try: with open(file_path, 'r', encoding=encoding) as f: text = f.read() break except UnicodeDecodeError: continue else: raise ValueError(f"Unable to decode the file {file_path} with any of the attempted encodings.") tokens = tokenizer.tokenize(text) encoded = tokenizer.encode(text) result = [{"token": token, "id": id} for token, id in zip(tokens, encoded)] result.append({"total": len(tokens)}) return result def main(): parser = argparse.ArgumentParser(description="Tokenize text using Claude Tokenizer") parser.add_argument("--text", type=str, help="Text to tokenize") parser.add_argument("--file", type=str, help="File to tokenize") parser.add_argument("--algo", type=str, choices=["linear", "trie"], required=True, help="Tokenization algorithm") args = parser.parse_args() if not args.text and not args.file: parser.error("Either --text or --file must be specified") try: tokenizer = ClaudeTokenizer("tokenizer_config.json", algorithm=args.algo) if args.file: result = process_file(args.file, tokenizer) output_file = args.file + ".tokens" with open(output_file, 'w', encoding='utf-8') as f: json.dump(result, f, indent=2, ensure_ascii=False) print(f"Tokenization results saved to {output_file}") else: tokens = tokenizer.tokenize(args.text) encoded = tokenizer.encode(args.text) result = [{"token": token, "id": id} for token, id in zip(tokens, encoded)] result.append({"total": len(tokens)}) print(json.dumps(result, indent=2, ensure_ascii=False)) except Exception as e: print(f"An error occurred: {str(e)}") import traceback traceback.print_exc() if __name__ == "__main__": main()