File size: 6,312 Bytes
f5776d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import json
# from peft import PeftConfig, PeftModel
# from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoTokenizer, AutoConfig
from typing import Optional, Literal

from dsp.modules.lm import LM
# from dsp.modules.finetuning.finetune_hf import preprocess_prompt
from dsp.modules.cache_utils import CacheMemory, NotebookCacheMemory, cache_turn_on
import functools

def openai_to_hf(**kwargs):
    hf_kwargs = {}
    for k, v in kwargs.items():
        if k == "n":
            hf_kwargs["num_return_sequences"] = v
        elif k == "frequency_penalty":
            hf_kwargs["repetition_penalty"] = 1.0 - v
        elif k == "presence_penalty":
            hf_kwargs["diversity_penalty"] = v
        elif k == "max_tokens":
            hf_kwargs["max_new_tokens"] = v
        elif k == "model":
            pass
        else:
            hf_kwargs[k] = v

    return hf_kwargs


class HFModel(LM):
    def __init__(self, model: str, checkpoint: Optional[str] = None, is_client: bool = False,
                 hf_device_map: Literal["auto", "balanced", "balanced_low_0", "sequential"] = "auto"):
        """wrapper for Hugging Face models

        Args:
            model (str): HF model identifier to load and use
            checkpoint (str, optional): load specific checkpoints of the model. Defaults to None.
            is_client (bool, optional): whether to access models via client. Defaults to False.
            hf_device_map (str, optional): HF config strategy to load the model. 
                Recommeded to use "auto", which will help loading large models using accelerate. Defaults to "auto".
        """

        super().__init__(model)
        self.provider = "hf"
        self.is_client = is_client
        self.device_map = hf_device_map
        if not self.is_client:
            try:
                from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoTokenizer, AutoConfig
                import torch
            except ImportError as exc:
                raise ModuleNotFoundError(
                    "You need to install Hugging Face transformers library to use HF models."
                ) from exc
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            try:
                architecture = AutoConfig.from_pretrained(model).__dict__["architectures"][0]
                self.encoder_decoder_model = ("ConditionalGeneration" in architecture) or ("T5WithLMHeadModel" in architecture)
                self.decoder_only_model = ("CausalLM" in architecture) or ("GPT2LMHeadModel" in architecture)
                assert self.encoder_decoder_model or self.decoder_only_model, f"Unknown HuggingFace model class: {model}"
                self.tokenizer = AutoTokenizer.from_pretrained(model if checkpoint is None else checkpoint)

                self.rationale = True
                AutoModelClass = AutoModelForSeq2SeqLM if self.encoder_decoder_model else AutoModelForCausalLM
                if checkpoint:
                    # with open(os.path.join(checkpoint, '..', 'compiler_config.json'), 'r') as f:
                    #     config = json.load(f)
                    self.rationale = False #config['rationale']
                    # if config['peft']:
                    #     peft_config = PeftConfig.from_pretrained(checkpoint)
                    #     self.model = AutoModelClass.from_pretrained(peft_config.base_model_name_or_path, return_dict=True, load_in_8bit=True, device_map=hf_device_map)
                    #     self.model = PeftModel.from_pretrained(self.model, checkpoint)
                    # else:
                    self.model = AutoModelClass.from_pretrained(checkpoint).to(self.device)
                else:
                    self.model = AutoModelClass.from_pretrained(model).to(self.device)
                self.drop_prompt_from_output = False
            except ValueError:
                self.model = AutoModelForCausalLM.from_pretrained(
                    model if checkpoint is None else checkpoint,
                    device_map=hf_device_map
                )
                self.drop_prompt_from_output = True
                self.tokenizer = AutoTokenizer.from_pretrained(model)
                self.drop_prompt_from_output = True
        self.history = []

    def basic_request(self, prompt, **kwargs):
        raw_kwargs = kwargs
        kwargs = {**self.kwargs, **kwargs}
        response = self._generate(prompt, **kwargs)

        history = {
            "prompt": prompt,
            "response": response,
            "kwargs": kwargs,
            "raw_kwargs": raw_kwargs,
        }
        self.history.append(history)

        return response

    def _generate(self, prompt, **kwargs):
        assert not self.is_client
        # TODO: Add caching
        kwargs = {**openai_to_hf(**self.kwargs), **openai_to_hf(**kwargs)}
        # print(prompt)
        if isinstance(prompt, dict):
            try:
                prompt = prompt['messages'][0]['content']
            except (KeyError, IndexError, TypeError):
                print("Failed to extract 'content' from the prompt.")
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

        # print(kwargs)
        outputs = self.model.generate(**inputs, **kwargs)
        if self.drop_prompt_from_output:
            input_length = inputs.input_ids.shape[1]
            outputs = outputs[:, input_length:]
        completions = [
            {"text": c}
            for c in self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ]
        response = {
            "prompt": prompt,
            "choices": completions,
        }
        return response

    def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs):
        assert only_completed, "for now"
        assert return_sorted is False, "for now"

        if kwargs.get("n", 1) > 1 or kwargs.get("temperature", 0.0) > 0.1:
            kwargs["do_sample"] = True

        response = self.request(prompt, **kwargs)
        return [c["text"] for c in response["choices"]]


# @functools.lru_cache(maxsize=None if cache_turn_on else 0)
# @NotebookCacheMemory.cache
# def cached_generate(self, prompt, **kwargs):
#      return self._generate(prompt, **kwargs)