Feat/chatml add system message (#1117)
Browse files* add system message to template
* readme update
* added code to register new system message
* register chatml template for test
---------
Co-authored-by: Mads Henrichsen <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
    	
        README.md
    CHANGED
    
    | @@ -613,6 +613,8 @@ rl: | |
| 613 | 
             
            # Saves the desired chat template to the tokenizer_config.json for easier inferencing
         | 
| 614 | 
             
            # Currently supports chatml and inst (mistral/mixtral)
         | 
| 615 | 
             
            chat_template: chatml
         | 
|  | |
|  | |
| 616 | 
             
            # Axolotl attempts to save the dataset as an arrow after packing the data together so
         | 
| 617 | 
             
            # subsequent training attempts load faster, relative path
         | 
| 618 | 
             
            dataset_prepared_path: data/last_run_prepared
         | 
|  | |
| 613 | 
             
            # Saves the desired chat template to the tokenizer_config.json for easier inferencing
         | 
| 614 | 
             
            # Currently supports chatml and inst (mistral/mixtral)
         | 
| 615 | 
             
            chat_template: chatml
         | 
| 616 | 
            +
            # Changes the default system message
         | 
| 617 | 
            +
            default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
         | 
| 618 | 
             
            # Axolotl attempts to save the dataset as an arrow after packing the data together so
         | 
| 619 | 
             
            # subsequent training attempts load faster, relative path
         | 
| 620 | 
             
            dataset_prepared_path: data/last_run_prepared
         | 
    	
        src/axolotl/cli/preprocess.py
    CHANGED
    
    | @@ -18,6 +18,7 @@ from axolotl.cli import ( | |
| 18 | 
             
            )
         | 
| 19 | 
             
            from axolotl.common.cli import PreprocessCliArgs
         | 
| 20 | 
             
            from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
         | 
|  | |
| 21 |  | 
| 22 | 
             
            LOG = logging.getLogger("axolotl.cli.preprocess")
         | 
| 23 |  | 
| @@ -34,6 +35,12 @@ def do_cli(config: Path = Path("examples/"), **kwargs): | |
| 34 | 
             
                    return_remaining_strings=True
         | 
| 35 | 
             
                )
         | 
| 36 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 37 | 
             
                if not parsed_cfg.dataset_prepared_path:
         | 
| 38 | 
             
                    msg = (
         | 
| 39 | 
             
                        Fore.RED
         | 
|  | |
| 18 | 
             
            )
         | 
| 19 | 
             
            from axolotl.common.cli import PreprocessCliArgs
         | 
| 20 | 
             
            from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
         | 
| 21 | 
            +
            from axolotl.prompt_strategies.sharegpt import register_chatml_template
         | 
| 22 |  | 
| 23 | 
             
            LOG = logging.getLogger("axolotl.cli.preprocess")
         | 
| 24 |  | 
|  | |
| 35 | 
             
                    return_remaining_strings=True
         | 
| 36 | 
             
                )
         | 
| 37 |  | 
| 38 | 
            +
                if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message:
         | 
| 39 | 
            +
                    LOG.info(
         | 
| 40 | 
            +
                        f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
         | 
| 41 | 
            +
                    )
         | 
| 42 | 
            +
                    register_chatml_template(parsed_cfg.default_system_message)
         | 
| 43 | 
            +
             | 
| 44 | 
             
                if not parsed_cfg.dataset_prepared_path:
         | 
| 45 | 
             
                    msg = (
         | 
| 46 | 
             
                        Fore.RED
         | 
    	
        src/axolotl/cli/train.py
    CHANGED
    
    | @@ -18,6 +18,7 @@ from axolotl.cli import ( | |
| 18 | 
             
                print_axolotl_text_art,
         | 
| 19 | 
             
            )
         | 
| 20 | 
             
            from axolotl.common.cli import TrainerCliArgs
         | 
|  | |
| 21 | 
             
            from axolotl.train import train
         | 
| 22 |  | 
| 23 | 
             
            LOG = logging.getLogger("axolotl.cli.train")
         | 
| @@ -37,7 +38,12 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: | |
| 37 | 
             
                print_axolotl_text_art()
         | 
| 38 | 
             
                check_accelerate_default_config()
         | 
| 39 | 
             
                check_user_token()
         | 
| 40 | 
            -
                if cfg. | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 41 | 
             
                    dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
         | 
| 42 | 
             
                else:
         | 
| 43 | 
             
                    dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
         | 
|  | |
| 18 | 
             
                print_axolotl_text_art,
         | 
| 19 | 
             
            )
         | 
| 20 | 
             
            from axolotl.common.cli import TrainerCliArgs
         | 
| 21 | 
            +
            from axolotl.prompt_strategies.sharegpt import register_chatml_template
         | 
| 22 | 
             
            from axolotl.train import train
         | 
| 23 |  | 
| 24 | 
             
            LOG = logging.getLogger("axolotl.cli.train")
         | 
|  | |
| 38 | 
             
                print_axolotl_text_art()
         | 
| 39 | 
             
                check_accelerate_default_config()
         | 
| 40 | 
             
                check_user_token()
         | 
| 41 | 
            +
                if cfg.chat_template == "chatml" and cfg.default_system_message:
         | 
| 42 | 
            +
                    LOG.info(
         | 
| 43 | 
            +
                        f"ChatML set. Adding default system message: {cfg.default_system_message}"
         | 
| 44 | 
            +
                    )
         | 
| 45 | 
            +
                    register_chatml_template(cfg.default_system_message)
         | 
| 46 | 
            +
             | 
| 47 | 
             
                    dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
         | 
| 48 | 
             
                else:
         | 
