Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						a1c704a
	
1
								Parent(s):
							
							c3931e9
								
Update with h2oGPT hash 05d3ad444971c24fb021ea80c27f867c7a953699
Browse files- client_test.py +4 -2
 - finetune.py +60 -10
 - generate.py +98 -83
 - gradio_runner.py +26 -10
 - prompter.py +6 -5
 - requirements.txt +1 -1
 - stopping.py +49 -6
 
    	
        client_test.py
    CHANGED
    
    | 
         @@ -53,13 +53,16 @@ def get_client(): 
     | 
|
| 53 | 
         | 
| 54 | 
         | 
| 55 | 
         
             
            def test_client_basic():
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 56 | 
         
             
                instruction = ''  # only for chat=True
         
     | 
| 57 | 
         
             
                iinput = ''  # only for chat=True
         
     | 
| 58 | 
         
             
                context = ''
         
     | 
| 59 | 
         
             
                # streaming output is supported, loops over and outputs each generation in streaming mode
         
     | 
| 60 | 
         
             
                # but leave stream_output=False for simple input/output mode
         
     | 
| 61 | 
         
             
                stream_output = False
         
     | 
| 62 | 
         
            -
                prompt_type = 'human_bot'
         
     | 
| 63 | 
         
             
                temperature = 0.1
         
     | 
| 64 | 
         
             
                top_p = 0.75
         
     | 
| 65 | 
         
             
                top_k = 40
         
     | 
| 
         @@ -73,7 +76,6 @@ def test_client_basic(): 
     | 
|
| 73 | 
         
             
                do_sample = True
         
     | 
| 74 | 
         
             
                # only these 2 below used if pass chat=False
         
     | 
| 75 | 
         
             
                chat = False
         
     | 
| 76 | 
         
            -
                instruction_nochat = "Who are you?"
         
     | 
| 77 | 
         
             
                iinput_nochat = ''
         
     | 
| 78 | 
         | 
