import tiktoken from aworld.logs.util import logger # TODO: merge to `models` package MODEL_TO_ENCODING = { "gpt-3.5-turbo": "cl100k_base", "gpt-4": "cl100k_base", "text-davinci-003": "p50k_base", "text-embedding-ada-002": "cl100k_base", "text-curie-001": "r50k_base", "text-babbage-001": "r50k_base", "text-ada-001": "r50k_base", } def get_encoding_for_model(model_name: str) -> tiktoken.Encoding: """ Automatically select the corresponding encoder based on the model name. """ encoding_name = MODEL_TO_ENCODING.get(model_name) if encoding_name is None: logger.warning(f"model '{model_name}' not found in mapping table.") return "cl100k_base" return encoding_name def count_tokens(model_name: str, content: str): encoding = tiktoken.get_encoding(get_encoding_for_model(model_name)) tokens = encoding.encode(content) token_count = len(tokens) return token_count