Add debug option for RL dataset preprocessing (#1404)
Browse files* adding debug option for RL dataset preprocessing
* Refine formatting of debugging code in RL dataset preprocessing
* Update __init__.py
* chore: fix lint
---------
Co-authored-by: NanoCode012 <[email protected]>
- src/axolotl/cli/__init__.py +17 -0
- src/axolotl/utils/tokenization.py +58 -3
src/axolotl/cli/__init__.py
CHANGED
|
@@ -433,6 +433,23 @@ def load_rl_datasets(
|
|
| 433 |
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
| 434 |
)
|
| 435 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
return TrainDatasetMeta(
|
| 437 |
train_dataset=train_dataset,
|
| 438 |
eval_dataset=eval_dataset,
|
|
|
|
| 433 |
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
| 434 |
)
|
| 435 |
|
| 436 |
+
if cli_args.debug or cfg.debug:
|
| 437 |
+
LOG.info("check_dataset_labels...")
|
| 438 |
+
|
| 439 |
+
tokenizer = load_tokenizer(cfg)
|
| 440 |
+
check_dataset_labels(
|
| 441 |
+
train_dataset.select(
|
| 442 |
+
[
|
| 443 |
+
random.randrange(0, len(train_dataset) - 1) # nosec
|
| 444 |
+
for _ in range(cli_args.debug_num_examples)
|
| 445 |
+
]
|
| 446 |
+
),
|
| 447 |
+
tokenizer,
|
| 448 |
+
num_examples=cli_args.debug_num_examples,
|
| 449 |
+
text_only=cli_args.debug_text_only,
|
| 450 |
+
rl_mode=True,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
return TrainDatasetMeta(
|
| 454 |
train_dataset=train_dataset,
|
| 455 |
eval_dataset=eval_dataset,
|
src/axolotl/utils/tokenization.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
"""Module for tokenization utilities"""
|
| 2 |
|
| 3 |
-
|
| 4 |
import logging
|
| 5 |
import re
|
| 6 |
from typing import Dict, List
|
|
@@ -10,10 +9,19 @@ from termcolor import colored
|
|
| 10 |
LOG = logging.getLogger("axolotl")
|
| 11 |
|
| 12 |
|
| 13 |
-
def check_dataset_labels(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
# the dataset is already shuffled, so let's just check the first 5 elements
|
| 15 |
for idx in range(num_examples):
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
def check_example_labels(example, tokenizer, text_only=False):
|
|
@@ -40,6 +48,53 @@ def check_example_labels(example, tokenizer, text_only=False):
|
|
| 40 |
return " ".join(colored_tokens)
|
| 41 |
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
|
| 44 |
GLAIVE_TO_SHAREGPT_ROLE = {
|
| 45 |
"SYSTEM": "system",
|
|
|
|
| 1 |
"""Module for tokenization utilities"""
|
| 2 |
|
|
|
|
| 3 |
import logging
|
| 4 |
import re
|
| 5 |
from typing import Dict, List
|
|
|
|
| 9 |
LOG = logging.getLogger("axolotl")
|
| 10 |
|
| 11 |
|
| 12 |
+
def check_dataset_labels(
|
| 13 |
+
dataset,
|
| 14 |
+
tokenizer,
|
| 15 |
+
num_examples=5,
|
| 16 |
+
text_only=False,
|
| 17 |
+
rl_mode=False,
|
| 18 |
+
):
|
| 19 |
# the dataset is already shuffled, so let's just check the first 5 elements
|
| 20 |
for idx in range(num_examples):
|
| 21 |
+
if not rl_mode:
|
| 22 |
+
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
| 23 |
+
else:
|
| 24 |
+
check_rl_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
| 25 |
|
| 26 |
|
| 27 |
def check_example_labels(example, tokenizer, text_only=False):
|
|
|
|
| 48 |
return " ".join(colored_tokens)
|
| 49 |
|
| 50 |
|
| 51 |
+
def color_token_for_rl_debug(decoded_token, encoded_token, color, text_only):
|
| 52 |
+
"""Helper function to color tokens based on their type."""
|
| 53 |
+
colored_text = colored(decoded_token, color)
|
| 54 |
+
return (
|
| 55 |
+
colored_text
|
| 56 |
+
if text_only
|
| 57 |
+
else f"{colored_text}{colored(f'({encoded_token})', 'white')}"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
|
| 62 |
+
"""Helper function to process and color tokens."""
|
| 63 |
+
colored_tokens = [
|
| 64 |
+
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
|
| 65 |
+
for token in tokenizer.encode(tokens)
|
| 66 |
+
]
|
| 67 |
+
return colored_tokens
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def check_rl_example_labels(example, tokenizer, text_only=False):
|
| 71 |
+
field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected"
|
| 72 |
+
|
| 73 |
+
input_tokens = example[field_prompt]
|
| 74 |
+
labels_chosen, labels_rejected = example[field_chosen], example[field_rejected]
|
| 75 |
+
|
| 76 |
+
# Process and color each type of token
|
| 77 |
+
colored_tokens = process_tokens_for_rl_debug(
|
| 78 |
+
input_tokens, "yellow", tokenizer, text_only
|
| 79 |
+
)
|
| 80 |
+
colored_chosens = process_tokens_for_rl_debug(
|
| 81 |
+
labels_chosen, "green", tokenizer, text_only
|
| 82 |
+
)
|
| 83 |
+
colored_rejecteds = process_tokens_for_rl_debug(
|
| 84 |
+
labels_rejected, "red", tokenizer, text_only
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Create a delimiter based on text_only flag
|
| 88 |
+
delimiter = "" if text_only else " "
|
| 89 |
+
|
| 90 |
+
# Logging information
|
| 91 |
+
LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n")
|
| 92 |
+
LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
|
| 93 |
+
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
|
| 94 |
+
|
| 95 |
+
return delimiter.join(colored_tokens)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
|
| 99 |
GLAIVE_TO_SHAREGPT_ROLE = {
|
| 100 |
"SYSTEM": "system",
|