| 79 | 
         
             
                args = [instruction,
         
     | 
| 
         | 
|
| 53 | 
         | 
| 54 | 
         | 
| 55 | 
         
             
            def test_client_basic():
         
     | 
| 56 | 
         
            +
                return run_client_basic(instruction_nochat='Who are you?', prompt_type='human_bot')
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            def run_client_basic(instruction_nochat, prompt_type):
         
     | 
| 60 | 
         
             
                instruction = ''  # only for chat=True
         
     | 
| 61 | 
         
             
                iinput = ''  # only for chat=True
         
     | 
| 62 | 
         
             
                context = ''
         
     | 
| 63 | 
         
             
                # streaming output is supported, loops over and outputs each generation in streaming mode
         
     | 
| 64 | 
         
             
                # but leave stream_output=False for simple input/output mode
         
     | 
| 65 | 
         
             
                stream_output = False
         
     | 
| 
         | 
|
| 66 | 
         
             
                temperature = 0.1
         
     | 
| 67 | 
         
             
                top_p = 0.75
         
     | 
| 68 | 
         
             
                top_k = 40
         
     | 
| 
         | 
|
| 76 | 
         
             
                do_sample = True
         
     | 
| 77 | 
         
             
                # only these 2 below used if pass chat=False
         
     | 
| 78 | 
         
             
                chat = False
         
     | 
| 
         | 
|
| 79 | 
         
             
                iinput_nochat = ''
         
     | 
| 80 | 
         | 
| 81 | 
         
             
                args = [instruction,
         
     | 
    	
        finetune.py
    CHANGED
    
    | 
         @@ -28,6 +28,8 @@ class PromptType(Enum): 
     | 
|
| 28 | 
         
             
                instruct_vicuna = 7
         
     | 
| 29 | 
         
             
                instruct_with_end = 8
         
     | 
| 30 | 
         
             
                human_bot_orig = 9
         
     | 
| 
         | 
|
| 
         | 
|
| 31 | 
         | 
| 32 | 
         | 
| 33 | 
         
             
            prompt_type_to_model_name = {
         
     | 
| 
         @@ -46,6 +48,14 @@ prompt_type_to_model_name = { 
     | 
|
| 46 | 
         
             
                    'philschmid/flan-t5-base-samsum',
         
     | 
| 47 | 
         
             
                    'gpt2',
         
     | 
| 48 | 
         
             
                    'distilgpt2',
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 49 | 
         
             
                ],
         
     | 
| 50 | 
         
             
                'instruct': [],
         
     | 
| 51 | 
         
             
                'instruct_with_end': ['databricks/dolly-v2-12b'],
         
     | 
| 
         @@ -61,14 +71,12 @@ prompt_type_to_model_name = { 
     | 
|
| 61 | 
         
             
                'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
         
     | 
| 62 | 
         
             
                'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b'],
         
     | 
| 63 | 
         
             
                'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
         
     | 
| 
         | 
|
| 64 | 
         
             
            }
         
     | 
| 65 | 
         | 
| 66 | 
         
             
            inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
         
     | 
| 67 | 
         
             
            inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
         
     | 
| 68 | 
         | 
| 69 | 
         
            -
            human = '<human>:'
         
     | 
| 70 | 
         
            -
            bot = "<bot>:"
         
     | 
| 71 | 
         
            -
             
     | 
| 72 | 
         
             
            prompt_types_strings = []
         
     | 
| 73 | 
         
             
            for p in PromptType:
         
     | 
| 74 | 
         
             
                prompt_types_strings.extend([p.name])
         
     | 
| 
         @@ -277,8 +285,13 @@ def train( 
     | 
|
| 277 | 
         
             
                            layer_norm_names=["layer_norm", "layernorm"],  # keep all layer norms in higher precision
         
     | 
| 278 | 
         
             
                        )
         
     | 
| 279 | 
         | 
| 280 | 
         
            -
                from peft import LoraConfig, get_peft_model, set_peft_model_state_dict 
     | 
| 281 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 282 | 
         
             
                lora_mappings['distilgpt2'] = ["c_attn"]
         
     | 
| 283 | 
         | 
| 284 | 
         
             
                if lora_weights:
         
     | 
| 
         @@ -730,10 +743,10 @@ def generate_and_tokenize_prompt(data_point, prompt_type=None, train_on_inputs=F 
     | 
|
| 730 | 
         
             
                assert prompt_type is not None
         
     | 
| 731 | 
         
             
                assert cutoff_len is not None
         
     | 
| 732 | 
         
             
                assert tokenizer is not None
         
     | 
| 733 | 
         
            -
                full_prompt, _, _ = generate_prompt(data_point, prompt_type, False, False)
         
     | 
| 734 | 
         
             
                tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
         
     | 
| 735 | 
         
             
                if not train_on_inputs:
         
     | 
| 736 | 
         
            -
                    user_prompt, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
         
     | 
| 737 | 
         
             
                    tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
         
     | 
| 738 | 
         
             
                    user_prompt_len = len(tokenized_user_prompt["input_ids"])
         
     | 
| 739 | 
         
             
                    if add_eos_token:
         
     | 
| 
         @@ -752,9 +765,11 @@ def get_prompt(prompt_type, chat, context, reduced): 
     | 
|
| 752 | 
         
             
                if prompt_type in [-1, "-1", "plain"]:
         
     | 
| 753 | 
         
             
                    promptA = promptB = PreInstruct = PreInput = PreResponse = ''
         
     | 
| 754 | 
         
             
                    terminate_response = []
         
     | 
| 
         | 
|
| 755 | 
         
             
                elif prompt_type == 'simple_instruct':
         
     | 
| 756 | 
         
             
                    promptA = promptB = PreInstruct = PreInput = PreResponse = None
         
     | 
| 757 | 
         
             
                    terminate_response = []
         
     | 
| 
         | 
|
| 758 | 
         
             
                elif prompt_type in [0, "0", "instruct"] or prompt_type in [7, "7", "instruct_with_end"]:
         
     | 
| 759 | 
         
             
                    promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
         
     | 
| 760 | 
         
             
                    promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
         
     | 
| 
         @@ -774,6 +789,7 @@ def get_prompt(prompt_type, chat, context, reduced): 
     | 
|
| 774 | 
         
             
                        terminate_response = ['### End']
         
     | 
| 775 | 
         
             
                    else:
         
     | 
| 776 | 
         
             
                        terminate_response = None
         
     | 
| 
         | 
|
| 777 | 
         
             
                elif prompt_type in [1, "1", "quality"]:
         
     | 
| 778 | 
         
             
                    promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (chat and reduced) else ''
         
     | 
| 779 | 
         
             
                    promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (chat and reduced) else ''
         
     | 
| 
         @@ -790,7 +806,10 @@ def get_prompt(prompt_type, chat, context, reduced): 
     | 
|
| 790 | 
         
             
            ### Response:
         
     | 
| 791 | 
         
             
            """
         
     | 
| 792 | 
         
             
                    terminate_response = None
         
     | 
| 
         | 
|
| 793 | 
         
             
                elif prompt_type in [2, "2", "human_bot", 9, "9", "human_bot_orig"]:
         
     | 
| 
         | 
|
| 
         | 
|
| 794 | 
         
             
                    if reduced or context or prompt_type in [2, "2", "human_bot"]:
         
     | 
| 795 | 
         
             
                        preprompt = ''
         
     | 
| 796 | 
         
             
                    else:
         
     | 
| 
         @@ -819,6 +838,7 @@ Current Time: {} 
     | 
|
| 819 | 
         
             
                        PreResponse = bot
         
     | 
| 820 | 
         | 
| 821 | 
         
             
                    terminate_response = [start, PreResponse]
         
     | 
| 
         | 
|
| 822 | 
         
             
                elif prompt_type in [3, "3", "dai_faq"]:
         
     | 
| 823 | 
         
             
                    promptA = ''
         
     | 
| 824 | 
         
             
                    promptB = 'Answer the following Driverless AI question.\n'
         
     | 
| 
         @@ -833,11 +853,13 @@ Current Time: {} 
     | 
|
| 833 | 
         
             
            ### Driverless AI documentation answer:
         
     | 
| 834 | 
         
             
            """
         
     | 
| 835 | 
         
             
                    terminate_response = ['\n\n']
         
     | 
| 
         | 
|
| 836 | 
         
             
                elif prompt_type in [5, "5", "summarize"]:
         
     | 
| 837 | 
         
             
                    promptA = promptB = PreInput = ''
         
     | 
| 838 | 
         
             
                    PreInstruct = '## Main Text\n\n'
         
     | 
| 839 | 
         
             
                    PreResponse = '\n\n## Summary\n\n'
         
     | 
| 840 | 
         
             
                    terminate_response = None
         
     | 
| 
         | 
|
| 841 | 
         
             
                elif prompt_type in [6, "6", "instruct_vicuna"]:
         
     | 
| 842 | 
         
             
                    promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
         
     | 
| 843 | 
         
             
                        "The assistant gives helpful, detailed, and polite answers to the human's questions." if not (chat and reduced) else ''
         
     | 
| 
         @@ -852,10 +874,37 @@ Current Time: {} 
     | 
|
| 852 | 
         
             
            ### Assistant:
         
     | 
| 853 | 
         
             
            """
         
     | 
| 854 | 
         
             
                    terminate_response = ['### Human:']  # but only allow terminate after prompt is found correctly, else can't terminate
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 855 | 
         
             
                else:
         
     | 
| 856 | 
         
             
                    raise RuntimeError("No such prompt_type=%s" % prompt_type)
         
     | 
| 857 | 
         | 
| 858 | 
         
            -
                return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response
         
     | 
| 859 | 
         | 
| 860 | 
         | 
| 861 | 
         
             
            def generate_prompt(data_point, prompt_type, chat, reduced):
         
     | 
| 
         @@ -867,7 +916,8 @@ def generate_prompt(data_point, prompt_type, chat, reduced): 
     | 
|
| 867 | 
         
             
                output = data_point.get('output')
         
     | 
| 868 | 
         
             
                prompt_type = data_point.get('prompt_type', prompt_type)
         
     | 
| 869 | 
         
             
                assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
         
     | 
| 870 | 
         
            -
                promptA, promptB, PreInstruct, PreInput, PreResponse,  
     | 
| 
         | 
|
| 871 | 
         | 
| 872 | 
         
             
                prompt = context if not reduced else ''
         
     | 
| 873 | 
         | 
| 
         @@ -919,7 +969,7 @@ def generate_prompt(data_point, prompt_type, chat, reduced): 
     | 
|
| 919 | 
         
             
                if output:
         
     | 
| 920 | 
         
             
                    prompt += f"""{output}"""
         
     | 
| 921 | 
         | 
| 922 | 
         
            -
                return prompt, pre_response, terminate_response
         
     | 
| 923 | 
         | 
| 924 | 
         | 
| 925 | 
         
             
            def inject_newline(prompt_type, prompt):
         
     | 
| 
         | 
|
| 28 | 
         
             
                instruct_vicuna = 7
         
     | 
| 29 | 
         
             
                instruct_with_end = 8
         
     | 
| 30 | 
         
             
                human_bot_orig = 9
         
     | 
| 31 | 
         
            +
                prompt_answer = 10
         
     | 
| 32 | 
         
            +
                open_assistant = 11
         
     | 
| 33 | 
         | 
| 34 | 
         | 
| 35 | 
         
             
            prompt_type_to_model_name = {
         
     | 
| 
         | 
|
| 48 | 
         
             
                    'philschmid/flan-t5-base-samsum',
         
     | 
| 49 | 
         
             
                    'gpt2',
         
     | 
| 50 | 
         
             
                    'distilgpt2',
         
     | 
| 51 | 
         
            +
                    'mosaicml/mpt-7b-storywriter',
         
     | 
| 52 | 
         
            +
                    'mosaicml/mpt-7b-instruct',  # internal code handles instruct
         
     | 
| 53 | 
         
            +
                    'mosaicml/mpt-7b-chat',  # NC, internal code handles instruct
         
     | 
| 54 | 
         
            +
                ],
         
     | 
| 55 | 
         
            +
                'prompt_answer': [
         
     | 
| 56 | 
         
            +
                    'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
         
     | 
| 57 | 
         
            +
                    'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
         
     | 
| 58 | 
         
            +
                    'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
         
     | 
| 59 | 
         
             
                ],
         
     | 
| 60 | 
         
             
                'instruct': [],
         
     | 
| 61 | 
         
             
                'instruct_with_end': ['databricks/dolly-v2-12b'],
         
     | 
| 
         | 
|
| 71 | 
         
             
                'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
         
     | 
| 72 | 
         
             
                'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b'],
         
     | 
| 73 | 
         
             
                'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
         
     | 
| 74 | 
         
            +
                "open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
         
     | 
| 75 | 
         
             
            }
         
     | 
| 76 | 
         | 
| 77 | 
         
             
            inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
         
     | 
| 78 | 
         
             
            inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
         
     | 
| 79 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 80 | 
         
             
            prompt_types_strings = []
         
     | 
| 81 | 
         
             
            for p in PromptType:
         
     | 
| 82 | 
         
             
                prompt_types_strings.extend([p.name])
         
     | 
| 
         | 
|
| 285 | 
         
             
                            layer_norm_names=["layer_norm", "layernorm"],  # keep all layer norms in higher precision
         
     | 
| 286 | 
         
             
                        )
         
     | 
| 287 | 
         | 
| 288 | 
         
            +
                from peft import LoraConfig, get_peft_model, set_peft_model_state_dict
         
     | 
| 289 | 
         
            +
                try:
         
     | 
| 290 | 
         
            +
                    from peft import utils
         
     | 
| 291 | 
         
            +
                    lora_mappings = utils.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
         
     | 
| 292 | 
         
            +
                except AttributeError:
         
     | 
| 293 | 
         
            +
                    from peft import mapping
         
     | 
| 294 | 
         
            +
                    lora_mappings = mapping.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
         
     | 
| 295 | 
         
             
                lora_mappings['distilgpt2'] = ["c_attn"]
         
     | 
| 296 | 
         | 
| 297 | 
         
             
                if lora_weights:
         
     | 
| 
         | 
|
| 743 | 
         
             
                assert prompt_type is not None
         
     | 
| 744 | 
         
             
                assert cutoff_len is not None
         
     | 
| 745 | 
         
             
                assert tokenizer is not None
         
     | 
| 746 | 
         
            +
                full_prompt, _, _, _ = generate_prompt(data_point, prompt_type, False, False)
         
     | 
| 747 | 
         
             
                tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
         
     | 
| 748 | 
         
             
                if not train_on_inputs:
         
     | 
| 749 | 
         
            +
                    user_prompt, _, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
         
     | 
| 750 | 
         
             
                    tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
         
     | 
| 751 | 
         
             
                    user_prompt_len = len(tokenized_user_prompt["input_ids"])
         
     | 
| 752 | 
         
             
                    if add_eos_token:
         
     | 
| 
         | 
|
| 765 | 
         
             
                if prompt_type in [-1, "-1", "plain"]:
         
     | 
| 766 | 
         
             
                    promptA = promptB = PreInstruct = PreInput = PreResponse = ''
         
     | 
| 767 | 
         
             
                    terminate_response = []
         
     | 
| 768 | 
         
            +
                    chat_sep = ''
         
     | 
| 769 | 
         
             
                elif prompt_type == 'simple_instruct':
         
     | 
| 770 | 
         
             
                    promptA = promptB = PreInstruct = PreInput = PreResponse = None
         
     | 
| 771 | 
         
             
                    terminate_response = []
         
     | 
| 772 | 
         
            +
                    chat_sep = '\n'
         
     | 
| 773 | 
         
             
                elif prompt_type in [0, "0", "instruct"] or prompt_type in [7, "7", "instruct_with_end"]:
         
     | 
| 774 | 
         
             
                    promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
         
     | 
| 775 | 
         
             
                    promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
         
     | 
| 
         | 
|
| 789 | 
         
             
                        terminate_response = ['### End']
         
     | 
| 790 | 
         
             
                    else:
         
     | 
| 791 | 
         
             
                        terminate_response = None
         
     | 
| 792 | 
         
            +
                    chat_sep = '\n'
         
     | 
| 793 | 
         
             
                elif prompt_type in [1, "1", "quality"]:
         
     | 
| 794 | 
         
             
                    promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (chat and reduced) else ''
         
     | 
| 795 | 
         
             
                    promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (chat and reduced) else ''
         
     | 
| 
         | 
|
| 806 | 
         
             
            ### Response:
         
     | 
| 807 | 
         
             
            """
         
     | 
| 808 | 
         
             
                    terminate_response = None
         
     | 
| 809 | 
         
            +
                    chat_sep = '\n'
         
     | 
| 810 | 
         
             
                elif prompt_type in [2, "2", "human_bot", 9, "9", "human_bot_orig"]:
         
     | 
| 811 | 
         
            +
                    human = '<human>:'
         
     | 
| 812 | 
         
            +
                    bot = "<bot>:"
         
     | 
| 813 | 
         
             
                    if reduced or context or prompt_type in [2, "2", "human_bot"]:
         
     | 
| 814 | 
         
             
                        preprompt = ''
         
     | 
| 815 | 
         
             
                    else:
         
     | 
| 
         | 
|
| 838 | 
         
             
                        PreResponse = bot
         
     | 
| 839 | 
         | 
| 840 | 
         
             
                    terminate_response = [start, PreResponse]
         
     | 
| 841 | 
         
            +
                    chat_sep = '\n'
         
     | 
| 842 | 
         
             
                elif prompt_type in [3, "3", "dai_faq"]:
         
     | 
| 843 | 
         
             
                    promptA = ''
         
     | 
| 844 | 
         
             
                    promptB = 'Answer the following Driverless AI question.\n'
         
     | 
| 
         | 
|
| 853 | 
         
             
            ### Driverless AI documentation answer:
         
     | 
| 854 | 
         
             
            """
         
     | 
| 855 | 
         
             
                    terminate_response = ['\n\n']
         
     | 
| 856 | 
         
            +
                    chat_sep = terminate_response
         
     | 
| 857 | 
         
             
                elif prompt_type in [5, "5", "summarize"]:
         
     | 
| 858 | 
         
             
                    promptA = promptB = PreInput = ''
         
     | 
| 859 | 
         
             
                    PreInstruct = '## Main Text\n\n'
         
     | 
| 860 | 
         
             
                    PreResponse = '\n\n## Summary\n\n'
         
     | 
| 861 | 
         
             
                    terminate_response = None
         
     | 
| 862 | 
         
            +
                    chat_sep = '\n'
         
     | 
| 863 | 
         
             
                elif prompt_type in [6, "6", "instruct_vicuna"]:
         
     | 
| 864 | 
         
             
                    promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
         
     | 
| 865 | 
         
             
                        "The assistant gives helpful, detailed, and polite answers to the human's questions." if not (chat and reduced) else ''
         
     | 
| 
         | 
|
| 874 | 
         
             
            ### Assistant:
         
     | 
| 875 | 
         
             
            """
         
     | 
| 876 | 
         
             
                    terminate_response = ['### Human:']  # but only allow terminate after prompt is found correctly, else can't terminate
         
     | 
| 877 | 
         
            +
                    chat_sep = '\n'
         
     | 
| 878 | 
         
            +
                elif prompt_type in [10, "10", "prompt_answer"]:
         
     | 
| 879 | 
         
            +
                    preprompt = ''
         
     | 
| 880 | 
         
            +
                    prompt_tokens = "<|prompt|>"
         
     | 
| 881 | 
         
            +
                    answer_tokens = "<|answer|>"
         
     | 
| 882 | 
         
            +
                    start = prompt_tokens
         
     | 
| 883 | 
         
            +
                    promptB = promptA = '%s%s' % (preprompt, start)
         
     | 
| 884 | 
         
            +
                    PreInstruct = ""
         
     | 
| 885 | 
         
            +
                    PreInput = None
         
     | 
| 886 | 
         
            +
                    PreResponse = answer_tokens
         
     | 
| 887 | 
         
            +
                    eos = '<|endoftext|>'  # neox eos
         
     | 
| 888 | 
         
            +
                    terminate_response = [start, PreResponse, eos]
         
     | 
| 889 | 
         
            +
                    chat_sep = eos
         
     | 
| 890 | 
         
            +
                elif prompt_type in [11, "11", "open_assistant"]:
         
     | 
| 891 | 
         
            +
                    # From added_tokens.json
         
     | 
| 892 | 
         
            +
                    preprompt = ''
         
     | 
| 893 | 
         
            +
                    prompt_tokens = "<|prompter|>"
         
     | 
| 894 | 
         
            +
                    answer_tokens = "<|assistant|>"
         
     | 
| 895 | 
         
            +
                    start = prompt_tokens
         
     | 
| 896 | 
         
            +
                    promptB = promptA = '%s%s' % (preprompt, start)
         
     | 
| 897 | 
         
            +
                    PreInstruct = ""
         
     | 
| 898 | 
         
            +
                    PreInput = None
         
     | 
| 899 | 
         
            +
                    PreResponse = answer_tokens
         
     | 
| 900 | 
         
            +
                    pend = "<|prefix_end|>"
         
     | 
| 901 | 
         
            +
                    eos = "</s>"
         
     | 
| 902 | 
         
            +
                    terminate_response = [start, PreResponse, pend, eos]
         
     | 
| 903 | 
         
            +
                    chat_sep = eos
         
     | 
| 904 | 
         
             
                else:
         
     | 
| 905 | 
         
             
                    raise RuntimeError("No such prompt_type=%s" % prompt_type)
         
     | 
| 906 | 
         | 
| 907 | 
         
            +
                return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response, chat_sep
         
     | 
| 908 | 
         | 
| 909 | 
         | 
| 910 | 
         
             
            def generate_prompt(data_point, prompt_type, chat, reduced):
         
     | 
| 
         | 
|
| 916 | 
         
             
                output = data_point.get('output')
         
     | 
| 917 | 
         
             
                prompt_type = data_point.get('prompt_type', prompt_type)
         
     | 
| 918 | 
         
             
                assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
         
     | 
| 919 | 
         
            +
                promptA, promptB, PreInstruct, PreInput, PreResponse, \
         
     | 
| 920 | 
         
            +
                terminate_response, chat_sep = get_prompt(prompt_type, chat, context, reduced)
         
     | 
| 921 | 
         | 
| 922 | 
         
             
                prompt = context if not reduced else ''
         
     | 
| 923 | 
         | 
| 
         | 
|
| 969 | 
         
             
                if output:
         
     | 
| 970 | 
         
             
                    prompt += f"""{output}"""
         
     | 
| 971 | 
         | 
| 972 | 
         
            +
                return prompt, pre_response, terminate_response, chat_sep
         
     | 
| 973 | 
         | 
| 974 | 
         | 
| 975 | 
         
             
            def inject_newline(prompt_type, prompt):
         
     | 
    	
        generate.py
    CHANGED
    
    | 
         @@ -9,7 +9,7 @@ from datetime import datetime 
     | 
|
| 9 | 
         
             
            import filelock
         
     | 
| 10 | 
         
             
            import psutil
         
     | 
| 11 | 
         | 
| 12 | 
         
            -
            from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread
         
     | 
| 13 | 
         | 
| 14 | 
         
             
            SEED = 1236
         
     | 
| 15 | 
         
             
            set_seed(SEED)
         
     | 
| 
         @@ -22,13 +22,13 @@ import pandas as pd 
     | 
|
| 22 | 
         
             
            import fire
         
     | 
| 23 | 
         
             
            import torch
         
     | 
| 24 | 
         
             
            from peft import PeftModel
         
     | 
| 25 | 
         
            -
            from transformers import GenerationConfig,  
     | 
| 26 | 
         
             
            from accelerate import init_empty_weights, infer_auto_device_map
         
     | 
| 27 | 
         | 
| 28 | 
         
             
            from prompter import Prompter
         
     | 
| 29 | 
         | 
| 30 | 
         
            -
            from finetune import get_loaders, example_data_points, generate_prompt,  
     | 
| 31 | 
         
            -
            from stopping import  
     | 
| 32 | 
         | 
| 33 | 
         
             
            eval_extra_columns = ['prompt', 'response', 'score']
         
     | 
| 34 | 
         | 
| 
         @@ -62,6 +62,7 @@ def main( 
     | 
|
| 62 | 
         
             
                    local_files_only: bool = False,
         
     | 
| 63 | 
         
             
                    resume_download: bool = True,
         
     | 
| 64 | 
         
             
                    use_auth_token: Union[str, bool] = False,
         
     | 
| 
         | 
|
| 65 | 
         | 
| 66 | 
         
             
                    src_lang: str = "English",
         
     | 
| 67 | 
         
             
                    tgt_lang: str = "Russian",
         
     | 
| 
         @@ -124,6 +125,7 @@ def main( 
     | 
|
| 124 | 
         
             
                :param local_files_only: whether to only use local files instead of doing to HF for models
         
     | 
| 125 | 
         
             
                :param resume_download: whether to resume downloads from HF for models
         
     | 
| 126 | 
         
             
                :param use_auth_token: whether to use HF auth token (requires CLI did huggingface-cli login before)
         
     | 
| 
         | 
|
| 127 | 
         
             
                :param src_lang: source languages to include if doing translation (None = all)
         
     | 
| 128 | 
         
             
                :param tgt_lang: target languages to include if doing translation (None = all)
         
     | 
| 129 | 
         
             
                :param gradio: whether to enable gradio, or to enable benchmark mode
         
     | 
| 
         @@ -168,15 +170,22 @@ def main( 
     | 
|
| 168 | 
         | 
| 169 | 
         
             
                if is_public:
         
     | 
| 170 | 
         
             
                    input_lines = 1  # ensure set, for ease of use
         
     | 
| 171 | 
         
            -
                    temperature = 0.2
         
     | 
| 172 | 
         
            -
                    top_p = 0.85
         
     | 
| 173 | 
         
            -
                    top_k = 70
         
     | 
| 174 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 175 | 
         
             
                    if is_low_mem:
         
     | 
| 176 | 
         
            -
                         
     | 
| 177 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 178 | 
         
             
                    else:
         
     | 
| 179 | 
         
            -
                        base_model = 'h2oai/h2ogpt-oasst1-512-20b'
         
     | 
| 180 | 
         
             
                if is_low_mem:
         
     | 
| 181 | 
         
             
                    load_8bit = True
         
     | 
| 182 | 
         
             
                if is_hf:
         
     | 
| 
         @@ -229,6 +238,11 @@ def main( 
     | 
|
| 229 | 
         
             
                                        do_sample,
         
     | 
| 230 | 
         
             
                                        )
         
     | 
| 231 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 232 | 
         
             
                if not gradio:
         
     | 
| 233 | 
         
             
                    if eval_sharegpt_prompts_only > 0:
         
     | 
| 234 | 
         
             
                        # override default examples with shareGPT ones for human-level eval purposes only
         
     | 
| 
         @@ -416,7 +430,11 @@ def get_device(): 
     | 
|
| 416 | 
         | 
| 417 | 
         
             
            def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
         
     | 
| 418 | 
         
             
                                   gpu_id=0,
         
     | 
| 419 | 
         
            -
                                   use_auth_token=False 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 420 | 
         
             
                """
         
     | 
| 421 | 
         
             
                Ensure model gets on correct device
         
     | 
| 422 | 
         
             
                :param base_model:
         
     | 
| 
         @@ -426,29 +444,47 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward 
     | 
|
| 426 | 
         
             
                :param reward_type:
         
     | 
| 427 | 
         
             
                :param gpu_id:
         
     | 
| 428 | 
         
             
                :param use_auth_token:
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 429 | 
         
             
                :return:
         
     | 
| 430 | 
         
             
                """
         
     | 
| 431 | 
         
             
                with init_empty_weights():
         
     | 
| 432 | 
         
             
                    from transformers import AutoConfig
         
     | 
| 433 | 
         
            -
                    config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token 
     | 
| 434 | 
         
            -
             
     | 
| 435 | 
         
            -
             
     | 
| 436 | 
         
            -
             
     | 
| 437 | 
         
            -
             
     | 
| 438 | 
         
            -
             
     | 
| 439 | 
         
            -
             
     | 
| 440 | 
         
            -
             
     | 
| 441 | 
         
            -
             
     | 
| 442 | 
         
            -
                     
     | 
| 443 | 
         
            -
             
     | 
| 444 | 
         
            -
             
     | 
| 445 | 
         
            -
             
     | 
| 446 | 
         
            -
                     
     | 
| 447 | 
         
            -
                         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 448 | 
         
             
                        dtype=torch.float16 if load_half else torch.float32,
         
     | 
| 449 | 
         
             
                    )
         
     | 
| 450 | 
         
            -
                     
     | 
| 451 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 452 | 
         | 
| 453 | 
         
             
                n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
         
     | 
| 454 | 
         | 
| 
         @@ -472,11 +508,13 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward 
     | 
|
| 472 | 
         
             
                if load_in_8bit or not load_half:
         
     | 
| 473 | 
         
             
                    model = model_loader.from_pretrained(
         
     | 
| 474 | 
         
             
                        base_model,
         
     | 
| 
         | 
|
| 475 | 
         
             
                        **model_kwargs,
         
     | 
| 476 | 
         
             
                    )
         
     | 
| 477 | 
         
             
                else:
         
     | 
| 478 | 
         
             
                    model = model_loader.from_pretrained(
         
     | 
| 479 | 
         
             
                        base_model,
         
     | 
| 
         | 
|
| 480 | 
         
             
                        **model_kwargs,
         
     | 
| 481 | 
         
             
                    ).half()
         
     | 
| 482 | 
         
             
                return model
         
     | 
| 
         @@ -495,6 +533,7 @@ def get_model( 
     | 
|
| 495 | 
         
             
                    local_files_only: bool = False,
         
     | 
| 496 | 
         
             
                    resume_download: bool = True,
         
     | 
| 497 | 
         
             
                    use_auth_token: Union[str, bool] = False,
         
     | 
| 
         | 
|
| 498 | 
         
             
                    compile: bool = True,
         
     | 
| 499 | 
         
             
                    **kwargs,
         
     | 
| 500 | 
         
             
            ):
         
     | 
| 
         @@ -513,6 +552,7 @@ def get_model( 
     | 
|
| 513 | 
         
             
                :param local_files_only: use local files instead of from HF
         
     | 
| 514 | 
         
             
                :param resume_download: resume downloads from HF
         
     | 
| 515 | 
         
             
                :param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
         
     | 
| 
         | 
|
| 516 | 
         
             
                :param compile: whether to compile torch model
         
     | 
| 517 | 
         
             
                :param kwargs:
         
     | 
| 518 | 
         
             
                :return:
         
     | 
| 
         @@ -531,7 +571,8 @@ def get_model( 
     | 
|
| 531 | 
         
             
                )
         
     | 
| 532 | 
         | 
| 533 | 
         
             
                from transformers import AutoConfig
         
     | 
| 534 | 
         
            -
                config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token 
     | 
| 
         | 
|
| 535 | 
         
             
                llama_type_from_config = 'llama' in str(config).lower()
         
     | 
| 536 | 
         
             
                llama_type_from_name = "llama" in base_model.lower()
         
     | 
| 537 | 
         
             
                llama_type = llama_type_from_config or llama_type_from_name
         
     | 
| 
         @@ -548,6 +589,7 @@ def get_model( 
     | 
|
| 548 | 
         
             
                                                                 local_files_only=local_files_only,
         
     | 
| 549 | 
         
             
                                                                 resume_download=resume_download,
         
     | 
| 550 | 
         
             
                                                                 use_auth_token=use_auth_token,
         
     | 
| 
         | 
|
| 551 | 
         
             
                                                                 )
         
     | 
| 552 | 
         
             
                else:
         
     | 
| 553 | 
         
             
                    tokenizer = tokenizer_loader
         
     | 
| 
         @@ -563,13 +605,18 @@ def get_model( 
     | 
|
| 563 | 
         
             
                    model_kwargs = dict(local_files_only=local_files_only,
         
     | 
| 564 | 
         
             
                                        torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
         
     | 
| 565 | 
         
             
                                        resume_download=resume_download,
         
     | 
| 566 | 
         
            -
                                        use_auth_token=use_auth_token 
     | 
| 567 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 568 | 
         
             
                        model_kwargs.update(dict(load_in_8bit=load_8bit,
         
     | 
| 569 | 
         
             
                                                 device_map={"": 0} if load_8bit and device == 'cuda' else "auto",
         
     | 
| 570 | 
         
             
                                                 ))
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 571 | 
         
             
                    if 'OpenAssistant/reward-model'.lower() in base_model.lower():
         
     | 
| 572 | 
         
            -
                        # could put on other GPUs
         
     | 
| 573 | 
         
             
                        model_kwargs['device_map'] = {"": 0} if device == 'cuda' else {"": 'cpu'}
         
     | 
| 574 | 
         
             
                        model_kwargs.pop('torch_dtype', None)
         
     | 
| 575 | 
         | 
| 
         @@ -577,7 +624,10 @@ def get_model( 
     | 
|
| 577 | 
         
             
                        with torch.device(device):
         
     | 
| 578 | 
         
             
                            if infer_devices:
         
     | 
| 579 | 
         
             
                                model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
         
     | 
| 580 | 
         
            -
                                                           gpu_id=gpu_id, 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 581 | 
         
             
                            else:
         
     | 
| 582 | 
         
             
                                if load_half and not load_8bit:
         
     | 
| 583 | 
         
             
                                    model = model_loader.from_pretrained(
         
     | 
| 
         @@ -599,6 +649,7 @@ def get_model( 
     | 
|
| 599 | 
         
             
                            local_files_only=local_files_only,
         
     | 
| 600 | 
         
             
                            resume_download=resume_download,
         
     | 
| 601 | 
         
             
                            use_auth_token=use_auth_token,
         
     | 
| 
         | 
|
| 602 | 
         
             
                            device_map={"": 0} if device == 'cuda' else {"": 'cpu'},  # seems to be required
         
     | 
| 603 | 
         
             
                        )
         
     | 
| 604 | 
         
             
                    else:
         
     | 
| 
         @@ -614,6 +665,7 @@ def get_model( 
     | 
|
| 614 | 
         
             
                                local_files_only=local_files_only,
         
     | 
| 615 | 
         
             
                                resume_download=resume_download,
         
     | 
| 616 | 
         
             
                                use_auth_token=use_auth_token,
         
     | 
| 
         | 
|
| 617 | 
         
             
                                device_map="auto",
         
     | 
| 618 | 
         
             
                            )
         
     | 
| 619 | 
         
             
                            if load_half:
         
     | 
| 
         @@ -782,49 +834,7 @@ def evaluate( 
     | 
|
| 782 | 
         
             
                if chat:
         
     | 
| 783 | 
         
             
                    # override, ignore user change
         
     | 
| 784 | 
         
             
                    num_return_sequences = 1
         
     | 
| 785 | 
         
            -
                 
     | 
| 786 | 
         
            -
                    if prompt_type == 'human_bot':
         
     | 
| 787 | 
         
            -
                        # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
         
     | 
| 788 | 
         
            -
                        # stopping only starts once output is beyond prompt
         
     | 
| 789 | 
         
            -
                        # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
         
     | 
| 790 | 
         
            -
                        stop_words = [human, bot, '\n' + human, '\n' + bot]
         
     | 
| 791 | 
         
            -
                        encounters = [1, 2]
         
     | 
| 792 | 
         
            -
                    elif prompt_type == 'instruct_vicuna':
         
     | 
| 793 | 
         
            -
                        # even below is not enough, generic strings and many ways to encode
         
     | 
| 794 | 
         
            -
                        stop_words = [
         
     | 
| 795 | 
         
            -
                            '### Human:',
         
     | 
| 796 | 
         
            -
                            """
         
     | 
| 797 | 
         
            -
            ### Human:""",
         
     | 
| 798 | 
         
            -
                            """
         
     | 
| 799 | 
         
            -
            ### Human:
         
     | 
| 800 | 
         
            -
            """,
         
     | 
| 801 | 
         
            -
                            '### Assistant:',
         
     | 
| 802 | 
         
            -
                            """
         
     | 
| 803 | 
         
            -
            ### Assistant:""",
         
     | 
| 804 | 
         
            -
                            """
         
     | 
| 805 | 
         
            -
            ### Assistant:
         
     | 
| 806 | 
         
            -
            """,
         
     | 
| 807 | 
         
            -
                        ]
         
     | 
| 808 | 
         
            -
                        encounters = [1, 2]
         
     | 
| 809 | 
         
            -
                    else:
         
     | 
| 810 | 
         
            -
                        # some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
         
     | 
| 811 | 
         
            -
                        stop_words = ['### End']
         
     | 
| 812 | 
         
            -
                        encounters = [1]
         
     | 
| 813 | 
         
            -
                    stop_words_ids = [
         
     | 
| 814 | 
         
            -
                        tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
         
     | 
| 815 | 
         
            -
                    # handle single token case
         
     | 
| 816 | 
         
            -
                    stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
         
     | 
| 817 | 
         
            -
                    stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
         
     | 
| 818 | 
         
            -
                    # avoid padding in front of tokens
         
     | 
| 819 | 
         
            -
                    if tokenizer.pad_token:
         
     | 
| 820 | 
         
            -
                        stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
         
     | 
| 821 | 
         
            -
                    # handle fake \n added
         
     | 
| 822 | 
         
            -
                    stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
         
     | 
| 823 | 
         
            -
                    # build stopper
         
     | 
| 824 | 
         
            -
                    stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device)])
         
     | 
| 825 | 
         
            -
                else:
         
     | 
| 826 | 
         
            -
                    stopping_criteria = StoppingCriteriaList()
         
     | 
| 827 | 
         
            -
             
     | 
| 828 | 
         
             
                # help to avoid errors like:
         
     | 
| 829 | 
         
             
                # RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
         
     | 
| 830 | 
         
             
                # RuntimeError: expected scalar type Half but found Float
         
     | 
| 
         @@ -903,7 +913,10 @@ def evaluate( 
     | 
|
| 903 | 
         
             
                                prompt = inputs_decoded
         
     | 
| 904 | 
         
             
                            elif inputs_decoded_raw == prompt:
         
     | 
| 905 | 
         
             
                                # some models specify special tokens that are part of normal prompt, so can't skip them
         
     | 
| 906 | 
         
            -
                                 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 907 | 
         
             
                                decoder = decoder_raw
         
     | 
| 908 | 
         
             
                            else:
         
     | 
| 909 | 
         
             
                                print("WARNING: Special characters in prompt", flush=True)
         
     | 
| 
         @@ -1046,6 +1059,7 @@ def get_generate_params(model_lower, chat, 
     | 
|
| 1046 | 
         | 
| 1047 | 
         
             
                if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
         
     | 
| 1048 | 
         
             
                    prompt_type = inv_prompt_type_to_model_lower[model_lower]
         
     | 
| 
         | 
|
| 1049 | 
         | 
| 1050 | 
         
             
                # examples at first don't include chat, instruction_nochat, iinput_nochat, added at end
         
     | 
| 1051 | 
         
             
                if show_examples is None:
         
     | 
| 
         @@ -1104,7 +1118,8 @@ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-pa 
     | 
|
| 1104 | 
         
             
                        placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter."
         
     | 
| 1105 | 
         
             
                    placeholder_input = ""
         
     | 
| 1106 | 
         
             
                    if model_lower:
         
     | 
| 1107 | 
         
            -
                         
     | 
| 
         | 
|
| 1108 | 
         
             
                    else:
         
     | 
| 1109 | 
         
             
                        prompt_type = ''
         
     | 
| 1110 | 
         
             
                    examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else '', "",
         
     | 
| 
         @@ -1133,9 +1148,9 @@ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-pa 
     | 
|
| 1133 | 
         
             
                    num_return_sequences = min(num_beams, num_return_sequences or 1)
         
     | 
| 1134 | 
         
             
                    do_sample = False if do_sample is None else do_sample
         
     | 
| 1135 | 
         
             
                else:
         
     | 
| 1136 | 
         
            -
                    temperature = 0. 
     | 
| 1137 | 
         
            -
                    top_p = 0. 
     | 
| 1138 | 
         
            -
                    top_k =  
     | 
| 1139 | 
         
             
                    if chat:
         
     | 
| 1140 | 
         
             
                        num_beams = num_beams or 1
         
     | 
| 1141 | 
         
             
                    else:
         
     | 
| 
         @@ -1143,7 +1158,7 @@ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-pa 
     | 
|
| 1143 | 
         
             
                    max_new_tokens = max_new_tokens or 256
         
     | 
| 1144 | 
         
             
                    repetition_penalty = repetition_penalty or 1.07
         
     | 
| 1145 | 
         
             
                    num_return_sequences = min(num_beams, num_return_sequences or 1)
         
     | 
| 1146 | 
         
            -
                    do_sample =  
     | 
| 1147 | 
         
             
                # doesn't include chat, instruction_nochat, iinput_nochat, added later
         
     | 
| 1148 | 
         
             
                params_list = ["", stream_output, prompt_type, temperature, top_p, top_k, num_beams, max_new_tokens, min_new_tokens,
         
     | 
| 1149 | 
         
             
                               early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample]
         
     | 
| 
         | 
|
| 9 | 
         
             
            import filelock
         
     | 
| 10 | 
         
             
            import psutil
         
     | 
| 11 | 
         | 
| 12 | 
         
            +
            from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash
         
     | 
| 13 | 
         | 
| 14 | 
         
             
            SEED = 1236
         
     | 
| 15 | 
         
             
            set_seed(SEED)
         
     | 
| 
         | 
|
| 22 | 
         
             
            import fire
         
     | 
| 23 | 
         
             
            import torch
         
     | 
| 24 | 
         
             
            from peft import PeftModel
         
     | 
| 25 | 
         
            +
            from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
         
     | 
| 26 | 
         
             
            from accelerate import init_empty_weights, infer_auto_device_map
         
     | 
| 27 | 
         | 
| 28 | 
         
             
            from prompter import Prompter
         
     | 
| 29 | 
         | 
| 30 | 
         
            +
            from finetune import get_loaders, example_data_points, generate_prompt, inv_prompt_type_to_model_lower
         
     | 
| 31 | 
         
            +
            from stopping import get_stopping
         
     | 
| 32 | 
         | 
| 33 | 
         
             
            eval_extra_columns = ['prompt', 'response', 'score']
         
     | 
| 34 | 
         | 
| 
         | 
|
| 62 | 
         
             
                    local_files_only: bool = False,
         
     | 
| 63 | 
         
             
                    resume_download: bool = True,
         
     | 
| 64 | 
         
             
                    use_auth_token: Union[str, bool] = False,
         
     | 
| 65 | 
         
            +
                    trust_remote_code: Union[str, bool] = True,
         
     | 
| 66 | 
         | 
| 67 | 
         
             
                    src_lang: str = "English",
         
     | 
| 68 | 
         
             
                    tgt_lang: str = "Russian",
         
     | 
| 
         | 
|
| 125 | 
         
             
                :param local_files_only: whether to only use local files instead of doing to HF for models
         
     | 
| 126 | 
         
             
                :param resume_download: whether to resume downloads from HF for models
         
     | 
| 127 | 
         
             
                :param use_auth_token: whether to use HF auth token (requires CLI did huggingface-cli login before)
         
     | 
| 128 | 
         
            +
                :param trust_remote_code: whether to use trust any code needed for HF model
         
     | 
| 129 | 
         
             
                :param src_lang: source languages to include if doing translation (None = all)
         
     | 
| 130 | 
         
             
                :param tgt_lang: target languages to include if doing translation (None = all)
         
     | 
| 131 | 
         
             
                :param gradio: whether to enable gradio, or to enable benchmark mode
         
     | 
| 
         | 
|
| 170 | 
         | 
| 171 | 
         
             
                if is_public:
         
     | 
| 172 | 
         
             
                    input_lines = 1  # ensure set, for ease of use
         
     | 
| 173 | 
         
            +
                    temperature = 0.2 if temperature is None else temperature
         
     | 
| 174 | 
         
            +
                    top_p = 0.85 if top_p is None else top_p
         
     | 
| 175 | 
         
            +
                    top_k = 70 if top_k is None else top_k
         
     | 
| 176 | 
         
            +
                    if is_hf:
         
     | 
| 177 | 
         
            +
                        do_sample = True if do_sample is None else do_sample
         
     | 
| 178 | 
         
            +
                    else:
         
     | 
| 179 | 
         
            +
                        # by default don't sample, too chatty
         
     | 
| 180 | 
         
            +
                        do_sample = False if do_sample is None else do_sample
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
             
                    if is_low_mem:
         
     | 
| 183 | 
         
            +
                        if not base_model:
         
     | 
| 184 | 
         
            +
                            base_model = 'h2oai/h2ogpt-oasst1-512-12b'
         
     | 
| 185 | 
         
            +
                            # don't set load_8bit if passed base_model, doesn't always work so can't just override
         
     | 
| 186 | 
         
            +
                            load_8bit = True
         
     | 
| 187 | 
         
             
                    else:
         
     | 
| 188 | 
         
            +
                        base_model = 'h2oai/h2ogpt-oasst1-512-20b' if not base_model else base_model
         
     | 
| 189 | 
         
             
                if is_low_mem:
         
     | 
| 190 | 
         
             
                    load_8bit = True
         
     | 
| 191 | 
         
             
                if is_hf:
         
     | 
| 
         | 
|
| 238 | 
         
             
                                        do_sample,
         
     | 
| 239 | 
         
             
                                        )
         
     | 
| 240 | 
         | 
| 241 | 
         
            +
                locals_dict = locals()
         
     | 
| 242 | 
         
            +
                locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
         
     | 
| 243 | 
         
            +
                print(f"Generating model with params:\n{locals_print}", flush=True)
         
     | 
| 244 | 
         
            +
                print("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()), flush=True)
         
     | 
| 245 | 
         
            +
             
     | 
| 246 | 
         
             
                if not gradio:
         
     | 
| 247 | 
         
             
                    if eval_sharegpt_prompts_only > 0:
         
     | 
| 248 | 
         
             
                        # override default examples with shareGPT ones for human-level eval purposes only
         
     | 
| 
         | 
|
| 430 | 
         | 
| 431 | 
         
             
            def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
         
     | 
| 432 | 
         
             
                                   gpu_id=0,
         
     | 
| 433 | 
         
            +
                                   use_auth_token=False,
         
     | 
| 434 | 
         
            +
                                   trust_remote_code=True,
         
     | 
| 435 | 
         
            +
                                   triton_attn=False,
         
     | 
| 436 | 
         
            +
                                   long_sequence=True,
         
     | 
| 437 | 
         
            +
                                   ):
         
     | 
| 438 | 
         
             
                """
         
     | 
| 439 | 
         
             
                Ensure model gets on correct device
         
     | 
| 440 | 
         
             
                :param base_model:
         
     | 
| 
         | 
|
| 444 | 
         
             
                :param reward_type:
         
     | 
| 445 | 
         
             
                :param gpu_id:
         
     | 
| 446 | 
         
             
                :param use_auth_token:
         
     | 
| 447 | 
         
            +
                :param trust_remote_code:
         
     | 
| 448 | 
         
            +
                :param triton_attn:
         
     | 
| 449 | 
         
            +
                :param long_sequence:
         
     | 
| 450 | 
         
             
                :return:
         
     | 
| 451 | 
         
             
                """
         
     | 
| 452 | 
         
             
                with init_empty_weights():
         
     | 
| 453 | 
         
             
                    from transformers import AutoConfig
         
     | 
| 454 | 
         
            +
                    config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
         
     | 
| 455 | 
         
            +
                                                        trust_remote_code=trust_remote_code)
         
     | 
| 456 | 
         
            +
                    if triton_attn and 'mpt-' in base_model.lower():
         
     | 
| 457 | 
         
            +
                        config.attn_config['attn_impl'] = 'triton'
         
     | 
| 458 | 
         
            +
                    if long_sequence:
         
     | 
| 459 | 
         
            +
                        if 'mpt-7b-storywriter' in base_model.lower():
         
     | 
| 460 | 
         
            +
                            config.update({"max_seq_len": 83968})
         
     | 
| 461 | 
         
            +
                        if 'mosaicml/mpt-7b-chat' in base_model.lower():
         
     | 
| 462 | 
         
            +
                            config.update({"max_seq_len": 4096})
         
     | 
| 463 | 
         
            +
                    if issubclass(config.__class__, tuple(AutoModel._model_mapping.keys())):
         
     | 
| 464 | 
         
            +
                        model = AutoModel.from_config(
         
     | 
| 465 | 
         
            +
                            config,
         
     | 
| 466 | 
         
            +
                        )
         
     | 
| 467 | 
         
            +
                    else:
         
     | 
| 468 | 
         
            +
                        # can't infer
         
     | 
| 469 | 
         
            +
                        model = None
         
     | 
| 470 | 
         
            +
             
     | 
| 471 | 
         
            +
                if model is not None:
         
     | 
| 472 | 
         
            +
                    # NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model
         
     | 
| 473 | 
         
            +
                    # NOTE: Some models require avoiding sharding some layers,
         
     | 
| 474 | 
         
            +
                    # then would pass no_split_module_classes and give list of those layers.
         
     | 
| 475 | 
         
            +
                    device_map = infer_auto_device_map(
         
     | 
| 476 | 
         
            +
                        model,
         
     | 
| 477 | 
         
             
                        dtype=torch.float16 if load_half else torch.float32,
         
     | 
| 478 | 
         
             
                    )
         
     | 
| 479 | 
         
            +
                    if hasattr(model, 'model'):
         
     | 
| 480 | 
         
            +
                        device_map_model = infer_auto_device_map(
         
     | 
| 481 | 
         
            +
                            model.model,
         
     | 
| 482 | 
         
            +
                            dtype=torch.float16 if load_half else torch.float32,
         
     | 
| 483 | 
         
            +
                        )
         
     | 
| 484 | 
         
            +
                        device_map.update(device_map_model)
         
     | 
| 485 | 
         
            +
                    print('device_map: %s' % device_map, flush=True)
         
     | 
| 486 | 
         
            +
                else:
         
     | 
| 487 | 
         
            +
                    device_map = "auto"
         
     | 
| 488 | 
         | 
| 489 | 
         
             
                n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
         
     | 
| 490 | 
         | 
| 
         | 
|
| 508 | 
         
             
                if load_in_8bit or not load_half:
         
     | 
| 509 | 
         
             
                    model = model_loader.from_pretrained(
         
     | 
| 510 | 
         
             
                        base_model,
         
     | 
| 511 | 
         
            +
                        config=config,
         
     | 
| 512 | 
         
             
                        **model_kwargs,
         
     | 
| 513 | 
         
             
                    )
         
     | 
| 514 | 
         
             
                else:
         
     | 
| 515 | 
         
             
                    model = model_loader.from_pretrained(
         
     | 
| 516 | 
         
             
                        base_model,
         
     | 
| 517 | 
         
            +
                        config=config,
         
     | 
| 518 | 
         
             
                        **model_kwargs,
         
     | 
| 519 | 
         
             
                    ).half()
         
     | 
| 520 | 
         
             
                return model
         
     | 
| 
         | 
|
| 533 | 
         
             
                    local_files_only: bool = False,
         
     | 
| 534 | 
         
             
                    resume_download: bool = True,
         
     | 
| 535 | 
         
             
                    use_auth_token: Union[str, bool] = False,
         
     | 
| 536 | 
         
            +
                    trust_remote_code: bool = True,
         
     | 
| 537 | 
         
             
                    compile: bool = True,
         
     | 
| 538 | 
         
             
                    **kwargs,
         
     | 
| 539 | 
         
             
            ):
         
     | 
| 
         | 
|
| 552 | 
         
             
                :param local_files_only: use local files instead of from HF
         
     | 
| 553 | 
         
             
                :param resume_download: resume downloads from HF
         
     | 
| 554 | 
         
             
                :param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
         
     | 
| 555 | 
         
            +
                :param trust_remote_code: trust code needed by model
         
     | 
| 556 | 
         
             
                :param compile: whether to compile torch model
         
     | 
| 557 | 
         
             
                :param kwargs:
         
     | 
| 558 | 
         
             
                :return:
         
     | 
| 
         | 
|
| 571 | 
         
             
                )
         
     | 
| 572 | 
         | 
| 573 | 
         
             
                from transformers import AutoConfig
         
     | 
| 574 | 
         
            +
                config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
         
     | 
| 575 | 
         
            +
                                                    trust_remote_code=trust_remote_code)
         
     | 
| 576 | 
         
             
                llama_type_from_config = 'llama' in str(config).lower()
         
     | 
| 577 | 
         
             
                llama_type_from_name = "llama" in base_model.lower()
         
     | 
| 578 | 
         
             
                llama_type = llama_type_from_config or llama_type_from_name
         
     | 
| 
         | 
|
| 589 | 
         
             
                                                                 local_files_only=local_files_only,
         
     | 
| 590 | 
         
             
                                                                 resume_download=resume_download,
         
     | 
| 591 | 
         
             
                                                                 use_auth_token=use_auth_token,
         
     | 
| 592 | 
         
            +
                                                                 trust_remote_code=trust_remote_code,
         
     | 
| 593 | 
         
             
                                                                 )
         
     | 
| 594 | 
         
             
                else:
         
     | 
| 595 | 
         
             
                    tokenizer = tokenizer_loader
         
     | 
| 
         | 
|
| 605 | 
         
             
                    model_kwargs = dict(local_files_only=local_files_only,
         
     | 
| 606 | 
         
             
                                        torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
         
     | 
| 607 | 
         
             
                                        resume_download=resume_download,
         
     | 
| 608 | 
         
            +
                                        use_auth_token=use_auth_token,
         
     | 
| 609 | 
         
            +
                                        trust_remote_code=trust_remote_code,
         
     | 
| 610 | 
         
            +
                                        )
         
     | 
| 611 | 
         
            +
                    if 'mbart-' not in base_model.lower() and 'mpt-' not in base_model.lower():
         
     | 
| 612 | 
         
             
                        model_kwargs.update(dict(load_in_8bit=load_8bit,
         
     | 
| 613 | 
         
             
                                                 device_map={"": 0} if load_8bit and device == 'cuda' else "auto",
         
     | 
| 614 | 
         
             
                                                 ))
         
     | 
| 615 | 
         
            +
                    if 'mpt-' in base_model.lower() and gpu_id >= 0:
         
     | 
| 616 | 
         
            +
                        model_kwargs.update(dict(device_map={"": gpu_id} if device == 'cuda' else "cpu"))
         
     | 
| 617 | 
         
            +
             
     | 
| 618 | 
         
             
                    if 'OpenAssistant/reward-model'.lower() in base_model.lower():
         
     | 
| 619 | 
         
            +
                        # FIXME: could put on other GPUs
         
     | 
| 620 | 
         
             
                        model_kwargs['device_map'] = {"": 0} if device == 'cuda' else {"": 'cpu'}
         
     | 
| 621 | 
         
             
                        model_kwargs.pop('torch_dtype', None)
         
     | 
| 622 | 
         | 
| 
         | 
|
| 624 | 
         
             
                        with torch.device(device):
         
     | 
| 625 | 
         
             
                            if infer_devices:
         
     | 
| 626 | 
         
             
                                model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
         
     | 
| 627 | 
         
            +
                                                           gpu_id=gpu_id,
         
     | 
| 628 | 
         
            +
                                                           use_auth_token=use_auth_token,
         
     | 
| 629 | 
         
            +
                                                           trust_remote_code=trust_remote_code,
         
     | 
| 630 | 
         
            +
                                                           )
         
     | 
| 631 | 
         
             
                            else:
         
     | 
| 632 | 
         
             
                                if load_half and not load_8bit:
         
     | 
| 633 | 
         
             
                                    model = model_loader.from_pretrained(
         
     | 
| 
         | 
|
| 649 | 
         
             
                            local_files_only=local_files_only,
         
     | 
| 650 | 
         
             
                            resume_download=resume_download,
         
     | 
| 651 | 
         
             
                            use_auth_token=use_auth_token,
         
     | 
| 652 | 
         
            +
                            trust_remote_code=trust_remote_code,
         
     | 
| 653 | 
         
             
                            device_map={"": 0} if device == 'cuda' else {"": 'cpu'},  # seems to be required
         
     | 
| 654 | 
         
             
                        )
         
     | 
| 655 | 
         
             
                    else:
         
     | 
| 
         | 
|
| 665 | 
         
             
                                local_files_only=local_files_only,
         
     | 
| 666 | 
         
             
                                resume_download=resume_download,
         
     | 
| 667 | 
         
             
                                use_auth_token=use_auth_token,
         
     | 
| 668 | 
         
            +
                                trust_remote_code=trust_remote_code,
         
     | 
| 669 | 
         
             
                                device_map="auto",
         
     | 
| 670 | 
         
             
                            )
         
     | 
| 671 | 
         
             
                            if load_half:
         
     | 
| 
         | 
|
| 834 | 
         
             
                if chat:
         
     | 
| 835 | 
         
             
                    # override, ignore user change
         
     | 
| 836 | 
         
             
                    num_return_sequences = 1
         
     | 
| 837 | 
         
            +
                stopping_criteria = get_stopping(prompt_type, tokenizer, device)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 838 | 
         
             
                # help to avoid errors like:
         
     | 
| 839 | 
         
             
                # RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
         
     | 
| 840 | 
         
             
                # RuntimeError: expected scalar type Half but found Float
         
     | 
| 
         | 
|
| 913 | 
         
             
                                prompt = inputs_decoded
         
     | 
| 914 | 
         
             
                            elif inputs_decoded_raw == prompt:
         
     | 
| 915 | 
         
             
                                # some models specify special tokens that are part of normal prompt, so can't skip them
         
     | 
| 916 | 
         
            +
                                inputs_decoded = prompt = inputs_decoded_raw
         
     | 
| 917 | 
         
            +
                                decoder = decoder_raw
         
     | 
| 918 | 
         
            +
                            elif inputs_decoded_raw.replace("<unk> ", "").replace("<unk>", "").replace('\n', ' ').replace(' ', '') == prompt.replace('\n', ' ').replace(' ', ''):
         
     | 
| 919 | 
         
            +
                                inputs_decoded = prompt = inputs_decoded_raw
         
     | 
| 920 | 
         
             
                                decoder = decoder_raw
         
     | 
| 921 | 
         
             
                            else:
         
     | 
| 922 | 
         
             
                                print("WARNING: Special characters in prompt", flush=True)
         
     | 
| 
         | 
|
| 1059 | 
         | 
| 1060 | 
         
             
                if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
         
     | 
| 1061 | 
         
             
                    prompt_type = inv_prompt_type_to_model_lower[model_lower]
         
     | 
| 1062 | 
         
            +
                    print("Auto-selecting prompt_type=%s for %s" % (prompt_type, model_lower), flush=True)
         
     | 
| 1063 | 
         | 
| 1064 | 
         
             
                # examples at first don't include chat, instruction_nochat, iinput_nochat, added at end
         
     | 
| 1065 | 
         
             
                if show_examples is None:
         
     | 
| 
         | 
|
| 1118 | 
         
             
                        placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter."
         
     | 
| 1119 | 
         
             
                    placeholder_input = ""
         
     | 
| 1120 | 
         
             
                    if model_lower:
         
     | 
| 1121 | 
         
            +
                        # default is plain, because might relly upon trust_remote_code to handle prompting
         
     | 
| 1122 | 
         
            +
                        prompt_type = prompt_type or 'plain'
         
     | 
| 1123 | 
         
             
                    else:
         
     | 
| 1124 | 
         
             
                        prompt_type = ''
         
     | 
| 1125 | 
         
             
                    examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else '', "",
         
     | 
| 
         | 
|
| 1148 | 
         
             
                    num_return_sequences = min(num_beams, num_return_sequences or 1)
         
     | 
| 1149 | 
         
             
                    do_sample = False if do_sample is None else do_sample
         
     | 
| 1150 | 
         
             
                else:
         
     | 
| 1151 | 
         
            +
                    temperature = 0.1 if temperature is None else temperature
         
     | 
| 1152 | 
         
            +
                    top_p = 0.75 if top_p is None else top_p
         
     | 
| 1153 | 
         
            +
                    top_k = 40 if top_k is None else top_k
         
     | 
| 1154 | 
         
             
                    if chat:
         
     | 
| 1155 | 
         
             
                        num_beams = num_beams or 1
         
     | 
| 1156 | 
         
             
                    else:
         
     | 
| 
         | 
|
| 1158 | 
         
             
                    max_new_tokens = max_new_tokens or 256
         
     | 
| 1159 | 
         
             
                    repetition_penalty = repetition_penalty or 1.07
         
     | 
| 1160 | 
         
             
                    num_return_sequences = min(num_beams, num_return_sequences or 1)
         
     | 
| 1161 | 
         
            +
                    do_sample = False if do_sample is None else do_sample
         
     | 
| 1162 | 
         
             
                # doesn't include chat, instruction_nochat, iinput_nochat, added later
         
     | 
| 1163 | 
         
             
                params_list = ["", stream_output, prompt_type, temperature, top_p, top_k, num_beams, max_new_tokens, min_new_tokens,
         
     | 
| 1164 | 
         
             
                               early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample]
         
     | 
    	
        gradio_runner.py
    CHANGED
    
    | 
         @@ -5,6 +5,7 @@ import os 
     | 
|
| 5 | 
         
             
            import sys
         
     | 
| 6 | 
         | 
| 7 | 
         
             
            from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js
         
     | 
| 
         | 
|
| 8 | 
         
             
            from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
         
     | 
| 9 | 
         
             
                ping
         
     | 
| 10 | 
         
             
            from finetune import prompt_type_to_model_name, prompt_types_strings, generate_prompt, inv_prompt_type_to_model_lower
         
     | 
| 
         @@ -49,6 +50,7 @@ def go_gradio(**kwargs): 
     | 
|
| 49 | 
         
             
                                  """
         
     | 
| 50 | 
         
             
                else:
         
     | 
| 51 | 
         
             
                    description = "For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio)<br>"
         
     | 
| 
         | 
|
| 52 | 
         
             
                description += """<p>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md)</p>"""
         
     | 
| 53 | 
         | 
| 54 | 
         
             
                if kwargs['verbose']:
         
     | 
| 
         @@ -389,6 +391,7 @@ def go_gradio(**kwargs): 
     | 
|
| 389 | 
         
             
                        .then(close_admin, inputs=admin_pass_textbox, outputs=admin_row, queue=False)
         
     | 
| 390 | 
         | 
| 391 | 
         
             
                    # Get inputs to evaluate()
         
     | 
| 
         | 
|
| 392 | 
         
             
                    all_kwargs = kwargs.copy()
         
     | 
| 393 | 
         
             
                    all_kwargs.update(locals())
         
     | 
| 394 | 
         
             
                    inputs_list = get_inputs_list(all_kwargs, kwargs['model_lower'])
         
     | 
| 
         @@ -516,9 +519,12 @@ def go_gradio(**kwargs): 
     | 
|
| 516 | 
         
             
                        :return:
         
     | 
| 517 | 
         
             
                        """
         
     | 
| 518 | 
         
             
                        args_list = list(args)
         
     | 
| 519 | 
         
            -
                        user_message = args_list[ 
     | 
| 520 | 
         
            -
                        input1 = args_list[ 
     | 
| 521 | 
         
            -
                        context1 = args_list[ 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 522 | 
         
             
                        if input1 and not user_message.endswith(':'):
         
     | 
| 523 | 
         
             
                            user_message1 = user_message + ":" + input1
         
     | 
| 524 | 
         
             
                        elif input1:
         
     | 
| 
         @@ -528,6 +534,8 @@ def go_gradio(**kwargs): 
     | 
|
| 528 | 
         
             
                        if sanitize_user_prompt:
         
     | 
| 529 | 
         
             
                            from better_profanity import profanity
         
     | 
| 530 | 
         
             
                            user_message1 = profanity.censor(user_message1)
         
     | 
| 
         | 
|
| 
         | 
|
| 531 | 
         
             
                        if user_message1 in ['']:
         
     | 
| 532 | 
         
             
                            # e.g. when user just hits enter in textbox,
         
     | 
| 533 | 
         
             
                            # else will have <human>: <bot>: on single line, which seems to be "ok" for LLM but not usual
         
     | 
| 
         @@ -559,7 +567,8 @@ def go_gradio(**kwargs): 
     | 
|
| 559 | 
         
             
                        :param retry:
         
     | 
| 560 | 
         
             
                        :return:
         
     | 
| 561 | 
         
             
                        """
         
     | 
| 562 | 
         
            -
                         
     | 
| 
         | 
|
| 563 | 
         
             
                        history = args_list[-1]  # model_state is -2
         
     | 
| 564 | 
         
             
                        if retry and history:
         
     | 
| 565 | 
         
             
                            history.pop()
         
     | 
| 
         @@ -580,12 +589,18 @@ def go_gradio(**kwargs): 
     | 
|
| 580 | 
         
             
                            context1 = ''
         
     | 
| 581 | 
         
             
                            for histi in range(len(history) - 1):
         
     | 
| 582 | 
         
             
                                data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
         
     | 
| 583 | 
         
            -
                                 
     | 
| 584 | 
         
            -
             
     | 
| 585 | 
         
            -
                                 
     | 
| 586 | 
         
            -
             
     | 
| 587 | 
         
            -
             
     | 
| 588 | 
         
            -
                                 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 589 | 
         
             
                        args_list[0] = instruction1  # override original instruction with history from user
         
     | 
| 590 | 
         
             
                        # only include desired chat history
         
     | 
| 591 | 
         
             
                        args_list[2] = context1[-kwargs['chat_history']:]
         
     | 
| 
         @@ -767,6 +782,7 @@ def go_gradio(**kwargs): 
     | 
|
| 767 | 
         
             
                            lora_weights = no_lora_str
         
     | 
| 768 | 
         
             
                            return [None, None, None, model_name], model_name, lora_weights, prompt_type_old
         
     | 
| 769 | 
         | 
| 
         | 
|
| 770 | 
         
             
                        all_kwargs1 = all_kwargs.copy()
         
     | 
| 771 | 
         
             
                        all_kwargs1['base_model'] = model_name.strip()
         
     | 
| 772 | 
         
             
                        all_kwargs1['load_8bit'] = load_8bit
         
     | 
| 
         | 
|
| 5 | 
         
             
            import sys
         
     | 
| 6 | 
         | 
| 7 | 
         
             
            from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js
         
     | 
| 8 | 
         
            +
            from prompter import Prompter
         
     | 
| 9 | 
         
             
            from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
         
     | 
| 10 | 
         
             
                ping
         
     | 
| 11 | 
         
             
            from finetune import prompt_type_to_model_name, prompt_types_strings, generate_prompt, inv_prompt_type_to_model_lower
         
     | 
| 
         | 
|
| 50 | 
         
             
                                  """
         
     | 
| 51 | 
         
             
                else:
         
     | 
| 52 | 
         
             
                    description = "For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio)<br>"
         
     | 
| 53 | 
         
            +
                description += "If this host is busy, try [gpt.h2o.ai 20B](https://gpt.h2o.ai) and [30B](http://gpu.hopto.org) and [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) and [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
         
     | 
| 54 | 
         
             
                description += """<p>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md)</p>"""
         
     | 
| 55 | 
         | 
| 56 | 
         
             
                if kwargs['verbose']:
         
     | 
| 
         | 
|
| 391 | 
         
             
                        .then(close_admin, inputs=admin_pass_textbox, outputs=admin_row, queue=False)
         
     | 
| 392 | 
         | 
| 393 | 
         
             
                    # Get inputs to evaluate()
         
     | 
| 394 | 
         
            +
                    # don't deepcopy, can contain model itself
         
     | 
| 395 | 
         
             
                    all_kwargs = kwargs.copy()
         
     | 
| 396 | 
         
             
                    all_kwargs.update(locals())
         
     | 
| 397 | 
         
             
                    inputs_list = get_inputs_list(all_kwargs, kwargs['model_lower'])
         
     | 
| 
         | 
|
| 519 | 
         
             
                        :return:
         
     | 
| 520 | 
         
             
                        """
         
     | 
| 521 | 
         
             
                        args_list = list(args)
         
     | 
| 522 | 
         
            +
                        user_message = args_list[eval_func_param_names.index('instruction')]  # chat only
         
     | 
| 523 | 
         
            +
                        input1 = args_list[eval_func_param_names.index('iinput')]  # chat only
         
     | 
| 524 | 
         
            +
                        context1 = args_list[eval_func_param_names.index('context')]
         
     | 
| 525 | 
         
            +
                        prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
         
     | 
| 526 | 
         
            +
                        chat1 = args_list[eval_func_param_names.index('chat')]
         
     | 
| 527 | 
         
            +
                        stream_output1 = args_list[eval_func_param_names.index('stream_output')]
         
     | 
| 528 | 
         
             
                        if input1 and not user_message.endswith(':'):
         
     | 
| 529 | 
         
             
                            user_message1 = user_message + ":" + input1
         
     | 
| 530 | 
         
             
                        elif input1:
         
     | 
| 
         | 
|
| 534 | 
         
             
                        if sanitize_user_prompt:
         
     | 
| 535 | 
         
             
                            from better_profanity import profanity
         
     | 
| 536 | 
         
             
                            user_message1 = profanity.censor(user_message1)
         
     | 
| 537 | 
         
            +
                        # FIXME: WIP to use desired seperator when user enters nothing
         
     | 
| 538 | 
         
            +
                        prompter = Prompter(prompt_type1, debug=kwargs['debug'], chat=chat1, stream_output=stream_output1)
         
     | 
| 539 | 
         
             
                        if user_message1 in ['']:
         
     | 
| 540 | 
         
             
                            # e.g. when user just hits enter in textbox,
         
     | 
| 541 | 
         
             
                            # else will have <human>: <bot>: on single line, which seems to be "ok" for LLM but not usual
         
     | 
| 
         | 
|
| 567 | 
         
             
                        :param retry:
         
     | 
| 568 | 
         
             
                        :return:
         
     | 
| 569 | 
         
             
                        """
         
     | 
| 570 | 
         
            +
                        # don't deepcopy, can contain model itself
         
     | 
| 571 | 
         
            +
                        args_list = list(args).copy()
         
     | 
| 572 | 
         
             
                        history = args_list[-1]  # model_state is -2
         
     | 
| 573 | 
         
             
                        if retry and history:
         
     | 
| 574 | 
         
             
                            history.pop()
         
     | 
| 
         | 
|
| 589 | 
         
             
                            context1 = ''
         
     | 
| 590 | 
         
             
                            for histi in range(len(history) - 1):
         
     | 
| 591 | 
         
             
                                data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
         
     | 
| 592 | 
         
            +
                                prompt, pre_response, terminate_response, chat_sep = generate_prompt(data_point, prompt_type1,
         
     | 
| 593 | 
         
            +
                                                                                                     chat1, reduced=True)
         
     | 
| 594 | 
         
            +
                                # md -> back to text, maybe not super improtant if model trained enough
         
     | 
| 595 | 
         
            +
                                prompt = prompt.replace('<br>', chat_sep)
         
     | 
| 596 | 
         
            +
                                context1 += prompt
         
     | 
| 597 | 
         
            +
                                if not context1.endswith(chat_sep):
         
     | 
| 598 | 
         
            +
                                    context1 += chat_sep
         
     | 
| 599 | 
         
            +
             
     | 
| 600 | 
         
            +
                            _, pre_response, terminate_response, chat_sep = generate_prompt({}, prompt_type1, chat1,
         
     | 
| 601 | 
         
            +
                                                                                            reduced=True)
         
     | 
| 602 | 
         
            +
                            if context1 and not context1.endswith(chat_sep):
         
     | 
| 603 | 
         
            +
                                context1 += chat_sep  # ensure if terminates abruptly, then human continues on next line
         
     | 
| 604 | 
         
             
                        args_list[0] = instruction1  # override original instruction with history from user
         
     | 
| 605 | 
         
             
                        # only include desired chat history
         
     | 
| 606 | 
         
             
                        args_list[2] = context1[-kwargs['chat_history']:]
         
     | 
| 
         | 
|
| 782 | 
         
             
                            lora_weights = no_lora_str
         
     | 
| 783 | 
         
             
                            return [None, None, None, model_name], model_name, lora_weights, prompt_type_old
         
     | 
| 784 | 
         | 
| 785 | 
         
            +
                        # don't deepcopy, can contain model itself
         
     | 
| 786 | 
         
             
                        all_kwargs1 = all_kwargs.copy()
         
     | 
| 787 | 
         
             
                        all_kwargs1['base_model'] = model_name.strip()
         
     | 
| 788 | 
         
             
                        all_kwargs1['load_8bit'] = load_8bit
         
     | 
    	
        prompter.py
    CHANGED
    
    | 
         @@ -6,7 +6,8 @@ class Prompter(object): 
     | 
|
| 6 | 
         
             
                             allowed_repeat_line_length=10):
         
     | 
| 7 | 
         
             
                    self.prompt_type = prompt_type
         
     | 
| 8 | 
         
             
                    data_point = dict(instruction='', input='', output='')
         
     | 
| 9 | 
         
            -
                    _, self.pre_response, self.terminate_response 
     | 
| 
         | 
|
| 10 | 
         
             
                    self.debug = debug
         
     | 
| 11 | 
         
             
                    self.chat = chat
         
     | 
| 12 | 
         
             
                    self.stream_output = stream_output
         
     | 
| 
         @@ -15,7 +16,7 @@ class Prompter(object): 
     | 
|
| 15 | 
         | 
| 16 | 
         
             
                def generate_prompt(self, data_point):
         
     | 
| 17 | 
         
             
                    reduced = False
         
     | 
| 18 | 
         
            -
                    prompt, _, _ = generate_prompt(data_point, self.prompt_type, self.chat, reduced)
         
     | 
| 19 | 
         
             
                    if self.debug:
         
     | 
| 20 | 
         
             
                        print("prompt: ", prompt, flush=True)
         
     | 
| 21 | 
         
             
                    self.prompt = prompt
         
     | 
| 
         @@ -25,12 +26,12 @@ class Prompter(object): 
     | 
|
| 25 | 
         
             
                    if isinstance(outputs, str):
         
     | 
| 26 | 
         
             
                        outputs = [outputs]
         
     | 
| 27 | 
         
             
                    if self.debug:
         
     | 
| 28 | 
         
            -
                        print("output 
     | 
| 29 | 
         
             
                    if prompt is not None:
         
     | 
| 30 | 
         
             
                        self.prompt = prompt
         
     | 
| 31 | 
         | 
| 32 | 
         
             
                    def clean_response(response):
         
     | 
| 33 | 
         
            -
                        meaningless_words = ['<pad>', '</s>', '<|endoftext|>' 
     | 
| 34 | 
         
             
                        for word in meaningless_words:
         
     | 
| 35 | 
         
             
                            response = response.replace(word, "")
         
     | 
| 36 | 
         
             
                        if sanitize_bot_response:
         
     | 
| 
         @@ -103,5 +104,5 @@ class Prompter(object): 
     | 
|
| 103 | 
         
             
                    # join all outputs, only one extra new line between outputs
         
     | 
| 104 | 
         
             
                    output = '\n'.join(outputs)
         
     | 
| 105 | 
         
             
                    if self.debug:
         
     | 
| 106 | 
         
            -
                        print("outputclean 
     | 
| 107 | 
         
             
                    return output
         
     | 
| 
         | 
|
| 6 | 
         
             
                             allowed_repeat_line_length=10):
         
     | 
| 7 | 
         
             
                    self.prompt_type = prompt_type
         
     | 
| 8 | 
         
             
                    data_point = dict(instruction='', input='', output='')
         
     | 
| 9 | 
         
            +
                    _, self.pre_response, self.terminate_response, self.chat_sep = \
         
     | 
| 10 | 
         
            +
                        generate_prompt(data_point, prompt_type, chat, False)
         
     | 
| 11 | 
         
             
                    self.debug = debug
         
     | 
| 12 | 
         
             
                    self.chat = chat
         
     | 
| 13 | 
         
             
                    self.stream_output = stream_output
         
     | 
| 
         | 
|
| 16 | 
         | 
| 17 | 
         
             
                def generate_prompt(self, data_point):
         
     | 
| 18 | 
         
             
                    reduced = False
         
     | 
| 19 | 
         
            +
                    prompt, _, _, _ = generate_prompt(data_point, self.prompt_type, self.chat, reduced)
         
     | 
| 20 | 
         
             
                    if self.debug:
         
     | 
| 21 | 
         
             
                        print("prompt: ", prompt, flush=True)
         
     | 
| 22 | 
         
             
                    self.prompt = prompt
         
     | 
| 
         | 
|
| 26 | 
         
             
                    if isinstance(outputs, str):
         
     | 
| 27 | 
         
             
                        outputs = [outputs]
         
     | 
| 28 | 
         
             
                    if self.debug:
         
     | 
| 29 | 
         
            +
                        print("output:\n", '\n\n'.join(outputs), flush=True)
         
     | 
| 30 | 
         
             
                    if prompt is not None:
         
     | 
| 31 | 
         
             
                        self.prompt = prompt
         
     | 
| 32 | 
         | 
| 33 | 
         
             
                    def clean_response(response):
         
     | 
| 34 | 
         
            +
                        meaningless_words = ['<pad>', '</s>', '<|endoftext|>']
         
     | 
| 35 | 
         
             
                        for word in meaningless_words:
         
     | 
| 36 | 
         
             
                            response = response.replace(word, "")
         
     | 
| 37 | 
         
             
                        if sanitize_bot_response:
         
     | 
| 
         | 
|
| 104 | 
         
             
                    # join all outputs, only one extra new line between outputs
         
     | 
| 105 | 
         
             
                    output = '\n'.join(outputs)
         
     | 
| 106 | 
         
             
                    if self.debug:
         
     | 
| 107 | 
         
            +
                        print("outputclean:\n", '\n\n'.join(outputs), flush=True)
         
     | 
| 108 | 
         
             
                    return output
         
     | 
    	
        requirements.txt
    CHANGED
    
    | 
         @@ -19,7 +19,7 @@ pandas==2.0.0 
     | 
|
| 19 | 
         
             
            matplotlib==3.7.1
         
     | 
| 20 | 
         
             
            loralib==0.1.1
         
     | 
| 21 | 
         
             
            bitsandbytes==0.38.1
         
     | 
| 22 | 
         
            -
            git+https://github.com/huggingface/peft.git@ 
     | 
| 23 | 
         
             
            transformers==4.28.1
         
     | 
| 24 | 
         
             
            tokenizers==0.13.3
         
     | 
| 25 | 
         
             
            APScheduler==3.10.1
         
     | 
| 
         | 
|
| 19 | 
         
             
            matplotlib==3.7.1
         
     | 
| 20 | 
         
             
            loralib==0.1.1
         
     | 
| 21 | 
         
             
            bitsandbytes==0.38.1
         
     | 
| 22 | 
         
            +
            git+https://github.com/huggingface/peft.git@098962fa6515f2e4fe83a757f5995d3ffbb1c373
         
     | 
| 23 | 
         
             
            transformers==4.28.1
         
     | 
| 24 | 
         
             
            tokenizers==0.13.3
         
     | 
| 25 | 
         
             
            APScheduler==3.10.1
         
     | 
    	
        stopping.py
    CHANGED
    
    | 
         @@ -1,10 +1,5 @@ 
     | 
|
| 1 | 
         
            -
            import traceback
         
     | 
| 2 | 
         
            -
            from queue import Queue
         
     | 
| 3 | 
         
            -
            from threading import Thread
         
     | 
| 4 | 
         
            -
            import collections.abc
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
             
            import torch
         
     | 
| 7 | 
         
            -
            from transformers import StoppingCriteria
         
     | 
| 8 | 
         | 
| 9 | 
         | 
| 10 | 
         
             
            class StoppingCriteriaSub(StoppingCriteria):
         
     | 
| 
         @@ -21,7 +16,55 @@ class StoppingCriteriaSub(StoppingCriteria): 
     | 
|
| 21 | 
         
             
                        if torch.all((stop == input_ids[0][-len(stop):])).item():
         
     | 
| 22 | 
         
             
                            self.num_stops[stopi] += 1
         
     | 
| 23 | 
         
             
                            if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
         
     | 
| 
         | 
|
| 24 | 
         
             
                                return True
         
     | 
| 25 | 
         
             
                    # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
         
     | 
| 26 | 
         
             
                    # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
         
     | 
| 27 | 
         
             
                    return False
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            +
            from transformers import StoppingCriteria, StoppingCriteriaList
         
     | 
| 3 | 
         | 
| 4 | 
         | 
| 5 | 
         
             
            class StoppingCriteriaSub(StoppingCriteria):
         
     | 
| 
         | 
|
| 16 | 
         
             
                        if torch.all((stop == input_ids[0][-len(stop):])).item():
         
     | 
| 17 | 
         
             
                            self.num_stops[stopi] += 1
         
     | 
| 18 | 
         
             
                            if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
         
     | 
| 19 | 
         
            +
                                # print("Stopped", flush=True)
         
     | 
| 20 | 
         
             
                                return True
         
     | 
| 21 | 
         
             
                    # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
         
     | 
| 22 | 
         
             
                    # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
         
     | 
| 23 | 
         
             
                    return False
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            def get_stopping(prompt_type, tokenizer, device, human='<human>:', bot="<bot>:"):
         
     | 
| 27 | 
         
            +
                if prompt_type in ['human_bot', 'instruct_vicuna', 'instruct_with_end']:
         
     | 
| 28 | 
         
            +
                    if prompt_type == 'human_bot':
         
     | 
| 29 | 
         
            +
                        # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
         
     | 
| 30 | 
         
            +
                        # stopping only starts once output is beyond prompt
         
     | 
| 31 | 
         
            +
                        # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
         
     | 
| 32 | 
         
            +
                        stop_words = [human, bot, '\n' + human, '\n' + bot]
         
     | 
| 33 | 
         
            +
                        encounters = [1, 2]
         
     | 
| 34 | 
         
            +
                    elif prompt_type == 'instruct_vicuna':
         
     | 
| 35 | 
         
            +
                        # even below is not enough, generic strings and many ways to encode
         
     | 
| 36 | 
         
            +
                        stop_words = [
         
     | 
| 37 | 
         
            +
                            '### Human:',
         
     | 
| 38 | 
         
            +
                            """
         
     | 
| 39 | 
         
            +
            ### Human:""",
         
     | 
| 40 | 
         
            +
                            """
         
     | 
| 41 | 
         
            +
            ### Human:
         
     | 
| 42 | 
         
            +
            """,
         
     | 
| 43 | 
         
            +
                            '### Assistant:',
         
     | 
| 44 | 
         
            +
                            """
         
     | 
| 45 | 
         
            +
            ### Assistant:""",
         
     | 
| 46 | 
         
            +
                            """
         
     | 
| 47 | 
         
            +
            ### Assistant:
         
     | 
| 48 | 
         
            +
            """,
         
     | 
| 49 | 
         
            +
                        ]
         
     | 
| 50 | 
         
            +
                        encounters = [1, 2]
         
     | 
| 51 | 
         
            +
                    else:
         
     | 
| 52 | 
         
            +
                        # some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
         
     | 
| 53 | 
         
            +
                        stop_words = ['### End']
         
     | 
| 54 | 
         
            +
                        encounters = [1]
         
     | 
| 55 | 
         
            +
                    stop_words_ids = [
         
     | 
| 56 | 
         
            +
                        tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
         
     | 
| 57 | 
         
            +
                    # handle single token case
         
     | 
| 58 | 
         
            +
                    stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
         
     | 
| 59 | 
         
            +
                    stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
         
     | 
| 60 | 
         
            +
                    # avoid padding in front of tokens
         
     | 
| 61 | 
         
            +
                    if tokenizer.pad_token:
         
     | 
| 62 | 
         
            +
                        stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
         
     | 
| 63 | 
         
            +
                    # handle fake \n added
         
     | 
| 64 | 
         
            +
                    stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
         
     | 
| 65 | 
         
            +
                    # build stopper
         
     | 
| 66 | 
         
            +
                    stopping_criteria = StoppingCriteriaList(
         
     | 
| 67 | 
         
            +
                        [StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device)])
         
     | 
| 68 | 
         
            +
                else:
         
     | 
| 69 | 
         
            +
                    stopping_criteria = StoppingCriteriaList()
         
     | 
| 70 | 
         
            +
                return stopping_criteria
         
     |