Husnain
commited on
⚡ [Enhance] Use nous-mixtral-8x7b as default model
Browse files- messagers/token_checker.py +46 -0
messagers/token_checker.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tclogger import logger
|
| 2 |
+
from transformers import AutoTokenizer
|
| 3 |
+
|
| 4 |
+
from constants.models import MODEL_MAP, TOKEN_LIMIT_MAP, TOKEN_RESERVED
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class TokenChecker:
|
| 8 |
+
def __init__(self, input_str: str, model: str):
|
| 9 |
+
self.input_str = input_str
|
| 10 |
+
|
| 11 |
+
if model in MODEL_MAP.keys():
|
| 12 |
+
self.model = model
|
| 13 |
+
else:
|
| 14 |
+
self.model = "nous-mixtral-8x7b"
|
| 15 |
+
|
| 16 |
+
self.model_fullname = MODEL_MAP[self.model]
|
| 17 |
+
|
| 18 |
+
# As some models are gated, we need to fetch tokenizers from alternatives
|
| 19 |
+
GATED_MODEL_MAP = {
|
| 20 |
+
"llama3-70b": "NousResearch/Meta-Llama-3-70B",
|
| 21 |
+
"gemma-7b": "unsloth/gemma-7b",
|
| 22 |
+
"mistral-7b": "dfurman/Mistral-7B-Instruct-v0.2",
|
| 23 |
+
"mixtral-8x7b": "dfurman/Mixtral-8x7B-Instruct-v0.1",
|
| 24 |
+
}
|
| 25 |
+
if self.model in GATED_MODEL_MAP.keys():
|
| 26 |
+
self.tokenizer = AutoTokenizer.from_pretrained(GATED_MODEL_MAP[self.model])
|
| 27 |
+
else:
|
| 28 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_fullname)
|
| 29 |
+
|
| 30 |
+
def count_tokens(self):
|
| 31 |
+
token_count = len(self.tokenizer.encode(self.input_str))
|
| 32 |
+
logger.note(f"Prompt Token Count: {token_count}")
|
| 33 |
+
return token_count
|
| 34 |
+
|
| 35 |
+
def get_token_limit(self):
|
| 36 |
+
return TOKEN_LIMIT_MAP[self.model]
|
| 37 |
+
|
| 38 |
+
def get_token_redundancy(self):
|
| 39 |
+
return int(self.get_token_limit() - TOKEN_RESERVED - self.count_tokens())
|
| 40 |
+
|
| 41 |
+
def check_token_limit(self):
|
| 42 |
+
if self.get_token_redundancy() <= 0:
|
| 43 |
+
raise ValueError(
|
| 44 |
+
f"Prompt exceeded token limit: {self.count_tokens()} > {self.get_token_limit()}"
|
| 45 |
+
)
|
| 46 |
+
return True
|