|  | """Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class""" | 
					
						
						|  |  | 
					
						
						|  | import copy | 
					
						
						|  | import logging | 
					
						
						|  | from collections import defaultdict | 
					
						
						|  | from typing import Generator, List, Tuple | 
					
						
						|  |  | 
					
						
						|  | from axolotl.prompt_tokenizers import ( | 
					
						
						|  | PromptTokenizingStrategy, | 
					
						
						|  | parse_tokenized_to_result, | 
					
						
						|  | tokenize_prompt_default, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | LOG = logging.getLogger("axolotl") | 
					
						
						|  |  | 
					
						
						|  | IGNORE_TOKEN_ID = -100 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): | 
					
						
						|  | """ | 
					
						
						|  | Tokenizing strategy for Pygmalion. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | bot_prefix_token_ids: List[int] = [] | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, prompter, tokenizer, *args, **kwargs): | 
					
						
						|  | super().__init__(prompter, tokenizer, *args, **kwargs) | 
					
						
						|  | res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True) | 
					
						
						|  | self.bot_prefix_token_ids = res["input_ids"] | 
					
						
						|  |  | 
					
						
						|  | def tokenize_prompt(self, prompt): | 
					
						
						|  | result, current_len = tokenize_prompt_default() | 
					
						
						|  | for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])): | 
					
						
						|  | role, message = part | 
					
						
						|  | if role == "system": | 
					
						
						|  | prefix = "<|system|>" | 
					
						
						|  |  | 
					
						
						|  | if message.endswith("\n<START>"): | 
					
						
						|  | message = message[:-8] | 
					
						
						|  | res = self._tokenize( | 
					
						
						|  | prefix + "Persona: " + message.strip(), | 
					
						
						|  | add_eos_token=False, | 
					
						
						|  | strip_bos_token=False, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) | 
					
						
						|  | elif role == "human": | 
					
						
						|  | prefix = "<|user|>" | 
					
						
						|  | res = self._tokenize( | 
					
						
						|  | prefix + " " + message.strip(), | 
					
						
						|  | add_eos_token=False, | 
					
						
						|  | strip_bos_token=True, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) | 
					
						
						|  | elif role == "bot": | 
					
						
						|  | prefix = "<|model|>" | 
					
						
						|  | res = self._tokenize( | 
					
						
						|  | prefix + " " + message.strip(), | 
					
						
						|  | add_eos_token=True, | 
					
						
						|  | strip_bos_token=True, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | labels = [IGNORE_TOKEN_ID] * len(self.bot_prefix_token_ids) + [ | 
					
						
						|  | *copy.deepcopy(res["input_ids"]) | 
					
						
						|  | ][len(self.bot_prefix_token_ids) :] | 
					
						
						|  | else: | 
					
						
						|  | LOG.warning(f"unknown role in conversation: {role}") | 
					
						
						|  | res = defaultdict(lambda: []) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | result, current_len = parse_tokenized_to_result( | 
					
						
						|  | result, | 
					
						
						|  | current_len, | 
					
						
						|  | res, | 
					
						
						|  | labels, | 
					
						
						|  | pad_token_id=self.tokenizer.pad_token_id, | 
					
						
						|  | ) | 
					
						
						|  | return result | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class PygmalionPrompter: | 
					
						
						|  | """ | 
					
						
						|  | Prompter for Pygmalion. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, *args, **kwargs): | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  | def build_prompt( | 
					
						
						|  | self, source, *args, **kwargs | 
					
						
						|  | ) -> Generator[Tuple[str, str], None, None]: | 
					
						
						|  | for msg in source: | 
					
						
						|  | yield msg["role"], msg["value"] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load(tokenizer, cfg): | 
					
						
						|  | return PygmalionPromptTokenizingStrategy( | 
					
						
						|  | PygmalionPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len | 
					
						
						|  | ) | 
					
						
						|  |  |