improve: Enhance code readability of prompt_tokenizers.py (#707)
Browse files- src/axolotl/prompt_tokenizers.py +80 -107
src/axolotl/prompt_tokenizers.py
CHANGED
|
@@ -45,6 +45,8 @@ class PromptTokenizingStrategy(abc.ABC):
|
|
| 45 |
self.prompter = prompter
|
| 46 |
self.tokenizer: PreTrainedTokenizer = tokenizer
|
| 47 |
self.train_on_inputs = train_on_inputs
|
|
|
|
|
|
|
| 48 |
self.sequence_len = sequence_len
|
| 49 |
self.max_length = sequence_len
|
| 50 |
|
|
@@ -59,34 +61,31 @@ class PromptTokenizingStrategy(abc.ABC):
|
|
| 59 |
def _tokenize(
|
| 60 |
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
| 61 |
) -> BatchEncoding:
|
| 62 |
-
|
| 63 |
if not prompt:
|
| 64 |
LOG.warning("Empty text requested for tokenization.")
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
if len(result["input_ids"]) == 0:
|
| 75 |
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
|
|
|
|
|
|
|
| 76 |
if (
|
| 77 |
-
|
| 78 |
-
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
| 79 |
and len(result["input_ids"]) < self.max_length
|
| 80 |
and add_eos_token
|
| 81 |
):
|
| 82 |
result["input_ids"].append(self.tokenizer.eos_token_id)
|
| 83 |
result["attention_mask"].append(1)
|
| 84 |
|
| 85 |
-
if
|
| 86 |
-
len(result["input_ids"]) > 0
|
| 87 |
-
and result["input_ids"][0] == self.tokenizer.bos_token_id
|
| 88 |
-
and strip_bos_token
|
| 89 |
-
):
|
| 90 |
result["input_ids"] = result["input_ids"][1:]
|
| 91 |
result["attention_mask"] = result["attention_mask"][1:]
|
| 92 |
|
|
@@ -122,7 +121,7 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
| 122 |
if not self.train_on_inputs:
|
| 123 |
user_prompt_len = len(tokenized_prompt["input_ids"])
|
| 124 |
# TODO this could be sped up using numpy array slicing
|
| 125 |
-
tokenized_prompt["labels"] = [
|
| 126 |
tokenized_res_prompt = self._tokenize(
|
| 127 |
response, strip_bos_token=True, add_eos_token=True
|
| 128 |
)
|
|
@@ -270,7 +269,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
| 270 |
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
| 271 |
# TODO this could be sped up using numpy array slicing
|
| 272 |
tokenized_full_prompt["labels"] = [
|
| 273 |
-
|
| 274 |
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
|
| 275 |
|
| 276 |
return tokenized_full_prompt
|
|
@@ -334,6 +333,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
| 334 |
return prompt["conversations"]
|
| 335 |
|
| 336 |
def tokenize_prompt(self, prompt):
|
|
|
|
| 337 |
result, current_len = tokenize_prompt_default()
|
| 338 |
conversation: Conversation = (
|
| 339 |
self.prompter._conversation.copy() # pylint: disable=protected-access
|
|
@@ -355,62 +355,67 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
| 355 |
for _, part in enumerate(
|
| 356 |
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
| 357 |
):
|
| 358 |
-
if isinstance(part, tuple):
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
|
| 415 |
# pylint: disable=duplicate-code
|
| 416 |
result, current_len = parse_tokenized_to_result(
|
|
@@ -424,38 +429,6 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
| 424 |
except (KeyError, AssertionError, IndexError) as err:
|
| 425 |
raise InvalidDataException(str(err)) from err
|
| 426 |
|
| 427 |
-
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
| 428 |
-
if not prompt.strip():
|
| 429 |
-
LOG.warning("Empty text requested for tokenization.")
|
| 430 |
-
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
| 431 |
-
else:
|
| 432 |
-
result = self.tokenizer(
|
| 433 |
-
prompt,
|
| 434 |
-
truncation=True,
|
| 435 |
-
max_length=self.sequence_len,
|
| 436 |
-
padding=False,
|
| 437 |
-
return_tensors=None,
|
| 438 |
-
)
|
| 439 |
-
if (
|
| 440 |
-
len(result["input_ids"]) > 0
|
| 441 |
-
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
| 442 |
-
and len(result["input_ids"]) < self.sequence_len
|
| 443 |
-
and add_eos_token
|
| 444 |
-
):
|
| 445 |
-
result["input_ids"].append(self.tokenizer.eos_token_id)
|
| 446 |
-
result["attention_mask"].append(1)
|
| 447 |
-
|
| 448 |
-
if (
|
| 449 |
-
len(result["input_ids"]) > 0
|
| 450 |
-
and result["input_ids"][0] == self.tokenizer.bos_token_id
|
| 451 |
-
and strip_bos_token
|
| 452 |
-
):
|
| 453 |
-
result["input_ids"] = result["input_ids"][1:]
|
| 454 |
-
result["attention_mask"] = result["attention_mask"][1:]
|
| 455 |
-
|
| 456 |
-
result["labels"] = result["input_ids"].copy()
|
| 457 |
-
return result
|
| 458 |
-
|
| 459 |
|
| 460 |
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
|
| 461 |
"""
|
|
|
|
| 45 |
self.prompter = prompter
|
| 46 |
self.tokenizer: PreTrainedTokenizer = tokenizer
|
| 47 |
self.train_on_inputs = train_on_inputs
|
| 48 |
+
# sequence_len and max_length can be different for CompletionPromptTokenizingStrategy.
|
| 49 |
+
# TODO: Document how they are different.
|
| 50 |
self.sequence_len = sequence_len
|
| 51 |
self.max_length = sequence_len
|
| 52 |
|
|
|
|
| 61 |
def _tokenize(
|
| 62 |
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
| 63 |
) -> BatchEncoding:
|
| 64 |
+
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
| 65 |
if not prompt:
|
| 66 |
LOG.warning("Empty text requested for tokenization.")
|
| 67 |
+
return empty
|
| 68 |
+
|
| 69 |
+
result = self.tokenizer(
|
| 70 |
+
prompt,
|
| 71 |
+
truncation=True,
|
| 72 |
+
max_length=self.max_length,
|
| 73 |
+
padding=False,
|
| 74 |
+
return_tensors=None,
|
| 75 |
+
)
|
| 76 |
if len(result["input_ids"]) == 0:
|
| 77 |
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
|
| 78 |
+
return empty
|
| 79 |
+
|
| 80 |
if (
|
| 81 |
+
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
|
|
|
| 82 |
and len(result["input_ids"]) < self.max_length
|
| 83 |
and add_eos_token
|
| 84 |
):
|
| 85 |
result["input_ids"].append(self.tokenizer.eos_token_id)
|
| 86 |
result["attention_mask"].append(1)
|
| 87 |
|
| 88 |
+
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
result["input_ids"] = result["input_ids"][1:]
|
| 90 |
result["attention_mask"] = result["attention_mask"][1:]
|
| 91 |
|
|
|
|
| 121 |
if not self.train_on_inputs:
|
| 122 |
user_prompt_len = len(tokenized_prompt["input_ids"])
|
| 123 |
# TODO this could be sped up using numpy array slicing
|
| 124 |
+
tokenized_prompt["labels"] = [IGNORE_INDEX] * user_prompt_len
|
| 125 |
tokenized_res_prompt = self._tokenize(
|
| 126 |
response, strip_bos_token=True, add_eos_token=True
|
| 127 |
)
|
|
|
|
| 269 |
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
| 270 |
# TODO this could be sped up using numpy array slicing
|
| 271 |
tokenized_full_prompt["labels"] = [
|
| 272 |
+
IGNORE_INDEX
|
| 273 |
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
|
| 274 |
|
| 275 |
return tokenized_full_prompt
|
|
|
|
| 333 |
return prompt["conversations"]
|
| 334 |
|
| 335 |
def tokenize_prompt(self, prompt):
|
| 336 |
+
# Initial values. We will append to these as we go through the conversation.
|
| 337 |
result, current_len = tokenize_prompt_default()
|
| 338 |
conversation: Conversation = (
|
| 339 |
self.prompter._conversation.copy() # pylint: disable=protected-access
|
|
|
|
| 355 |
for _, part in enumerate(
|
| 356 |
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
| 357 |
):
|
| 358 |
+
if not isinstance(part, tuple):
|
| 359 |
+
LOG.warning(f"expected tuple, got {part}")
|
| 360 |
+
continue
|
| 361 |
+
|
| 362 |
+
user, assistant = conversation.roles
|
| 363 |
+
role, content = part
|
| 364 |
+
|
| 365 |
+
# Uses "in" because role contains extra characters
|
| 366 |
+
if user in role:
|
| 367 |
+
role = (
|
| 368 |
+
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
| 369 |
+
if role_remap
|
| 370 |
+
else role
|
| 371 |
+
)
|
| 372 |
+
turn = role + content
|
| 373 |
+
# this is still the user query, we should
|
| 374 |
+
if not content.strip():
|
| 375 |
+
LOG.warning(f"user turn has empty text: {prompt}")
|
| 376 |
+
res = self._tokenize(
|
| 377 |
+
turn,
|
| 378 |
+
add_eos_token=False,
|
| 379 |
+
strip_bos_token=True,
|
| 380 |
+
)
|
| 381 |
+
# everything from this is masked out from the labels
|
| 382 |
+
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
| 383 |
+
elif assistant in role:
|
| 384 |
+
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
| 385 |
+
role = (
|
| 386 |
+
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
| 387 |
+
if role_remap
|
| 388 |
+
else role
|
| 389 |
+
)
|
| 390 |
+
turn = role + content
|
| 391 |
+
# this should be the assistant response, should end with an eos token
|
| 392 |
+
if not content.strip():
|
| 393 |
+
LOG.warning(f"assistant turn has empty text: {prompt}")
|
| 394 |
+
res = self._tokenize(
|
| 395 |
+
turn,
|
| 396 |
+
add_eos_token=True,
|
| 397 |
+
strip_bos_token=True,
|
| 398 |
+
)
|
| 399 |
+
role_res = self._tokenize(
|
| 400 |
+
role.rstrip(),
|
| 401 |
+
add_eos_token=False,
|
| 402 |
+
strip_bos_token=True,
|
| 403 |
+
)
|
| 404 |
+
# not masked out from labels
|
| 405 |
+
labels = copy.deepcopy(res["input_ids"])
|
| 406 |
+
len_role = len(role_res["input_ids"])
|
| 407 |
+
labels[:len_role] = [IGNORE_TOKEN_ID] * min(len_role, len(labels))
|
| 408 |
+
elif role == "":
|
| 409 |
+
turn = content
|
| 410 |
+
# this is only ever the first part, should include the bos token and the user query
|
| 411 |
+
res = self._tokenize(
|
| 412 |
+
turn, add_eos_token=False, strip_bos_token=False
|
| 413 |
+
)
|
| 414 |
+
# everything from this is masked out from the labels
|
| 415 |
+
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
| 416 |
+
else:
|
| 417 |
+
LOG.warning(f"unhandled role: {role}")
|
| 418 |
+
continue
|
| 419 |
|
| 420 |
# pylint: disable=duplicate-code
|
| 421 |
result, current_len = parse_tokenized_to_result(
|
|
|
|
| 429 |
except (KeyError, AssertionError, IndexError) as err:
|
| 430 |
raise InvalidDataException(str(err)) from err
|
| 431 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
|
| 433 |
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
|
| 434 |
"""
|