claude-3-tokenizer / tokenize.py
leafspark's picture
model: add test files and support command line arguments
a52e760 verified
import re
import json
import argparse
from typing import List, Dict
import bisect
class ClaudeTokenizer:
def __init__(self, config_file: str, algorithm: str = "trie"):
with open(config_file, "r") as f:
config = json.load(f)
self.vocab = sorted(config["vocab"]) # Sort vocab for binary search
self.vocab_size = config["n_vocab_size"]
self.pat_str = config["pat_str"]
self.special_tokens = config["special_tokens"]
self.token_to_id = {token: i for i, token in enumerate(self.vocab)}
self.id_to_token = {i: token for token, i in self.token_to_id.items()}
for token, id in self.special_tokens.items():
self.token_to_id[token] = id
self.id_to_token[id] = token
self.pat = re.compile(self.pat_str)
self.vocab_trie = self._build_trie(self.vocab)
self.algorithm = algorithm
if algorithm not in ["trie", "linear"]:
raise ValueError("Invalid algorithm. Choose 'trie' or 'linear'.")
def _build_trie(self, vocab: List[str]) -> Dict:
trie = {}
for token in vocab:
current = trie
for char in token:
if isinstance(current, str):
break
if char not in current:
current[char] = {}
current = current[char]
if isinstance(current, dict):
current["*"] = token
return trie
def tokenize(self, text: str) -> List[str]:
if self.algorithm == "trie":
tokens = []
for part in self.pat.findall(text):
tokens.extend(self._tokenize_part_trie(part))
return tokens
else:
return self._tokenize_part_linear(text)
def encode(self, text: str) -> List[int]:
tokens = self.tokenize(text)
return [
self.token_to_id.get(token, self.special_tokens["<META>"])
for token in tokens
]
def decode(self, ids: List[int]) -> str:
return "".join(self.id_to_token.get(id, "") for id in ids)
def _tokenize_part_trie(self, text: str) -> List[str]:
tokens = []
while text:
current = self.vocab_trie
longest_match = ""
for i, char in enumerate(text):
if char not in current:
break
current = current[char]
if "*" in current:
longest_match = current["*"]
if longest_match:
tokens.append(longest_match)
text = text[len(longest_match):]
else:
tokens.append(text[0])
text = text[1:]
return tokens
def _tokenize_part_linear(self, text: str) -> List[str]:
tokens = []
while text:
longest_match = self._binary_search_prefix(text)
if longest_match:
tokens.append(longest_match)
text = text[len(longest_match):]
else:
tokens.append(text[0])
text = text[1:]
return tokens
def _binary_search_prefix(self, text: str) -> str:
left, right = 0, len(self.vocab) - 1
longest_match = ""
while left <= right:
mid = (left + right) // 2
if text.startswith(self.vocab[mid]):
longest_match = self.vocab[mid]
left = mid + 1
elif self.vocab[mid] < text:
left = mid + 1
else:
right = mid - 1
return longest_match
def process_file(file_path: str, tokenizer: ClaudeTokenizer) -> List[Dict]:
encodings = ['utf-8', 'utf-16', 'latin-1', 'iso-8859-1']
for encoding in encodings:
try:
with open(file_path, 'r', encoding=encoding) as f:
text = f.read()
break
except UnicodeDecodeError:
continue
else:
raise ValueError(f"Unable to decode the file {file_path} with any of the attempted encodings.")
tokens = tokenizer.tokenize(text)
encoded = tokenizer.encode(text)
result = [{"token": token, "id": id} for token, id in zip(tokens, encoded)]
result.append({"total": len(tokens)})
return result
def main():
parser = argparse.ArgumentParser(description="Tokenize text using Claude Tokenizer")
parser.add_argument("--text", type=str, help="Text to tokenize")
parser.add_argument("--file", type=str, help="File to tokenize")
parser.add_argument("--algo", type=str, choices=["linear", "trie"], required=True, help="Tokenization algorithm")
args = parser.parse_args()
if not args.text and not args.file:
parser.error("Either --text or --file must be specified")
try:
tokenizer = ClaudeTokenizer("tokenizer_config.json", algorithm=args.algo)
if args.file:
result = process_file(args.file, tokenizer)
output_file = args.file + ".tokens"
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(result, f, indent=2, ensure_ascii=False)
print(f"Tokenization results saved to {output_file}")
else:
tokens = tokenizer.tokenize(args.text)
encoded = tokenizer.encode(args.text)
result = [{"token": token, "id": id} for token, id in zip(tokens, encoded)]
result.append({"total": len(tokens)})
print(json.dumps(result, indent=2, ensure_ascii=False))
except Exception as e:
print(f"An error occurred: {str(e)}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()