File size: 15,243 Bytes
cd0431b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 |
import copy
from typing import List, Optional, Tuple, Union
import numpy
import transformers
from tqdm import tqdm
import lm_eval.models.utils
from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM
eval_logger = utils.eval_logger
@register_model("sparseml")
class SparseMLLM(HFLM):
"""
SparseML is an open-source model optimization toolkit that enables you to create
inference-optimized sparse models using pruning, quantization, and distillation
algorithms. Models optimized with SparseML can then be exported to the ONNX format and
deployed with DeepSparse for GPU-class performance on CPU hardware.
This class is a wrapper around the HuggingFace LM class to enable SparseML
integration with the lm-evaluation-harness.
"""
def _create_model(
self,
pretrained: str,
revision: Optional[str] = "main",
dtype: Optional[str] = "auto",
trust_remote_code: Optional[bool] = False,
**kwargs,
) -> None:
try:
from sparseml.transformers import SparseAutoModelForCausalLM
except ModuleNotFoundError:
raise Exception(
"Package `sparseml` is not installed. "
"Please install it via `pip install sparseml[transformers]`"
)
model_kwargs = kwargs if kwargs else {}
if "device_map" not in model_kwargs:
# set a device_map to initialize model on the right GPU.
# this is needed because it seems that the default behavior
# for quantized models now seems to be device_map="auto"
# which breaks data-parallel mode.
if hasattr(self, "accelerator"):
model_kwargs.update(
{"device_map": {"": f"cuda:{self.accelerator.local_process_index}"}}
)
else:
model_kwargs.update({"device_map": {"": str(self.device)}})
relevant_kwarg_names = [
"offload_folder",
"device_map",
]
relevant_kwargs = {
k: v for k, v in model_kwargs.items() if k in relevant_kwarg_names
}
# Log the difference between model_kwargs and relevant_kwargs so we can see
# what is being ignored
ignored_kwargs = {}
for k, v in model_kwargs.items():
if k not in relevant_kwargs.keys():
ignored_kwargs[k] = v
eval_logger.warning(
f"The sparseml integration is ignoring the following kwargs that are specified: {ignored_kwargs}"
)
model = SparseAutoModelForCausalLM.from_pretrained(
pretrained,
revision=revision,
torch_dtype=lm_eval.models.utils.get_dtype(dtype),
trust_remote_code=trust_remote_code,
**relevant_kwargs,
)
self._model = model
def _get_config(self, pretrained: str, **kwargs) -> None:
try:
from sparseml.transformers import SparseAutoConfig
except ModuleNotFoundError:
raise Exception(
"Package `sparseml` is not installed. "
"Please install it via `pip install sparseml[transformers]`"
)
self._config = SparseAutoConfig.from_pretrained(
pretrained_model_name_or_path=pretrained, **kwargs
)
def _create_tokenizer(
self,
pretrained: Union[str, transformers.PreTrainedModel],
tokenizer: Optional[
Union[
str,
transformers.PreTrainedTokenizer,
transformers.PreTrainedTokenizerFast,
]
],
**kwargs,
) -> None:
try:
from sparseml.transformers import SparseAutoTokenizer
except ModuleNotFoundError:
raise Exception(
"Package `sparseml` is not installed. "
"Please install it via `pip install sparseml[transformers]`"
)
if tokenizer:
if isinstance(tokenizer, str):
self.tokenizer = SparseAutoTokenizer.from_pretrained(
tokenizer,
**kwargs,
)
else:
assert isinstance(
tokenizer, transformers.PreTrainedTokenizer
) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
self.tokenizer = tokenizer
else:
# Get tokenizer based on 'pretrained'
if isinstance(pretrained, str):
model_name = pretrained
else:
# get the HF hub name via accessor on model
model_name = self.model.name_or_path
self.tokenizer = SparseAutoTokenizer.from_pretrained(
model_name,
**kwargs,
)
return None
@register_model("deepsparse")
class DeepSparseLM(LM):
"""
Wrapper around DeepSparse, a sparsity-aware deep learning
inference runtime for CPUs, to make it compatible with the
lm-evaluation-harness.
"""
_DEFAULT_MAX_LENGTH = 2048
def __init__(
self,
pretrained: str,
tokenizer: Optional[
Union[
str,
transformers.PreTrainedTokenizer,
transformers.PreTrainedTokenizerFast,
]
] = None,
batch_size: Optional[Union[int, str]] = 1,
max_gen_toks: Optional[int] = 256,
max_length: Optional[int] = None,
):
super().__init__()
try:
import deepsparse
except ModuleNotFoundError:
raise Exception(
"Package `deepsparse` is not installed. "
"Please install it via `pip install deepsparse[transformers]`"
)
if isinstance(batch_size, str) and not batch_size.isdigit():
eval_logger.warning(
f"batch_size={batch_size} is not valid for deepsparse because it is not an integer. "
"Ignoring and using the default of 1."
)
batch_size = 1
self.batch_size = int(batch_size)
self._max_length = max_length if max_length else self._DEFAULT_MAX_LENGTH
self._max_gen_toks = max_gen_toks
self.batch_sizes = {}
# Initialize new model and tokenizer instances
self.model = deepsparse.TextGeneration(
model_path=pretrained,
sequence_length=self._max_length,
batch_size=batch_size,
)
self.tokenizer = tokenizer if tokenizer else self.model.tokenizer
self.config = self.model.config
def tok_encode(self, string: str) -> List[int]:
return self.tokenizer.encode(string)
def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens)
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def prefix_token_id(self):
# it is used as prefix for loglikelihood
if self.tokenizer.bos_token_id is not None:
return self.tokenizer.bos_token_id
return self.tokenizer.eos_token_id
@property
def max_length(self) -> int:
return self._max_length
@property
def max_gen_toks(self) -> int:
return self._max_gen_toks
def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
"""
Copied directly from
https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/huggingface.py
"""
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
raise NotImplementedError(
"Implementing empty context is not supported yet"
)
context_enc, continuation_enc = self._encode_pair(context, continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def _loglikelihood_tokens(
self,
requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
disable_tqdm: bool = False,
) -> List[Tuple[float, bool]]:
"""
The function to compute the loglikelihood of the continuation
tokens given the context tokens.
This function is an adapted version of the original function from
https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/huggingface.py
"""
res = []
def _collate(x):
"""Defines the key for the sorted method"""
toks = x[1] + x[2]
return -len(toks), tuple(toks)
re_ord = utils.Reorderer(requests, _collate)
for chunk in tqdm(
list(lm_eval.models.utils.chunks(re_ord.get_reordered(), self.batch_size)),
disable=disable_tqdm,
):
batch_inp = []
batch_cache_key = []
batch_continuation_enc = []
# len(chunk) is the batch_size
for cache_key, context_enc, continuation_enc in chunk:
# how this all works (illustrated on a causal decoder-only setup):
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# model \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice # noqa: E501
inp = (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
batch_inp.append(self.tokenizer.decode(inp))
batch_cache_key.append(cache_key)
batch_continuation_enc.append(continuation_enc)
response = self.model(
prompt=batch_inp,
max_new_tokens=0,
output_scores=True,
include_prompt_logits=True,
)
for resp, continuation_enc, cache_key in zip(
response.generations, batch_continuation_enc, batch_cache_key
):
# (seq_len, vocab_size)
multi_scores = resp.score
from deepsparse.utils.data import numpy_log_softmax
# (seq_len, vocab_size) but with softmax applied
multi_logits = numpy_log_softmax(multi_scores, axis=1)
# toss out the context half of the sequence
# (cont_len, vocab_size)
continuation_multi_logits = multi_logits[-len(continuation_enc) :]
# pick out the logits for the continuation tokens
# (cont_len,)
continuation_logits = continuation_multi_logits[
numpy.arange(len(continuation_enc)), continuation_enc
]
# check if the tokens generated greedly are the same
# as the expected continuation
greedy_tokens = continuation_multi_logits.argmax(axis=1)
max_equal = greedy_tokens.tolist() == continuation_enc
# Answer: (log prob, is-exact-match)
answer = (float(continuation_logits.sum()), bool(max_equal))
res.append(answer)
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return re_ord.get_original(res)
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
raise NotImplementedError(
"The method not required by any of our current task integrations so far"
)
def generate_until(self, requests: List[Instance]) -> List[str]:
"""
The function to generate a certain number of new tokens
given a context.
This function is an adapted version of the original function from
https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/openai_completions.py
"""
if not requests:
return []
res = []
requests = [req.args for req in requests]
def _collate(x):
toks = self.tok_encode(x[0])
return len(toks), x[0]
re_ord = utils.Reorderer(requests, _collate)
def sameuntil_chunks(xs, size):
ret = []
lastuntil = xs[0][1]
for x in xs:
if len(ret) >= size or x[1] != lastuntil:
yield ret, lastuntil
ret = []
lastuntil = x[1]
ret.append(x)
if ret:
yield ret, lastuntil
pbar = tqdm(total=len(requests))
for chunk, request_args in tqdm(
list(sameuntil_chunks(re_ord.get_reordered(), self.batch_size))
):
inps = []
# make a deepcopy since we are changing arguments
request_args = copy.deepcopy(request_args)
self._max_gen_toks = request_args.pop("max_gen_toks", self.max_gen_toks)
for context, _ in chunk:
# add context (prompts) to the list
inps.append(context)
until = request_args.pop("until", ["<|endoftext|>"])
request_args.pop("do_sample", None)
request_args["temperature"] = request_args.get("temperature", 0)
# run inference (generate max_gen_toks tokens)
out = self.model(
sequences=inps,
max_new_tokens=self.max_gen_toks - 1,
stop=until,
**request_args,
)
for resp, (context, args_) in zip(out.generations, chunk):
text = resp.text
until_ = until
# split the text at the first occurrence of any of the until tokens
for term in until_:
if len(term) > 0:
text = text.split(term)[0]
res.append(text)
self.cache_hook.add_partial(
"generate_until", (context, {"until": until_}), text
)
pbar.update(1)
pbar.close()
return re_ord.get_original(res)
def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
"""
Copied directly from
https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/huggingface.py
"""
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
|