Create README.md
Browse files
    	
        README.md
    ADDED
    
    | 
         @@ -0,0 +1,81 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ---
         
     | 
| 2 | 
         
            +
            license: openrail
         
     | 
| 3 | 
         
            +
            datasets:
         
     | 
| 4 | 
         
            +
            - shareAI/ShareGPT-Chinese-English-90k
         
     | 
| 5 | 
         
            +
            - shareAI/CodeChat
         
     | 
| 6 | 
         
            +
            language:
         
     | 
| 7 | 
         
            +
            - zh
         
     | 
| 8 | 
         
            +
            - en
         
     | 
| 9 | 
         
            +
            library_name: transformers
         
     | 
| 10 | 
         
            +
            tags:
         
     | 
| 11 | 
         
            +
            - code
         
     | 
| 12 | 
         
            +
            - chat
         
     | 
| 13 | 
         
            +
            - codellama
         
     | 
| 14 | 
         
            +
            - copilot
         
     | 
| 15 | 
         
            +
            - codeAI
         
     | 
| 16 | 
         
            +
            pipeline_tag: question-answering
         
     | 
| 17 | 
         
            +
            ---
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            ## CodeLlaMa模型的中文化版本 (支持多轮对话)
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            科普:CodeLlaMa是专门用于代码助手的,与ChineseLlaMa不同,适用于代码类问题的回复。  
         
     | 
| 22 | 
         
            +
            用于多轮对话的推理代码:  
         
     | 
| 23 | 
         
            +
            (可以直接复制运行,默认会自动拉取该模型权重)  
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            关联Github仓库:https://github.com/CrazyBoyM/CodeLLaMA-chat  
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            ```
         
     | 
| 28 | 
         
            +
            # from Firefly
         
     | 
| 29 | 
         
            +
            from transformers import AutoModelForCausalLM, AutoTokenizer
         
     | 
| 30 | 
         
            +
            import torch
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            def main():
         
     | 
| 34 | 
         
            +
                model_name = 'shareAI/CodeLLaMA-chat-13b-Chinese'
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                device = 'cuda'
         
     | 
| 37 | 
         
            +
                max_new_tokens = 500    # 每轮对话最多生成多少个token
         
     | 
| 38 | 
         
            +
                history_max_len = 1000  # 模型记忆的最大token长度
         
     | 
| 39 | 
         
            +
                top_p = 0.9
         
     | 
| 40 | 
         
            +
                temperature = 0.35
         
     | 
| 41 | 
         
            +
                repetition_penalty = 1.0
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                model = AutoModelForCausalLM.from_pretrained(
         
     | 
| 44 | 
         
            +
                    model_name,
         
     | 
| 45 | 
         
            +
                    trust_remote_code=True,
         
     | 
| 46 | 
         
            +
                    low_cpu_mem_usage=True,
         
     | 
| 47 | 
         
            +
                    torch_dtype=torch.float16,
         
     | 
| 48 | 
         
            +
                    device_map='auto'
         
     | 
| 49 | 
         
            +
                ).to(device).eval()
         
     | 
| 50 | 
         
            +
                tokenizer = AutoTokenizer.from_pretrained(
         
     | 
| 51 | 
         
            +
                    model_name,
         
     | 
| 52 | 
         
            +
                    trust_remote_code=True,
         
     | 
| 53 | 
         
            +
                    use_fast=False
         
     | 
| 54 | 
         
            +
                )
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                history_token_ids = torch.tensor([[]], dtype=torch.long)
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                user_input = input('User:')
         
     | 
| 60 | 
         
            +
                while True:
         
     | 
| 61 | 
         
            +
                    input_ids = tokenizer(user_input, return_tensors="pt", add_special_tokens=False).input_ids
         
     | 
| 62 | 
         
            +
                    eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long)
         
     | 
| 63 | 
         
            +
                    user_input_ids = torch.concat([input_ids, eos_token_id], dim=1)
         
     | 
| 64 | 
         
            +
                    history_token_ids = torch.concat((history_token_ids, user_input_ids), dim=1)
         
     | 
| 65 | 
         
            +
                    model_input_ids = history_token_ids[:, -history_max_len:].to(device)
         
     | 
| 66 | 
         
            +
                    with torch.no_grad():
         
     | 
| 67 | 
         
            +
                        outputs = model.generate(
         
     | 
| 68 | 
         
            +
                            input_ids=model_input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p,
         
     | 
| 69 | 
         
            +
                            temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id
         
     | 
| 70 | 
         
            +
                        )
         
     | 
| 71 | 
         
            +
                    model_input_ids_len = model_input_ids.size(1)
         
     | 
| 72 | 
         
            +
                    response_ids = outputs[:, model_input_ids_len:]
         
     | 
| 73 | 
         
            +
                    history_token_ids = torch.concat((history_token_ids, response_ids.cpu()), dim=1)
         
     | 
| 74 | 
         
            +
                    response = tokenizer.batch_decode(response_ids)
         
     | 
| 75 | 
         
            +
                    print("Bot:" + response[0].strip().replace(tokenizer.eos_token, ""))
         
     | 
| 76 | 
         
            +
                    user_input = input('User:')
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 80 | 
         
            +
                main()
         
     | 
| 81 | 
         
            +
            ```
         
     |