File size: 3,255 Bytes
2a26d3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import pathlib

from transformers import AutoTokenizer
from vllm import LLM, SamplingParams


def get_infer_kwargs(args) -> dict:
    """llm_inference kwargs"""
    temperature = args.temperature if args.temperature else 1.0
    max_new_tokens = args.max_new_tokens if args.max_new_tokens else 1024
    model_type = args.model_type if args.model_type else "chat_model"

    kwargs = {
        "temperature": temperature,
        "max_tokens": max_new_tokens,
        "model_type": model_type,
    }
    return kwargs


def load_tokenizer_and_template(model_name_or_path, template=None):
    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path, trust_remote_code=True
    )

    if tokenizer.chat_template is None:
        if template is not None:
            chatml_jinja_path = (
                pathlib.Path(os.path.dirname(os.path.abspath(__file__)))
                / f"templates/template_{template}.jinja"
            )
            assert chatml_jinja_path.exists()
            with open(chatml_jinja_path, "r") as f:
                tokenizer.chat_template = f.read()
        else:
            pass
            # raise ValueError("chat_template is not found in the config file, please provide the template parameter.")
    return tokenizer


def load_model(model_name_or_path, max_model_len=None, gpus_num=1):
    llm_args = {
        "model": model_name_or_path,
        "gpu_memory_utilization": 0.95,
        "trust_remote_code": True,
        "tensor_parallel_size": gpus_num,
        "dtype": "half",
    }

    if max_model_len:
        llm_args["max_model_len"] = max_model_len

    # Create an LLM.
    llm = LLM(**llm_args)
    return llm


def generate_outputs(messages_batch, llm_model, tokenizer, generate_args):
    """
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": prompt}
    ]

    messages_batch = [messages]

    generate_args = {
        "max_new_tokens": 1024,
        "do_sample": True or False,
        "temperature": 0-1,
        ""
    }
    """
    model_type = generate_args.pop("model_type", "chat_model")
    # 添加一个默认参数, 抑制instruct-following能力较差的模型,输出重复内容,考虑加入参数配置
    # generate_args["presence_penalty"] = 2.0
    sampling_params = SamplingParams(**generate_args)

    prompt_batch = []
    for messages in messages_batch:
        # 如果是basemodel, 直接拼接prompt内容后输入到模型
        if model_type == "base_model":
            messages_content = [msg["content"] for msg in messages]
            prompt = "\n".join(messages_content)
        # 如果是chat—model 则拼接chat-template后输入
        else:
            prompt = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        prompt_batch.append(prompt)

    outputs = llm_model.generate(prompt_batch, sampling_params)

    outputs_batch = []
    for output in outputs:
        prompt_output = output.prompt
        generated_text = output.outputs[0].text
        outputs_batch.append(
            {"input_prompt": prompt_output, "output_text": generated_text}
        )

    return outputs_batch