File size: 21,677 Bytes
f238c4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 |
import collections
import fnmatch
import gc
import itertools
import time
from functools import wraps
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Literal,
Optional,
Tuple,
Type,
Union,
)
import torch
import transformers
from lm_eval.utils import eval_logger
def chunks(iter, n: int = 0, fn=None):
"""
Divides an iterable into chunks of specified size or based on a given function.
Useful for batching
Parameters:
- iter: The input iterable to be divided into chunks.
- n: An integer representing the size of each chunk. Default is 0.
- fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
Returns:
An iterator that yields chunks of the input iterable.
Example usage:
```
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
for chunk in chunks(data, 3):
print(chunk)
```
Output:
```
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
[10]
```
"""
arr = []
for i, x in enumerate(iter):
arr.append(x)
if len(arr) == (fn(i, iter) if fn else n):
yield arr
arr = []
if arr:
yield arr
class MultiChoice:
def __init__(self, choices) -> None:
self.choices = choices
# Simple wildcard support (linux filename patterns)
def __contains__(self, values) -> bool:
for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0:
eval_logger.info("Available tasks to choose:")
for choice in self.choices:
eval_logger.info(f" - {choice}")
raise ValueError("'{}' is not in task list".format(value))
return True
def __iter__(self) -> Iterator:
for choice in self.choices:
yield choice
class Grouper:
"""
takes an array `arr` and function `fn` and returns a dictionary
with keys fn(ob) for each ob in `arr` and with values `self.arr[key]` a list of all
objects in `arr` satisfying `key == fn(ob)`.
"""
def __init__(self, arr, fn) -> None:
# self.orig_arr = arr
self.size = len(arr)
arr = list(enumerate(arr))
def group_return_dict(arr, fn):
res = collections.defaultdict(list)
for ob in arr:
res[fn(ob)].append(ob)
return res
arr = group_return_dict(arr, lambda x: fn(x[1]))
# self.arr has format Dict[Tuple[int, <entry from orig. arr>]]
self.arr = arr
self._grouped = None
def get_grouped(self):
# return the contents but not indices for our grouped dict.
if self._grouped:
return self._grouped
grouped = {}
for key in self.arr.keys():
# drop the index from each element of self.arr
grouped[key] = [y[1] for y in self.arr[key]]
self._grouped = grouped
return grouped
def get_original(self, grouped_dict):
# take in a grouped dictionary with e.g. results for each key listed
# in the same order as the instances in `self.arr`, and
# return the results in the same (single list) order as `self.orig_arr`.
res = [None] * self.size
cov = [False] * self.size
# orig = [None] * self.size
assert grouped_dict.keys() == self.arr.keys()
for key in grouped_dict.keys():
for (ind, _), v in zip(self.arr[key], grouped_dict[key]):
res[ind] = v
cov[ind] = True
# orig[ind] = _
assert all(cov)
# assert orig == self.orig_arr
return res
def pad_and_concat(
max_length: int,
tensors: List[torch.Tensor],
padding_side: Literal["right", "left"] = "right",
):
"""
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
seq2seq models.
"""
assert (
padding_side == "left" or padding_side == "right"
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors):
if len(tensor.shape) == 2:
tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size
tensor_len = tensor.shape[0]
if tensor_len < max_length:
if padding_side == "right":
# right-pad
tensors[i] = torch.cat(
[
tensor, # [seq]
torch.zeros(
max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq]
],
dim=0,
).unsqueeze(0)
else:
# left-pad
tensors[i] = torch.cat(
[
torch.zeros(
max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq]
tensor, # [seq]
],
dim=0,
).unsqueeze(0)
else:
tensors[i] = tensor.unsqueeze(0)
return torch.cat(tensors, dim=0)
def clear_torch_cache() -> None:
gc.collect()
torch.cuda.empty_cache()
def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype = getattr(torch, dtype)
else:
_torch_dtype = dtype
return _torch_dtype
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""
def __init__(
self,
sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
initial_decoder_input_length: int,
batch_size: int,
) -> None:
self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
# print(sequence, self.sequence_ids)
# we look back for 2 more tokens than it takes to encode our stop sequence
# because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
# and we don't want to mistakenly not stop a generation because our
# (string) stop sequence was output in a different tokenization
# NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
# and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
# Additionally, in lookback_ids_batch we should prevent ever looking back into the inputs as described.
self.sequence_id_len = len(self.sequence_ids) + 2
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :]
lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
return False not in self.done_tracker
def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizer,
stop_sequences: List[str],
initial_decoder_input_length: int,
batch_size: int,
) -> transformers.StoppingCriteriaList:
return transformers.StoppingCriteriaList(
[
*[
MultiTokenEOSCriteria(
sequence, tokenizer, initial_decoder_input_length, batch_size
)
for sequence in stop_sequences
],
]
)
def undistribute(iterable):
"""
Undoes https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distribute .
Re-interleaves results that have been split using more_itertools.distribute:
>>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
>>> list(group_1)
[1, 3, 5]
>>> list(group_2)
[2, 4, 6]
>>> undistribute([group_1, group_2])
[1, 2, 3, 4, 5, 6]
Handles non-uniform component lengths:
>>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
>>> [list(c) for c in children]
[[1, 4, 7], [2, 5], [3, 6]]
>>> undistribute(children)
[1, 2, 3, 4, 5, 6, 7]
Also handles when some iterables are empty:
>>> children = distribute(5, [1, 2, 3])
>>> [list(c) for c in children]
[[1], [2], [3], [], []]
>>> undistribute(children)
[1, 2, 3]
"""
return [
x
for x in itertools.chain.from_iterable(
itertools.zip_longest(*[list(x) for x in iterable])
)
if x is not None
]
def retry_on_specific_exceptions(
on_exceptions: List[Type[Exception]],
max_retries: Optional[int] = None,
backoff_time: float = 3.0,
backoff_multiplier: float = 1.5,
on_exception_callback: Optional[Callable[[Exception, float], Any]] = None,
):
"""Retry on an LLM Provider's rate limit error with exponential backoff
For example, to use for OpenAI, do the following:
```
from openai import RateLimitError
# Recommend specifying max_retries to avoid infinite loops!
@retry_on_specific_exceptions([RateLimitError], max_retries=3)
def completion(...):
# Wrap OpenAI completion function here
...
```
"""
def decorator(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
sleep_time = backoff_time
attempt = 0
while max_retries is None or attempt < max_retries:
try:
return func(*args, **kwargs)
except tuple(on_exceptions) as e:
if on_exception_callback is not None:
on_exception_callback(e, sleep_time)
time.sleep(sleep_time)
sleep_time *= backoff_multiplier
attempt += 1
return wrapper
return decorator
class Collator:
"""
A class for reordering and batching elements of an array.
This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data.
Objects of this class have the group_by attribute which determines the method for grouping
the data while batching it. Three options include "gen_kwargs", "contexts", or None:
If group_by == "gen_kwargs" then requests will be grouped by gen_kwargs
If group_by == "contexts" then requests will be grouped by context + cont[:-1]
If None then requests will just be reordered by length descending.
"""
def __init__(
self,
arr: List,
sort_fn: Callable = lambda x: x,
group_fn: Callable = lambda x: x[1],
group_by: Union[Literal["gen_kwargs", "contexts"], None] = None,
) -> None:
self._group_by = group_by
# 0 indices are enumerated indices. Apply functions to original arr.
self._sort_fn = lambda x: sort_fn(x[1])
self._group_fn = lambda x: group_fn(x[1])
self._reorder_indices: List = []
self._size = len(arr)
self._arr_with_indices: Union[Dict, Tuple[Tuple[int, Any], ...]] = tuple(
enumerate(arr)
) # [indices, (arr)]
if self._group_by == "contexts":
self._group_by_context()
elif self._group_by == "gen_kwargs":
self._group_by_index()
def _group_by_index(self) -> None:
"""Group the elements of a list based on their indices."""
self._arr_with_indices = self.group(
self._arr_with_indices, fn=self._group_fn, group_by="gen_kwargs"
)
def _group_by_context(self) -> None:
"""Group the array with indices by context."""
self._arr_with_indices = self.group(
self._arr_with_indices, fn=self._group_fn, group_by="contexts"
)
def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator:
"""
Generates and yields batches from the reordered array. The method of grouping and batching
depends on the parameter `group_by`.
If `group_by` is set to "gen_kwargs", it will batch the
re-ordered values with same gen_kwargs for each batch.
If `group_by` is "contexts", it caches the requests by context before batching.
If `group_by` is neither "gen_kwargs" nor "contexts", it yields the reordered array
Parameters:
- n (int): The size of each batch. Defaults to 1.
- batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of
each batch. Optional, defaults to None.
Returns:
Iterator: An iterator over batches of reordered elements grouped as per the `group_by`
attribute.
Yields:
List of batched elements according to the `group_by` attribute.
"""
if self._group_by == "gen_kwargs":
for (
key,
values,
) in self._arr_with_indices.items(): # type: ignore
values = self._reorder(values)
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
elif self._group_by == "contexts":
# Get one sample from each key
values = self._reorder(
[value[0] for value in self._arr_with_indices.values()]
)
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
else:
values = self._reorder(self._arr_with_indices) # type: ignore
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
def get_cache(
self,
req_str: Tuple[str, str] = None,
cxt_toks: List[int] = None,
cont_toks: List[int] = None,
logits: torch.Tensor = None,
) -> Iterator[Tuple[Tuple[str, str], List[int], torch.Tensor]]:
"""
Retrieves cached single-token continuations and their associated arguments, updating indices as necessary.
The behavior of this function varies depending on how the `group_by` attribute is set:
- When `group_by` is "contexts":
The function identifies single-token continuations by checking for keys that equate to
[context+continuation][-1] and logs the indices for re-ordering.
In this mode, this function can work in two scenarios:
1. Cache Hit - Single Match:
If a single matching context-continuation pair is found in the cache,
the function yields the original arguments.
2. Cache Hit - Multiple Matches:
If multiple matching context-continuation pairs are found in the cache,
the function expands the logits batch dimension to match the number of cache hits.
It updates the original requests and continuation tokens.
- When `group_by` is not set to "contexts":
This method yields the original arguments, logits and continuation tokens,
without checking for one-token continuations.
Parameters:
- req_str (tuple[str, str]): Original strings used for CachingLM.
- cxt_toks (list[int]): Full context tokens used for lookup.
- cont_toks (list[int]): Continuation tokens for which logits were generated.
- logits (torch.Tensor [1, seq_length, vocab_size]): Logits generated by the model given context and continuation keys.
Yields:
- Iterator:
- req_str (tuple[str, str]): strings used for CachingLM.
- cont_toks (list[int]) : continuation tokens.
- logits (torch.Tensor [1, seq_length, vocab_size]): The original logits (repeated cache hit times)
"""
if self._group_by == "contexts":
cache_hit: List[
Tuple[int, Tuple[Tuple[str, str], List[int], List[int]]]
] = self._arr_with_indices.pop(tuple(cxt_toks + cont_toks[:-1]))
if (cache_size := len(cache_hit)) == 1:
self._reorder_indices.extend(x[0] for x in cache_hit)
yield req_str, cont_toks, logits
else:
# If we have matching requests then expand the batch dimension (no-op) and
# yield each along with its corresponding args.
multilogits = logits.expand(cache_size, -1, -1).chunk(cache_size)
indices, req_str, cont_toks = zip(
*[(x[0], x[1][0], x[-1][-1]) for x in cache_hit]
)
self._reorder_indices.extend(indices)
for c_key, cont_tok, logit in zip(req_str, cont_toks, multilogits):
yield c_key, cont_tok, logit
else:
yield req_str, cont_toks, logits
def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> Iterator:
"""
Reorders the elements in the array based on the sorting function.
Parameters:
- arr (list | tuple[tuple[int, Any], ...]]): The array or iterable to be reordered.
Yields:
Iterator
"""
arr = sorted(arr, key=self._sort_fn)
if not self._group_by == "contexts":
# If grouped by contexts then indices will be set in get_cache()
self._reorder_indices.extend([x[0] for x in arr])
yield from [x[1] for x in arr]
def get_original(self, newarr: List) -> List:
"""
Restores the original order of elements from the reordered list.
Parameters:
- newarr (list): The reordered array.
Returns:
list: The array with elements restored to their original order.
"""
res = [None] * self._size
cov = [False] * self._size
for ind, v in zip(self._reorder_indices, newarr):
res[ind] = v
cov[ind] = True
assert all(cov)
return res
def __len__(self):
return self._size
@staticmethod
def group(
arr: Iterable,
fn: Callable,
group_by: Literal["gen_kwargs", "contexts"] = "gen_kwargs",
) -> dict:
"""
Groups elements of an iterable based on a provided function.
The `group_by` parameter determines the method of grouping.
If `group_by` is "contexts", the elements are grouped by [context + cont][:-1].
If `group_by` is "gen_kwargs", the elements are grouped based on the gen_kwargs dict.
Parameters:
- arr (Iterable): The iterable to be grouped.
- fn (Callable): The function to determine the grouping.
- values (bool): If True, returns the values of the group. Defaults to False.
Returns:
Iterator: An iterable of grouped elements.
"""
res = collections.defaultdict(list)
for ob in arr:
# where ob == [context + cont]
if group_by == "contexts":
res[tuple(fn(ob))].append(ob)
else:
try:
hashable_dict = tuple(
(
key,
tuple(value)
if isinstance(value, collections.abc.Iterable)
else value,
)
for key, value in sorted(fn(ob).items())
)
res[hashable_dict].append(ob)
except (TypeError, AttributeError):
res[tuple(fn(ob))].append(ob)
return res
@staticmethod
def get_chunks(_iter, n: int = 0, fn=None):
"""
Divides an iterable into chunks of specified size or based on a given function.
Useful for batching
Parameters:
- iter: The input iterable to be divided into chunks.
- n: An integer representing the size of each chunk. Default is 0.
- fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
Returns:
An iterator that yields chunks of the input iterable.
Example usage:
```
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
for chunk in chunks(data, 3):
print(chunk)
```
Output:
```
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
[10]
```
"""
arr = []
_iter = tuple(_iter)
for i, x in enumerate(_iter):
arr.append(x)
if len(arr) == (fn(i, _iter) if fn else n):
yield arr
arr = []
if arr:
yield arr
|