File size: 4,285 Bytes
6a422c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import sys

import torch
from modelscope_agent.agent_types import AgentType
from swift import Swift
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer

from modelscope import GenerationConfig, snapshot_download
from .base import LLM


class ModelScopeLLM(LLM):

    def __init__(self, cfg):
        super().__init__(cfg)

        model_id = self.cfg.get('model_id', '')
        self.model_id = model_id
        model_revision = self.cfg.get('model_revision', None)
        cache_dir = self.cfg.get('cache_dir', None)

        if not os.path.exists(model_id):
            model_dir = snapshot_download(
                model_id, model_revision, cache_dir=cache_dir)
        else:
            model_dir = model_id
        self.model_dir = model_dir
        sys.path.append(self.model_dir)

        self.model_cls = self.cfg.get('model_cls', AutoModelForCausalLM)
        self.tokenizer_cls = self.cfg.get('tokenizer_cls', AutoTokenizer)

        self.device_map = self.cfg.get('device_map', 'auto')
        self.generation_cfg = GenerationConfig(
            **self.cfg.get('generate_cfg', {}))

        self.use_lora = self.cfg.get('use_lora', False)
        self.lora_ckpt_dir = self.cfg.get('lora_ckpt_dir',
                                          None) if self.use_lora else None

        self.custom_chat = self.cfg.get('custom_chat', False)

        self.end_token = self.cfg.get('end_token', '<|endofthink|>')
        self.include_end = self.cfg.get('include_end', True)

        self.setup()
        self.agent_type = self.cfg.get('agent_type', AgentType.DEFAULT)

    def setup(self):
        model_cls = self.model_cls
        tokenizer_cls = self.tokenizer_cls

        self.model = model_cls.from_pretrained(
            self.model_dir,
            device_map=self.device_map,
            # device='cuda:0',
            torch_dtype=torch.float16,
            trust_remote_code=True)
        self.tokenizer = tokenizer_cls.from_pretrained(
            self.model_dir, trust_remote_code=True)
        self.model = self.model.eval()

        if self.use_lora:
            self.load_from_lora()

        if self.cfg.get('use_raw_generation_config', False):
            self.model.generation_config = GenerationConfig.from_pretrained(
                self.model_dir, trust_remote_code=True)

    def generate(self, prompt, functions=[], **kwargs):

        if self.custom_chat and self.model.chat:
            response = self.model.chat(
                self.tokenizer, prompt, history=[], system='')[0]
        else:
            response = self.chat(prompt)

        end_idx = response.find(self.end_token)
        if end_idx != -1:
            end_idx += len(self.end_token) if self.include_end else 0
            response = response[:end_idx]

        return response

    def load_from_lora(self):

        model = self.model.bfloat16()
        # transform to lora
        model = Swift.from_pretrained(model, self.lora_ckpt_dir)

        self.model = model

    def chat(self, prompt):
        device = self.model.device
        input_ids = self.tokenizer(
            prompt, return_tensors='pt').input_ids.to(device)
        input_len = input_ids.shape[1]

        result = self.model.generate(
            input_ids=input_ids, generation_config=self.generation_cfg)

        result = result[0].tolist()[input_len:]
        response = self.tokenizer.decode(result)

        return response


class ModelScopeChatGLM(ModelScopeLLM):

    def chat(self, prompt):
        device = self.model.device
        input_ids = self.tokenizer(
            prompt, return_tensors='pt').input_ids.to(device)
        input_len = input_ids.shape[1]

        eos_token_id = [
            self.tokenizer.eos_token_id,
            self.tokenizer.get_command('<|user|>'),
            self.tokenizer.get_command('<|observation|>')
        ]
        result = self.model.generate(
            input_ids=input_ids,
            generation_config=self.generation_cfg,
            eos_token_id=eos_token_id)

        result = result[0].tolist()[input_len:]
        response = self.tokenizer.decode(result)
        # 遇到生成'<', '|', 'user', '|', '>'的case
        response = response.split('<|user|>')[0].split('<|observation|>')[0]

        return response