| 49 | 
             
                    dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
         | 
    	
        src/axolotl/prompt_strategies/sharegpt.py
    CHANGED
    
    | @@ -6,16 +6,19 @@ from fastchat.conversation import Conversation, SeparatorStyle, register_conv_te | |
| 6 | 
             
            from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
         | 
| 7 | 
             
            from axolotl.prompters import ShareGPTPrompterV2
         | 
| 8 |  | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
| 13 | 
            -
                     | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 17 | 
             
                )
         | 
| 18 | 
            -
            )
         | 
| 19 |  | 
| 20 |  | 
| 21 | 
             
            def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
         | 
|  | |
| 6 | 
             
            from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
         | 
| 7 | 
             
            from axolotl.prompters import ShareGPTPrompterV2
         | 
| 8 |  | 
| 9 | 
            +
             | 
| 10 | 
            +
            def register_chatml_template(system_message=None):
         | 
| 11 | 
            +
                system_message = system_message or "You are a helpful assistant."
         | 
| 12 | 
            +
                register_conv_template(
         | 
| 13 | 
            +
                    Conversation(
         | 
| 14 | 
            +
                        name="chatml",
         | 
| 15 | 
            +
                        system_template="<|im_start|>system\n{system_message}",
         | 
| 16 | 
            +
                        system_message=system_message,
         | 
| 17 | 
            +
                        roles=["<|im_start|>user", "<|im_start|>assistant"],
         | 
| 18 | 
            +
                        sep_style=SeparatorStyle.CHATML,
         | 
| 19 | 
            +
                        sep="<|im_end|>",
         | 
| 20 | 
            +
                    )
         | 
| 21 | 
             
                )
         | 
|  | |
| 22 |  | 
| 23 |  | 
| 24 | 
             
            def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
         | 
    	
        src/axolotl/utils/chat_templates.py
    CHANGED
    
    | @@ -20,7 +20,7 @@ def chat_templates(user_choice: str): | |
| 20 |  | 
| 21 | 
             
                templates = {
         | 
| 22 | 
             
                    "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",  # I don't know what this one is called. Used by Mistral/Mixtral.
         | 
| 23 | 
            -
                    "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in  | 
| 24 | 
             
                }
         | 
| 25 |  | 
| 26 | 
             
                if user_choice in templates:
         | 
|  | |
| 20 |  | 
| 21 | 
             
                templates = {
         | 
| 22 | 
             
                    "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",  # I don't know what this one is called. Used by Mistral/Mixtral.
         | 
| 23 | 
            +
                    "chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
         | 
| 24 | 
             
                }
         | 
| 25 |  | 
| 26 | 
             
                if user_choice in templates:
         | 
    	
        src/axolotl/utils/models.py
    CHANGED
    
    | @@ -219,7 +219,13 @@ def load_tokenizer(cfg): | |
| 219 | 
             
                LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
         | 
| 220 |  | 
| 221 | 
             
                if cfg.chat_template:
         | 
| 222 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 223 | 
             
                else:
         | 
| 224 | 
             
                    LOG.info(
         | 
| 225 | 
             
                        "No Chat template selected. Consider adding a chat template for easier inference."
         | 
|  | |
| 219 | 
             
                LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
         | 
| 220 |  | 
| 221 | 
             
                if cfg.chat_template:
         | 
| 222 | 
            +
                    chat_template_string = chat_templates(cfg.chat_template)
         | 
| 223 | 
            +
                    if cfg.default_system_message and cfg.chat_template == "chatml":
         | 
| 224 | 
            +
                        chat_template_string = chat_template_string.replace(
         | 
| 225 | 
            +
                            "You are a helpful assistant.", cfg.default_system_message
         | 
| 226 | 
            +
                        )
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    tokenizer.chat_template = chat_template_string
         | 
| 229 | 
             
                else:
         | 
| 230 | 
             
                    LOG.info(
         | 
| 231 | 
             
                        "No Chat template selected. Consider adding a chat template for easier inference."
         | 
    	
        tests/prompt_strategies/test_sharegpt.py
    CHANGED
    
    | @@ -7,9 +7,14 @@ from tokenizers import AddedToken | |
| 7 | 
             
            from transformers import AutoTokenizer
         | 
| 8 |  | 
| 9 | 
             
            from axolotl.datasets import TokenizedPromptDataset
         | 
| 10 | 
            -
            from axolotl.prompt_strategies.sharegpt import  | 
|  | |
|  | |
|  | |
| 11 | 
             
            from axolotl.prompters import ShareGPTPrompterV2
         | 
| 12 |  | 
|  | |
|  | |
| 13 |  | 
| 14 | 
             
            @pytest.fixture(name="sharegpt_dataset")
         | 
| 15 | 
             
            def fixture_sharegpt_dataset():
         | 
|  | |
| 7 | 
             
            from transformers import AutoTokenizer
         | 
| 8 |  | 
| 9 | 
             
            from axolotl.datasets import TokenizedPromptDataset
         | 
| 10 | 
            +
            from axolotl.prompt_strategies.sharegpt import (
         | 
| 11 | 
            +
                SimpleShareGPTPromptTokenizingStrategy,
         | 
| 12 | 
            +
                register_chatml_template,
         | 
| 13 | 
            +
            )
         | 
| 14 | 
             
            from axolotl.prompters import ShareGPTPrompterV2
         | 
| 15 |  | 
| 16 | 
            +
            register_chatml_template()
         | 
| 17 | 
            +
             | 
| 18 |  | 
| 19 | 
             
            @pytest.fixture(name="sharegpt_dataset")
         | 
| 20 | 
             
            def fixture_sharegpt_dataset():
         | 
