File size: 19,024 Bytes
cd0431b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 |
import copy
from importlib.metadata import version
from importlib.util import find_spec
from typing import List, Literal, Optional, Tuple, Union
from more_itertools import distribute
from packaging.version import parse as parse_version
from tqdm import tqdm
from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model
from lm_eval.models.utils import Collator, undistribute
from lm_eval.utils import (
eval_logger,
get_rolling_token_windows,
make_disjoint_window,
)
try:
import ray
from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
except ModuleNotFoundError:
pass
eval_logger = eval_logger
@register_model("vllm")
class VLLM(TemplateLM):
_DEFAULT_MAX_LENGTH = 2048
def __init__(
self,
pretrained="gpt2",
dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
revision: Optional[str] = None,
trust_remote_code: Optional[bool] = False,
tokenizer: Optional[str] = None,
tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None,
add_bos_token: Optional[bool] = False,
prefix_token_id: Optional[int] = None,
tensor_parallel_size: int = 1,
quantization: Optional[str] = None,
max_gen_toks: int = 256,
swap_space: int = 4,
batch_size: Union[str, int] = 1,
max_batch_size=None,
max_length: int = None,
max_model_len: int = None,
seed: int = 1234,
gpu_memory_utilization: float = 0.9,
device: str = "cuda",
data_parallel_size: int = 1,
**kwargs,
):
super().__init__()
if not find_spec("vllm"):
raise Exception(
"attempted to use 'vllm' LM type, but package `vllm` is not installed. "
"Please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
)
assert "cuda" in device or device is None, "vLLM only supports CUDA"
assert (
max_length is None or max_model_len is None
), "Either max_length or max_model_len may be provided, but not both"
self._max_length = max_model_len if max_model_len is not None else max_length
self.tensor_parallel_size = int(tensor_parallel_size)
self.data_parallel_size = int(data_parallel_size)
self.model_args = {
"model": pretrained,
"gpu_memory_utilization": float(gpu_memory_utilization),
"revision": revision,
"dtype": dtype,
"tokenizer": tokenizer,
"tokenizer_mode": tokenizer_mode,
"tokenizer_revision": tokenizer_revision,
"trust_remote_code": trust_remote_code,
"tensor_parallel_size": int(tensor_parallel_size),
"max_model_len": int(self._max_length) if self._max_length else None,
"swap_space": int(swap_space),
"quantization": quantization,
"seed": int(seed),
}
self.model_args.update(kwargs)
self.batch_size = (
"auto"
if isinstance(batch_size, str) and "auto" in batch_size
else batch_size
)
if self.data_parallel_size <= 1:
self.model = LLM(**self.model_args)
else:
assert parse_version(version("vllm")) < parse_version(
"0.3.3"
), "data_parallel is only compatible with vllm < v0.3.3."
eval_logger.warning(
"You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached."
)
self.model_args["worker_use_ray"] = True
self.batch_size = "auto"
eval_logger.info("Manual batching is not compatible with data parallelism.")
from transformers import AutoConfig
self._config = AutoConfig.from_pretrained(
pretrained, trust_remote_code=trust_remote_code, revision=revision
)
self.tokenizer = get_tokenizer(
tokenizer if tokenizer else pretrained,
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
)
self.add_bos_token = add_bos_token
self.custom_prefix_token_id = prefix_token_id
if prefix_token_id is not None:
eval_logger.info(
f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
)
self._max_gen_toks = max_gen_toks
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def prefix_token_id(self):
# it is used as prefix for loglikelihood
if self.custom_prefix_token_id is not None:
return self.custom_prefix_token_id
if self.tokenizer.bos_token_id is not None:
return self.tokenizer.bos_token_id
return self.tokenizer.eos_token_id
@property
def max_length(self):
if self._max_length: # if max length manually set, return it
return self._max_length
if self.data_parallel_size <= 1:
return self.model.llm_engine.model_config.max_model_len
else:
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
for attr in seqlen_config_attrs:
if hasattr(self._config, attr):
return getattr(self._config, attr)
if hasattr(self.tokenizer, "model_max_length"):
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
return self._DEFAULT_MAX_LENGTH
return self.tokenizer.model_max_length
return self._DEFAULT_MAX_LENGTH
@property
def max_gen_toks(self):
return self._max_gen_toks
def tok_encode(
self,
string: str,
left_truncate_len=None,
add_special_tokens=None,
truncation=False,
):
""" """
if not add_special_tokens:
add_special_tokens = False or self.add_bos_token
encoding = self.tokenizer.encode(
string, add_special_tokens=add_special_tokens, truncation=truncation
)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len:
encoding = encoding[-left_truncate_len:]
return encoding
def _model_generate(
self,
requests: List[List[int]] = None,
generate: bool = False,
max_tokens: int = None,
stop: Optional[List[str]] = None,
**kwargs,
):
if generate:
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
else:
sampling_params = SamplingParams(
temperature=0, prompt_logprobs=1, max_tokens=1
)
if self.data_parallel_size > 1:
# vLLM hangs if tensor_parallel > 1 and resources are set in ray.remote
# also seems to only work with decorator and not with ray.remote() fn
# see https://github.com/vllm-project/vllm/issues/973
# note: this has changed on 0.3.3, and it only works now if num_gpus are set.
# but then tensor_parallel breaks
@ray.remote
def run_inference_one_model(
model_args: dict, sampling_params, requests: List[List[int]]
):
llm = LLM(**model_args)
return llm.generate(
prompt_token_ids=requests, sampling_params=sampling_params
)
# dispatch requests to all self.data_parallel_size workers, in interleaved fashion
# interleaved important to balance context lengths across workers
requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
inputs = ((self.model_args, sampling_params, req) for req in requests)
object_refs = [run_inference_one_model.remote(*x) for x in inputs]
results = ray.get(object_refs)
# Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
ray.shutdown()
# flatten results
return undistribute(results)
outputs = self.model.generate(
prompt_token_ids=requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
)
return outputs
def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]:
loglikelihoods = []
for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm):
rolling_token_windows = list(
map(
make_disjoint_window,
get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length - 1,
context_len=1,
),
)
)
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
string_nll = self._loglikelihood_tokens(
rolling_token_windows,
)
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
res = []
# batch tokenize contexts
context, all_gen_kwargs = zip(*(req.args for req in requests))
context_encoding = self.tokenizer(context, add_special_tokens=False).input_ids
requests = [
((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs)
]
def _collate_gen(_requests):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
return -len(_requests[0][1]), _requests[0][0]
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs")
chunks = re_ords.get_batched(
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
)
pbar = tqdm(
total=len(requests),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running generate_until requests",
)
# for each different set of kwargs, we execute all requests, by batch.
for chunk in chunks:
context_and_encoding, all_gen_kwargs = zip(*chunk)
context, context_encoding = zip(*context_and_encoding)
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}"
)
# add EOS token to stop sequences
eos = self.tokenizer.decode(self.eot_token_id)
if not until:
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
context_encoding = [x[-max_ctx_len:] for x in context_encoding]
# perform batched generation
cont = self._model_generate(
requests=context_encoding,
generate=True,
max_tokens=max_gen_toks,
stop=until,
**kwargs,
)
# cache generations
for output, context in zip(cont, context):
generated_text = output.outputs[0].text
res.append(generated_text)
self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text
)
pbar.update(1)
pbar.close()
# reorder all group of results back to original unsorted form
return re_ords.get_original(res)
def _loglikelihood_tokens(
self,
requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
disable_tqdm: bool = False,
) -> List[Tuple[float, bool]]:
res = []
def _collate(x):
toks = x[1] + x[2]
return -len(toks), tuple(toks)
# Reorder requests by length and batch
re_ord = Collator(requests, sort_fn=_collate)
chunks = re_ord.get_batched(
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
)
pbar = tqdm(
total=len(requests),
disable=disable_tqdm,
desc="Running loglikelihood requests",
)
for chunk in chunks:
inputs = []
ctxlens = []
for cache_key, context_enc, continuation_enc in chunk:
inp = (context_enc + continuation_enc)[-(self.max_length) :]
ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length)
)
inputs.append(inp)
ctxlens.append(ctxlen)
outputs = self._model_generate(requests=inputs, generate=False)
for output, ctxlen, (cache_key, _, _), inp in zip(
outputs, ctxlens, chunk, inputs
):
answer = self._parse_logprobs(
tokens=inp,
outputs=output,
ctxlen=ctxlen,
)
res.append(answer)
# partial caching
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
pbar.update(1)
pbar.close()
return re_ord.get_original(res)
@staticmethod
def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]:
"""Process logprobs and tokens.
:param tokens: list
Input tokens (potentially left-truncated)
:param outputs: RequestOutput
Contains prompt_logprobs
:param ctxlen: int
Length of context (so we can slice them away and only keep the predictions)
:return:
continuation_logprobs: float
Log probabilities of continuation tokens
is_greedy: bool
Whether argmax matches given continuation exactly
"""
# The first entry of prompt_logprobs is None because the model has no previous tokens to condition on.
continuation_logprobs_dicts = outputs.prompt_logprobs
def coerce_logprob_to_num(logprob):
# vLLM changed the return type of logprobs from float
# to a Logprob object storing the float value + extra data
# (https://github.com/vllm-project/vllm/pull/3065).
# If we are dealing with vllm's Logprob object, return
# the logprob value stored as an attribute. Otherwise,
# return the object itself (which should be a float
# for older versions of vLLM).
return getattr(logprob, "logprob", logprob)
continuation_logprobs_dicts = [
{
token: coerce_logprob_to_num(logprob)
for token, logprob in logprob_dict.items()
}
if logprob_dict is not None
else None
for logprob_dict in continuation_logprobs_dicts
]
# Calculate continuation_logprobs
# assume ctxlen always >= 1
continuation_logprobs = sum(
logprob_dict.get(token)
for token, logprob_dict in zip(
tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
)
)
# Determine if is_greedy
is_greedy = True
for token, logprob_dict in zip(
tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
):
# Get the token with the maximum log probability from the logprob_dict
if logprob_dict: # Ensure the logprob_dict is not None
top_token = max(logprob_dict, key=logprob_dict.get)
if top_token != token:
is_greedy = False
break
return continuation_logprobs, is_greedy
@staticmethod
def modify_gen_kwargs(kwargs: dict) -> dict:
# sampling_params
do_sample = kwargs.pop("do_sample", None)
if do_sample is False or "temperature" not in kwargs:
kwargs["temperature"] = 0.0
# hf defaults
kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False)
kwargs["spaces_between_special_tokens"] = kwargs.get(
"spaces_between_special_tokens", False
)
return kwargs
|