Spaces:
Configuration error
Configuration error
import os | |
import torch | |
import time | |
import numpy as np | |
import torch.distributed as dist | |
from transformers.utils import logging | |
from transformers import AutoTokenizer | |
from itertools import cycle | |
from typing import List | |
logger = logging.get_logger(__name__) | |
class Memory(torch.nn.Module): | |
def __init__( | |
self, | |
model_config, | |
k_seq_dim:int=2, | |
v_seq_dim:int=2, | |
): | |
"""Setup necessary attributes.""" | |
super().__init__() | |
self.config = model_config | |
# initialize necessary parameters | |
self.k_seq_dim = k_seq_dim | |
self.v_seq_dim = v_seq_dim | |
self.rng = np.random.default_rng(42) | |
self._post_validation() | |
self.reset() | |
def beacon_token(self): | |
return self.config.vocab_size | |
def _post_validation(self, verbose=True): | |
assert self.config.beacon_window >= self.config.beacon_stride, f"Make sure the beacon_window {self.config.beacon_window} >= beacon_stride {self.config.beacon_stride}!" | |
for ratio in self.config.beacon_ratio: | |
assert ratio >= 0, f"Make sure all beacon ratios are greater than or equal to 0, found {self.config.beacon_ratio}!" | |
assert self.config.beacon_attn in ["segmentation", "step-expansion", "full-coverage"], f"beacon_attn {self.config.beacon_attn} not implemented!" | |
assert self.config.beacon_ratio_mix in ["instance-random", "step-random", "sequence"] or "adapt-" in self.config.beacon_ratio_mix, f"beacon_ratio_mix {self.config.beacon_ratio_mix} not implemented!" | |
# assert self.config.beacon_pos in ["append", "interleave"], f"beacon_pos {self.config.beacon_pos} not implemented!" | |
if self.config.beacon_pos == "interleave": | |
assert self.config.beacon_window == self.config.beacon_stride, f"Make sure the beacon_window equals to beacon_stride when using interleaving mode." | |
if self.config.beacon_parallel_window > 1: | |
assert self.config._attn_implementation != "flash_attention_2", f"Currently parallel window does not support flash_attention_2!" | |
self._cpu = torch.device("cpu") | |
if verbose: | |
info = f"applying activation beacon on {self.config.beacon_param} (the beacon embedding is initialized from {'bos' if self.config.beacon_embed_init == 'bos' else 'eos'} embedding, the beacon tokens are positioned with '{self.config.beacon_pos}' method), with window size {self.config.beacon_window}, stride {self.config.beacon_stride}, {self.config.beacon_attn} attention{' (attending to previous beacons)' if self.config.beacon_attend_prev else ' (no attending to previous beacons)'}, sink size {self.config.beacon_sink_size}, compression ratio {self.config.beacon_ratio} (mixed by {self.config.beacon_ratio_mix})..." | |
logger.info(info) | |
def set(self, verbose=True, **kwargs): | |
""" | |
Set attributes out of the constructor. | |
""" | |
for k, v in kwargs.items(): | |
setattr(self.config, k, v) | |
self._post_validation(verbose=verbose) | |
def reset(self): | |
"""Initialize attributes for a new sequence.""" | |
# the cursor pointing to the start of the current window | |
self.start_idx = 0 | |
# the cursor pointing to the end of the current window | |
self.end_idx = 0 | |
# the beacon sizes of all strides | |
self.all_beacon_sizes = [] | |
# the loss per batch | |
self.batch_loss = None | |
# the valid token number per batch | |
self.valid_token_num = None | |
# the step index for processing the input_ids | |
self.step_idx = 0 | |
# used in set_compression_ratio | |
self.compression_ratio = None | |
# the previous inputs is a full window or not, defaults to True | |
self.is_full_window = True | |
# the number of raw activations to preserve in update_memory (only useful when beacon_stride < beacon_window) | |
self.raw_size_to_cache = 0 | |
# the number of tokens in previous stride that should be compressed by the upcoming beacon | |
self.interleave_remainder = 0 | |
# compression ratio for the unfinished window | |
self.interleave_compression_ratio = None | |
self.beacon_indices = None | |
self.all_input_ids = None | |
self.all_attention_mask = None | |
self.all_labels = None | |
# NOTE: will be reset in prepare() | |
self.beacon_skip_first = None | |
self.beacon_skip_last = None | |
# the raw activations of recent tokens | |
self.raw_activations = [(None, None) for _ in range(self.config.num_hidden_layers)] | |
# the attention sink activations | |
self.sink_activations = [(None, None) for _ in range(self.config.num_hidden_layers)] | |
# the beacon activations | |
self.beacon_activations = [(None, None) for _ in range(self.config.num_hidden_layers)] | |
def all_sequence_length(self): | |
if self.all_input_ids is None: | |
return 0 | |
else: | |
return self.all_input_ids.shape[1] | |
def batch_size(self): | |
if self.all_input_ids is None: | |
return 0 | |
else: | |
return self.all_input_ids.shape[0] | |
def finish(self): | |
is_finish = self.end_idx == self.all_sequence_length | |
return is_finish | |
def dtype(self): | |
return self.config.torch_dtype | |
def min_value(self): | |
return torch.finfo(self.dtype).min | |
def max_position_embeddings(self): | |
max_position_embeddings = self.config.max_position_embeddings | |
if getattr(self.config, "rope_scaling", None) is not None: | |
scaling_factor = self.config.rope_scaling["factor"] | |
max_position_embeddings = max_position_embeddings * scaling_factor | |
return max_position_embeddings | |
def beacon_window(self): | |
if ( | |
self.beacon_skip_last is not None | |
and self.start_idx < self.beacon_skip_last | |
and self.start_idx + self.config.beacon_window > self.beacon_skip_last | |
): | |
return self.beacon_skip_last - self.start_idx | |
else: | |
return self.config.beacon_window | |
def beacon_stride(self): | |
if ( | |
self.beacon_skip_last is not None | |
and self.start_idx < self.beacon_skip_last | |
and self.start_idx + self.config.beacon_window > self.beacon_skip_last | |
): | |
return self.beacon_skip_last - self.start_idx | |
else: | |
return self.config.beacon_stride | |
def get_memory_size(self): | |
""" | |
Sink memory size, beacon memory size and raw memory size. | |
""" | |
sink_memory_size = 0 | |
beacon_memory_size = 0 | |
raw_memory_size = 0 | |
if self.sink_activations[0][0] is not None: | |
sink_memory_size += self.sink_activations[0][0].shape[self.k_seq_dim] | |
if self.beacon_activations[0][0] is not None: | |
beacon_memory_size += self.beacon_activations[0][0].shape[self.k_seq_dim] | |
if self.raw_activations[0][0] is not None: | |
raw_memory_size += self.raw_activations[0][0].shape[self.k_seq_dim] | |
return sink_memory_size, beacon_memory_size, raw_memory_size | |
def prepare(self, input_ids, attention_mask, labels, skip_first=None, skip_last=None): | |
""" | |
Prepare inputs for the model. These inputs belong to the same sequence. | |
""" | |
# assert input_ids.shape[0] == 1, "Make sure the batch size is 1!" | |
# assert attention_mask is None or (attention_mask == 1).all(), "Make sure there is no padding!" | |
self._device = input_ids.device | |
# accumulate input_ids | |
if self.all_input_ids is None: | |
self.all_input_ids = input_ids.cpu() | |
else: | |
self.all_input_ids = torch.cat([self.all_input_ids, input_ids.cpu()], dim=1) | |
# accumulate attention_mask | |
if attention_mask is None: | |
attention_mask = torch.ones_like(input_ids, device=torch.device("cpu")) | |
if self.all_attention_mask is None: | |
self.all_attention_mask = attention_mask.cpu() | |
else: | |
self.all_attention_mask = torch.cat([self.all_attention_mask, attention_mask.cpu()], dim=1) | |
# accumulate labels if exisits | |
if labels is not None: | |
# rotate labels in advance so that the loss of the last token is not ignored in every window | |
labels = torch.cat([labels[:, 1:].cpu(), torch.tensor([-100]).expand(labels.shape[0], 1)], dim=1) | |
if self.all_labels is None: | |
self.all_labels = labels.cpu() | |
else: | |
self.all_labels = torch.cat([self.all_labels, labels], dim=1) | |
assert self.all_input_ids.shape[1] == self.all_labels.shape[1], f"Found inconsistent all_input_ids {self.all_input_ids.shape} and all_labels {self.all_labels.shape}!" | |
# how many tokens to skip at the beginning of the sequence? (They will be packed in a single chunk and processed by the model, after which their activations will be cached in sink_activations.) | |
if skip_first is not None: | |
assert self.config.beacon_parallel_window == 1, f"Make sure the parallel window is set to 1 when using beacon_skip!" | |
assert self.config.beacon_window == self.config.beacon_stride, f"Make sure the beacon_window equals to beacon_stride when using beacon_skip." | |
assert self.config.beacon_sink_size == 0, f"Make sure the beacon_sink_size is set to 0 when using beacon_skip!" | |
# stop compression after how many tokens | |
if skip_last is not None: | |
skip_first = skip_first if skip_first is not None else 0 | |
# assert (skip_last - skip_first) % self.config.beacon_window == 0, f"skip_last ({skip_last}) - skip_first ({skip_first}) = {skip_last - skip_first} is not divisible by window size {self.config.beacon_window}" | |
assert self.config.beacon_sink_size == 0, "Make sure the beacon_sink_size is zero when using skip_last!" | |
self.beacon_skip_first = skip_first | |
self.beacon_skip_last = skip_last | |
def set_compression_ratio(self, start_idx, end_idx): | |
"""Choose a condensing ratio from self.config.beacon_ratio""" | |
def filter_ratio(ratios, stride): | |
valid_ratios = [] | |
for ratio in ratios: | |
# stride must be bigger than condensing ratio because we there must be at least one beacon | |
if stride < ratio: | |
continue | |
# the stride must be evenly divisible by condensing ratio | |
if ratio > 0 and (stride % ratio) != 0: | |
continue | |
# when training, ratio=0 is valid if previous windows contain beacon or later windows contain beacon | |
if ratio == 0 and self.training: | |
previous_has_zero = -1 in self.all_beacon_sizes | |
following_has_nonzero = (start_idx + stride + self.beacon_window) <= self.all_sequence_length | |
if previous_has_zero or (not following_has_nonzero): | |
continue | |
valid_ratios.append(ratio) | |
assert len(valid_ratios), f"Cannot find valid condensing ratio (among {ratios}) for stride {stride}!" | |
return valid_ratios | |
def get_max_length(ratios): | |
max_lengths = [] | |
for compression_ratio in ratios: | |
if compression_ratio > 0: | |
# NOTE: here we must use the scaled position embeddings | |
max_lengths.append((self.max_position_embeddings - self.beacon_window) * compression_ratio + self.beacon_window) | |
else: | |
max_lengths.append(self.max_position_embeddings) | |
return max_lengths | |
if len(self.config.beacon_ratio) == 1: | |
return self.config.beacon_ratio[0] | |
ratio_mix = self.config.beacon_ratio_mix | |
beacon_ratio = filter_ratio(self.config.beacon_ratio, self.beacon_stride) | |
if ratio_mix == "instance-random": | |
if self.compression_ratio is None: | |
beacon_ratio = self.rng.choice(beacon_ratio).tolist() | |
self.compression_ratio = beacon_ratio | |
else: | |
beacon_ratio = self.compression_ratio | |
elif ratio_mix == "step-random": | |
beacon_ratio = self.rng.choice(beacon_ratio).tolist() | |
elif ratio_mix == "sequence": | |
if self.compression_ratio is None: | |
self.compression_ratio = cycle(beacon_ratio) | |
beacon_ratio = next(self.compression_ratio) | |
elif "adapt" in ratio_mix: | |
if self.compression_ratio is None: | |
future_length = int(ratio_mix.split("-")[1]) | |
sequence_length = self.all_input_ids.shape[1] + future_length | |
max_lengths = get_max_length(beacon_ratio) | |
# ascendingly sort the max lengths | |
valid_max_lengths_and_indices = [x for x in enumerate(max_lengths) if x[1] >= sequence_length] | |
if len(valid_max_lengths_and_indices): | |
minimum_length_index = min(valid_max_lengths_and_indices, key=lambda x: x[1])[0] | |
# use the minimal possible length for this sequence (the smallest fold ratio) | |
beacon_ratio = beacon_ratio[minimum_length_index] | |
else: | |
beacon_ratio = max(beacon_ratio) | |
# logger.warning(f"Failed to find valid fold window and size for sequence length {sequence_length}, as the maximum theoretical length is {max(max_lengths)}. Fall back to use the maximum one: {beacon_ratio}.") | |
self.compression_ratio = beacon_ratio | |
else: | |
beacon_ratio = self.compression_ratio | |
return beacon_ratio | |
def step(self): | |
# parallel does not support stride < window | |
# parallel does not support non-compression | |
# the input_ids is not long enough for parallel | |
if ( | |
self.config.beacon_parallel_window > 1 | |
and self.config.beacon_stride == self.config.beacon_window | |
and 0 not in self.config.beacon_ratio | |
and self.all_input_ids[:, self.end_idx:].shape[1] >= self.config.beacon_parallel_window * self.config.beacon_window | |
): | |
input_ids_list = [] | |
attention_mask_list = [] | |
position_ids_list = [] | |
labels_list = [] | |
beacon_size_list = [] | |
beacon_indices_list = [] | |
for i in range(self.config.beacon_parallel_window): | |
if i == 0: | |
_input_ids, _attention_mask, _position_ids, _past_key_values, _labels = self._step() | |
else: | |
_input_ids, _attention_mask, _position_ids, _past_key_values, _labels = self._step(ignore_memory=True) | |
input_ids_list.append(_input_ids) | |
attention_mask_list.append(_attention_mask) | |
position_ids_list.append(_position_ids) | |
labels_list.append(_labels) | |
beacon_size_list.append(_past_key_values[0][2]) | |
beacon_indices_list.append(_past_key_values[0][3]) | |
if i == 0: | |
past_key_values = _past_key_values | |
if past_key_values[0][0] is None: | |
mem_size = 0 | |
else: | |
mem_size = past_key_values[0][0].shape[self.k_seq_dim] | |
else: | |
# no memory | |
assert _past_key_values[0][0] is None | |
batch_size = self.all_input_ids.shape[0] | |
# NOTE: we do not need to repliace beacon tokens for the last window | |
seq_len = sum(x.shape[1] for x in input_ids_list) + sum(beacon_size_list) - beacon_size_list[-1] | |
input_ids = _input_ids.new_zeros((batch_size, seq_len)) + self.beacon_token | |
# all 0 | |
attention_mask = _attention_mask.new_zeros((batch_size, 1, seq_len, mem_size + seq_len)) + self.min_value | |
position_ids = torch.arange(mem_size + seq_len, device=self._device).expand(batch_size, mem_size + seq_len) | |
# 2 indicates the beacon token is used for replication | |
beacon_indices = beacon_indices_list[0].new_zeros(seq_len) + 2 | |
if _labels is not None: | |
# -100 because no loss on beacon tokens | |
labels = _labels.new_zeros((batch_size, seq_len)) - 100 | |
else: | |
labels = None | |
start_idx = 0 | |
position_offset = mem_size | |
for i in range(self.config.beacon_parallel_window): | |
beacon_size = beacon_size_list[i] | |
# populate input_ids | |
_input_ids = input_ids_list[i] | |
cur_seq_len = _input_ids.shape[1] | |
input_ids[:, start_idx: start_idx + cur_seq_len] = _input_ids | |
# populate attention_mask and position_ids | |
_attention_mask = attention_mask_list[i] | |
_position_ids = position_ids_list[i] | |
# the attention mask in the first window contains the mask for memory, which is redundant here | |
if i == 0: | |
_attention_mask = _attention_mask[:, :, :, mem_size:] | |
_position_ids = _position_ids[:, mem_size:] - mem_size | |
attention_mask[:, :, start_idx: start_idx + cur_seq_len, mem_size + start_idx: mem_size + start_idx + cur_seq_len] = _attention_mask | |
position_ids[:, mem_size + start_idx: mem_size + start_idx + cur_seq_len] = _position_ids + position_offset | |
# populate beacon_indices | |
_beacon_indices = beacon_indices_list[i] | |
beacon_indices[start_idx: start_idx + cur_seq_len] = _beacon_indices | |
# populate labels | |
if labels is not None: | |
# populate labels | |
_labels = labels_list[i] | |
labels[:, start_idx: start_idx + cur_seq_len] = _labels | |
# NOTE: when there is sink activations, we need to bias the position_ids for the first window | |
if i == 0 and self.config.beacon_sink_size > 0 and self.sink_activations[0][0] is None: | |
position_offset += 1 | |
# modify the attention and position for replicated beacon tokens | |
if i != self.config.beacon_parallel_window - 1: | |
replicate_beacon_row_start = start_idx + cur_seq_len | |
replicate_beacon_col_start = mem_size + start_idx + cur_seq_len | |
# NOTE: any attention mask is okay for replicated beacon tokens, but for convenience we use the causal mask | |
attention_mask[:, :, replicate_beacon_row_start: replicate_beacon_row_start + beacon_size, replicate_beacon_col_start: replicate_beacon_col_start + beacon_size] = _attention_mask.new_full((beacon_size, beacon_size), self.min_value).triu(1) | |
# NOTE: all future tokens can attend to the replicated beacon tokens | |
attention_mask[:, :, replicate_beacon_row_start + beacon_size:, replicate_beacon_col_start: replicate_beacon_col_start + beacon_size] = 0 | |
# NOTE: the position of replicated beacon tokens start from 0 | |
position_ids[:, mem_size + start_idx + cur_seq_len: mem_size + start_idx + cur_seq_len + beacon_size] = torch.arange(position_offset, position_offset + beacon_size, device=_input_ids.device)[None:] | |
start_idx += cur_seq_len + beacon_size | |
position_offset += beacon_size | |
# the memory is visible to all subsequent tokens | |
attention_mask[:, :, :, :max(mem_size, self.config.beacon_sink_size)] = 0 | |
# NOTE: modify beacon_indices | |
for i, (key, value, _, _) in enumerate(past_key_values): | |
past_key_values[i] = (key, value, sum(beacon_size_list), beacon_indices) | |
# NOTE: update _beacon_indices so that the next-token logits can be properly sliced out in self.output() | |
self.beacon_indices = beacon_indices | |
return input_ids, attention_mask, position_ids, past_key_values, labels | |
else: | |
return self._step() | |
def _step(self, ignore_memory=False): | |
""" | |
Yield inputs for the current sliding window, including the input_ids, attention_mask, position_ids, and past_key_values. | |
""" | |
#============================================# | |
# Check whether the inputs fulfills a window. | |
#============================================# | |
# the starting position of the current window w.r.t. the start of the current input sequence | |
start_idx = self.start_idx | |
# the end position of the current window w.r.t. the start of the current input sequence | |
end_idx = start_idx + self.beacon_window | |
# indicates if the current window is completely filled by raw activations and new tokens | |
# we only append beacon tokens for full windows | |
if end_idx > self.all_sequence_length: | |
# the input is shorter than the initial window size | |
end_idx = self.all_sequence_length | |
is_full_window = False | |
else: | |
is_full_window = True | |
# NOTE: in training, the entire sequence is input to the model at once | |
# In the last window, we do not need to append beacons because they will not be used at all | |
if self.training and end_idx == self.all_sequence_length: | |
next_start_idx = start_idx | |
is_full_window = False | |
raw_size_to_cache = -1 | |
beacon_size = 0 | |
compression_ratio = -1 | |
# NOTE: we do not compress the beacon_skip_first tokens at the beginning of the sequence | |
elif self.step_idx == 0 and self.beacon_skip_first is not None: | |
end_idx = start_idx + self.beacon_skip_first | |
assert end_idx <= self.all_sequence_length | |
next_start_idx = end_idx | |
is_full_window = True | |
raw_size_to_cache = -1 | |
beacon_size = 0 | |
compression_ratio = -1 | |
# NOTE: we do not compress tokens after beacon_skip_last tokens | |
elif self.beacon_skip_last is not None and start_idx >= self.beacon_skip_last: | |
end_idx = min(start_idx + self.beacon_window, self.all_sequence_length) | |
next_start_idx = end_idx | |
is_full_window = False | |
raw_size_to_cache = -1 | |
beacon_size = 0 | |
compression_ratio = -1 | |
else: | |
#============================================# | |
# Set compression ratio | |
#============================================# | |
if self.config.beacon_pos == "append": | |
if is_full_window: | |
# determine compression ratio for the current window | |
beacon_stride = self.beacon_stride | |
compression_ratio = self.set_compression_ratio(start_idx=start_idx, end_idx=end_idx) | |
if compression_ratio > 0: | |
# the stride must be evenly divisible by compression_ratio | |
beacon_size = beacon_stride // compression_ratio | |
else: | |
# the raw activations are used as beacon activations | |
beacon_size = -1 | |
# forward start_idx and end_idx | |
next_start_idx = start_idx + beacon_stride | |
# how many raw activations to save | |
raw_size_to_cache = end_idx - next_start_idx | |
else: | |
# no stride because the sequence has finished | |
next_start_idx = start_idx | |
# cache all raw activations | |
raw_size_to_cache = -1 | |
beacon_size = 0 | |
compression_ratio = 0 | |
elif self.config.beacon_pos == "interleave": | |
# the number of raw tokens in the input_ids | |
input_size = end_idx - self.end_idx | |
# set compression ratio once the previous window has finished, otherwise, reuse the interleave_compression_ratio if the input belongs to an unfinished window | |
if self.is_full_window: | |
compression_ratio = self.set_compression_ratio(start_idx=start_idx, end_idx=end_idx) | |
self.interleave_compression_ratio = compression_ratio | |
else: | |
compression_ratio = self.interleave_compression_ratio | |
# the beacon size is non-zero even if the window is not full | |
if compression_ratio > 0: | |
# this number of beacon tokens will be inserted among the raw tokens | |
beacon_size = (input_size + self.interleave_remainder) // compression_ratio | |
else: | |
# the raw activations are used as beacon activations | |
beacon_size = -1 | |
if is_full_window: | |
# move forward one window | |
next_start_idx = start_idx + self.beacon_stride | |
# no save raw activations | |
raw_size_to_cache = 0 | |
else: | |
# no stride because the sequence has not finished | |
next_start_idx = start_idx | |
# cache all recent raw activations to be used in the next window | |
raw_size_to_cache = -1 | |
#============================================# | |
# Slice out input_ids (raw tokens in the current window) | |
#============================================# | |
input_ids = self.all_input_ids[:, self.end_idx: end_idx].to(self._device) | |
attention_mask = self.all_attention_mask[:, self.end_idx: end_idx].to(self._device) | |
if self.all_labels is not None: | |
labels = self.all_labels[:, self.end_idx: end_idx].to(self._device) | |
else: | |
labels = None | |
batch_size = input_ids.shape[0] | |
#============================================# | |
# Insert beacon tokens if necessary. | |
#============================================# | |
# t1 = time.time() | |
if self.config.beacon_pos == "append": | |
# append beacons if necessary | |
if is_full_window and beacon_size > 0: | |
input_ids = torch.cat([input_ids, input_ids.new_full((batch_size, beacon_size), self.beacon_token)], dim=1) | |
# NOTE: prepend 1 to attention_mask because we have past_key_values | |
attention_mask = torch.cat([attention_mask, attention_mask.new_ones(batch_size, beacon_size)], dim=1) | |
if labels is not None: | |
labels = torch.cat([labels, labels.new_zeros(batch_size, beacon_size) - 100], dim=1) | |
elif self.config.beacon_pos == "interleave": | |
input_len = input_ids.shape[1] | |
if beacon_size > 0: | |
# insert beacon tokens in between raw tokens | |
input_ids_with_beacons = input_ids.new_full((input_ids.shape[0], input_len + beacon_size), self.beacon_token) | |
raw_token_indices = torch.arange(input_ids_with_beacons.shape[1], device=input_ids.device) | |
interleave_start_idx = compression_ratio - self.interleave_remainder | |
raw_token_indices = raw_token_indices[raw_token_indices % (compression_ratio + 1) != interleave_start_idx].unsqueeze(0).expand_as(input_ids) | |
input_ids_with_beacons = input_ids_with_beacons.scatter(dim=1, index=raw_token_indices, src=input_ids) | |
input_ids = input_ids_with_beacons | |
# attention mask | |
attention_mask_with_beacons = attention_mask.new_full((attention_mask.shape[0], attention_mask.shape[1] + beacon_size), 1) | |
attention_mask_with_beacons = attention_mask_with_beacons.scatter(dim=1, index=raw_token_indices, src=attention_mask) | |
attention_mask = attention_mask_with_beacons | |
# labels | |
if labels is not None: | |
labels_with_beacons = labels.new_full((labels.shape[0], labels.shape[1] + beacon_size), -100) | |
labels_with_beacons = labels_with_beacons.scatter(dim=1, index=raw_token_indices, src=labels) | |
labels = labels_with_beacons | |
if compression_ratio > 0: | |
# update the reminder | |
self.interleave_remainder = (input_len + self.interleave_remainder) % compression_ratio | |
# NOTE: skip computing loss in the very first window because the beacon tokens will be used in the next window | |
if self.training and self.step_idx == 0 and not (self.config.beacon_pos == 'interleave' and self.config.beacon_attn == 'full-coverage'): | |
labels[:] = -100 | |
# t2 = time.time() | |
#============================================# | |
# Prepare beacon_indices for interleave beacon_pos, a boolean mask where True indicates the beacon tokens. | |
# The mask is applied on the inputs of the entire window, including the cached activations and the input_ids. | |
#============================================# | |
beacon_indices = (input_ids[0] == self.beacon_token).long() | |
if self.is_full_window: | |
self.beacon_indices = torch.tensor([], dtype=torch.long, device=input_ids.device) | |
# the beacon_indices always tracks the beacon tokens in both the cached activations and the input_ids | |
beacon_indices = torch.cat([self.beacon_indices, beacon_indices]) | |
# record the beacon_indices for the next window | |
self.beacon_indices = beacon_indices | |
if is_full_window and beacon_size == -1: | |
# NOTE: the first beacon_stride raw tokens serve as beacon tokens | |
# we use -1 to indicate these raw tokens, so that the attention mask and position ids will not be modified | |
beacon_indices[:self.beacon_stride] = -1 | |
# t3 = time.time() | |
#============================================# | |
# Prepare past_key_values. | |
# beacon_size: how many beacon tokens are there in the input_ids | |
# beacon_indices: the boolean mask for the entire window where True indicates the beacon tokens (for append, the beacon_indices corresponds to input_ids, while for 'interleave', the beacon_indices corresponds to the entire window including both the input_ids and the cached activations) | |
#============================================# | |
past_key_values = [] | |
for layer_idx in range(self.config.num_hidden_layers): | |
if ignore_memory: | |
key, value = None, None | |
else: | |
sink_key, sink_value = self.sink_activations[layer_idx] | |
beacon_key, beacon_value = self.beacon_activations[layer_idx] | |
raw_key, raw_value = self.raw_activations[layer_idx] | |
key = cat_tensor([ | |
sink_key, beacon_key, raw_key, | |
], dim=self.k_seq_dim) | |
value = cat_tensor([ | |
sink_value, beacon_value, raw_value, | |
], dim=self.v_seq_dim) | |
layer_past_key_values = (key, value, beacon_size, beacon_indices) | |
past_key_values.append(layer_past_key_values) | |
# t4 = time.time() | |
#============================================# | |
# Prepare attention_mask and position_ids. | |
#============================================# | |
first_key = past_key_values[0][0] | |
mem_size = first_key.shape[self.k_seq_dim] if first_key is not None else 0 | |
if mem_size > 0: | |
attention_mask = torch.cat([attention_mask.new_ones(batch_size, mem_size), attention_mask], dim=1) | |
input_length = input_ids.shape[1] | |
position_ids = torch.arange(attention_mask.shape[-1], dtype=torch.long, device=self._device).repeat(batch_size, 1) | |
if self.config._attn_implementation == "flash_attention_2": | |
assert self.config.beacon_attn == "full-coverage", f"Make sure to set beacon_attn='full-coverage' when using flash attention! Found {self.config.beacon_attn}." | |
if 0 in attention_mask: | |
pass | |
else: | |
attention_mask = None | |
elif self.config._attn_implementation == "sdpa" and self.config.beacon_pos == "append" and beacon_size <= 0 and (input_length == 1 or mem_size == 0): | |
attention_mask = None | |
else: | |
attention_mask, position_ids = self._make_4d_attention_mask_and_position_ids( | |
attention_mask, | |
position_ids, | |
mem_size, | |
beacon_size, | |
compression_ratio, | |
) | |
# t5 = time.time() | |
# print(f"prepare inputs {t2-t1}, prepare indices {t3-t2}, prepare memory {t4-t3}, prepare attention mask {t5-t4}") | |
#============================================# | |
# Update necessary attributes. | |
#============================================# | |
# keep track of whether the current inputs is a full_window | |
self.is_full_window = is_full_window | |
# keep track of the raw_size_to_cache | |
self.raw_size_to_cache = raw_size_to_cache | |
# involked in self.output() | |
self.all_beacon_sizes.append(beacon_size) | |
# update start_idx and end_idx | |
# NOTE: the update of start_idx will influence self.beacon_window and self.beacon_stride in case self.beacon_skip_last is not None | |
# Therefore, we must make sure all calls to self.beacon_window and self.beacon_stride happen before the update of start_idx | |
self.start_idx = next_start_idx | |
self.end_idx = end_idx | |
self.step_idx += 1 | |
# print(f"start_idx: {start_idx}") | |
# print(f"next_start_idx: {next_start_idx}") | |
# print(f"beacon_size: {beacon_size}") | |
# print(f"raw_size_to_cache: {raw_size_to_cache}") | |
# print(f"interleave_remainder:{self.interleave_remainder}") | |
# print(f"input_ids: {input_ids}") | |
# print(f"beacon_indices: {beacon_indices}") | |
# print(f"position_ids: {position_ids}") | |
# print(f"attention_mask:\n{attention_mask == 0}") | |
# x = input() | |
# if x == "s": | |
# return | |
return input_ids, attention_mask, position_ids, past_key_values, labels | |
def update_memory(self, past_key_values): | |
""" | |
Accumulate beacon activations and raw activations. | |
""" | |
for layer_idx, (key, value, beacon_size, beacon_indices) in enumerate(past_key_values): | |
# NOTE: the past_key_values are incrementally returned (only the new keys and values are returned) | |
previous_raw_key, previous_raw_value = self.raw_activations[layer_idx] | |
if self.beacon_skip_first is not None and self.sink_activations[layer_idx][0] is None: | |
assert key.shape[self.k_seq_dim] == self.beacon_skip_first | |
assert value.shape[self.k_seq_dim] == self.beacon_skip_first | |
self.sink_activations[layer_idx] = [ | |
key, | |
value, | |
] | |
# NOTE: no need to update raw activations and beacon activations as all activations are kept as sink activations | |
continue | |
if self.beacon_activations[layer_idx][0] is None and self.config.beacon_sink_size > 0: | |
# save the sink activations | |
# NOTE: we do not slice the key/value activations, which may cause duplication when beacon_ratio=-1 for the first window, but it's okay | |
self.sink_activations[layer_idx] = [ | |
slice_tensor(key, end=self.config.beacon_sink_size, dim=self.k_seq_dim), | |
slice_tensor(value, end=self.config.beacon_sink_size, dim=self.v_seq_dim), | |
] | |
if not self.is_full_window: | |
# this means the current input does not fulfill a window | |
# thus, the key and value are all raw activations, and we accumulate them until the window is fulfilled | |
assert self.raw_size_to_cache == -1 | |
raw_key = cat_tensor([ | |
previous_raw_key, | |
key | |
], dim=self.k_seq_dim) | |
raw_value = cat_tensor([ | |
previous_raw_value, | |
value | |
], dim=self.v_seq_dim) | |
self.raw_activations[layer_idx] = (raw_key, raw_value) | |
else: | |
# NOTE: use the correct previous_beacon_key and value! | |
previous_beacon_key, previous_beacon_value = self.beacon_activations[layer_idx] | |
beacon_key, beacon_value, raw_key, raw_value = self._extract_beacon_and_raw_memory( | |
key, | |
value, | |
previous_beacon_key, | |
previous_beacon_value, | |
previous_raw_key, | |
previous_raw_value, | |
beacon_indices, | |
) | |
self.beacon_activations[layer_idx] = (beacon_key, beacon_value) | |
self.raw_activations[layer_idx] = (raw_key, raw_value) | |
def update_loss(self, batch_loss, valid_token_num): | |
""" | |
Accumulate loss for later perplexity computation and backward pass. | |
""" | |
if self.batch_loss is None: | |
# NOTE: multiply valid_token_num because batch_loss is divided by it in advance | |
self.batch_loss = batch_loss * valid_token_num | |
self.valid_token_num = valid_token_num | |
else: | |
# NOTE: avoid in-place operations, otherwise there will be gradient errors in training | |
self.batch_loss = self.batch_loss + batch_loss * valid_token_num | |
self.valid_token_num = self.valid_token_num + valid_token_num | |
def output(self, model_outputs): | |
""" | |
Override loss with accumulated loss. Update the next-token logits. | |
""" | |
# override loss | |
if self.batch_loss is not None: | |
# here the batch_loss is the summation of all token losses in each element | |
loss = self.batch_loss.sum() / self.valid_token_num.sum() | |
# NOTE: prevent nan | |
batch_loss = self.batch_loss / self.valid_token_num | |
if (self.valid_token_num == 0).any(): | |
batch_loss = batch_loss.masked_fill(self.valid_token_num == 0, 0.) | |
# NOTE: we must use dict to override values, otherwise trainer cannot find loss | |
model_outputs["loss"] = loss | |
model_outputs["batch_loss"] = batch_loss | |
# override last_hidden_states (used in generation) | |
beacon_size = self.all_beacon_sizes[-1] | |
# remove logits corresponding to beacon tokens | |
if beacon_size > 0: | |
logits = model_outputs["logits"] | |
beacon_indices = self.beacon_indices[-logits.shape[1]:] | |
model_outputs["logits"] = logits[:, beacon_indices == 0] | |
return model_outputs | |
def _make_4d_attention_mask_and_position_ids( | |
self, | |
attention_mask, | |
position_ids, | |
mem_size, | |
beacon_size, | |
compression_ratio, | |
): | |
""" | |
Convert attention_mask into causal 4D attention_mask (batch_size, head_num, query_len, key_len). | |
""" | |
tgt_size = attention_mask.size(-1) - mem_size | |
dtype = self.dtype | |
min_value = self.min_value | |
device = self._device | |
batch_size, src_size = attention_mask.size() | |
# square for memory, and lower triangular for input_ids | |
causal_mask = torch.full((tgt_size, tgt_size), min_value, device=device, dtype=dtype) | |
mask_cond = torch.arange(causal_mask.size(-1), device=device) | |
causal_mask.masked_fill_(mask_cond < (mask_cond + 1).view(causal_mask.size(-1), -1), 0) | |
causal_mask = torch.cat([torch.zeros(tgt_size, mem_size, dtype=dtype, device=device), causal_mask], dim=-1) | |
causal_mask = causal_mask[None, None, ...].expand(batch_size, 1, tgt_size, src_size) | |
# 1 for non-padding tokens | |
expand_mask = attention_mask[:, None, None, :].expand(batch_size, 1, tgt_size, src_size) | |
invert_mask = 1.0 - expand_mask | |
###add | |
# invert_mask = ~ expand_mask | |
#print("min_value:", min_value) # 查看当前值 | |
#print("dtype:", invert_mask.dtype) # 查看张量类型 | |
invert_mask.masked_fill_(invert_mask.bool(), min_value) | |
attention_mask = causal_mask.masked_fill(invert_mask.bool(), min_value) | |
if self.config.beacon_attn == "step-expansion": | |
# each beacon can attend to one more sub-interval than its predecessor | |
if self.config.beacon_pos == "append" and beacon_size > 0: | |
window_size = self.beacon_window | |
window_size_with_beacon = window_size + beacon_size | |
beacon_start_idx = -beacon_size | |
# batch_size, head_num, window_size | |
reference_attention_mask = attention_mask[..., -beacon_size - 1, -window_size_with_beacon: -beacon_size] | |
# compression_ratio, 2 * compression_ratio, ..., beacon_size * compression_ratio | |
beacon_arange = torch.arange(1, beacon_size + 1, device=device) * compression_ratio | |
# 0, 1, 2, ..., window_size - 1 | |
ordinal_arange = torch.arange(window_size, device=device) | |
# beacon_size, window_size | |
valid_pos = ordinal_arange.expand(beacon_size, window_size) < beacon_arange.unsqueeze(-1) | |
# beacon_size, window_size | |
ordinal_attention_mask = torch.where(valid_pos, 0, min_value) | |
# NOTE: add reference attention_mask so that padding tokens are considered | |
ordinal_attention_mask = ordinal_attention_mask[None, None, ...] + reference_attention_mask.unsqueeze(-2) | |
if self.config.beacon_attend_prev: | |
beacon_attention_mask = attention_mask.new_full((beacon_size, beacon_size), min_value).triu(1) | |
# the beacon token is next to the last ordinal token it attends to | |
ordinal_position_ids = position_ids[:, -window_size_with_beacon: -beacon_size] | |
beacon_position_ids = ordinal_position_ids[:, compression_ratio - 1::compression_ratio] + torch.arange(1, beacon_size + 1, device=device)[None] | |
position_ids[:, beacon_start_idx:] = beacon_position_ids | |
else: | |
beacon_attention_mask = attention_mask.new_full((beacon_size, beacon_size), min_value).fill_diagonal_(0) | |
# the beacon token is next to the last ordinal token it attends to | |
ordinal_position_ids = position_ids[:, -window_size_with_beacon: -beacon_size] | |
beacon_position_ids = ordinal_position_ids[:, compression_ratio - 1::compression_ratio] + 1 | |
position_ids[:, beacon_start_idx:] = beacon_position_ids | |
attention_mask[..., beacon_start_idx:, -window_size_with_beacon: -beacon_size] = ordinal_attention_mask | |
attention_mask[..., beacon_start_idx:, beacon_start_idx:] = beacon_attention_mask | |
# NOTE: the attention mask should be modified when there is beacon token within the window, not in the input_ids | |
elif self.config.beacon_pos == "interleave" and (self.beacon_indices == 1).any(): | |
assert self.config.beacon_attend_prev == False, f"Make sure beacon_attend_prev is False if using 'interleave' beacon pos!" | |
beacon_indices = self.beacon_indices | |
cur_position_ids = position_ids[:, -len(beacon_indices):] | |
base_position = cur_position_ids[:, 0] - 1 | |
# NOTE: alternate position so that the position of raw tokens are consistent | |
position_template = cur_position_ids.new_ones(cur_position_ids.shape) | |
position_template[:, compression_ratio + 1::compression_ratio + 1] = 0 | |
cur_position_ids = base_position + position_template.cumsum(-1) | |
position_ids[:, -len(beacon_indices):] = cur_position_ids | |
cur_input_length = len(beacon_indices) | |
cur_attention_mask = attention_mask[..., -cur_input_length:, -cur_input_length:] | |
# mask all beacon columns | |
cur_attention_mask[..., beacon_indices] = min_value | |
# beacon tokens can attend to themselves | |
input_ids_attention_mask = cur_attention_mask[..., -tgt_size:, -tgt_size:] | |
input_ids_attention_mask[..., range(tgt_size), range(tgt_size)] = 0 | |
elif self.config.beacon_attn == "segmentation": | |
# each beacon can attend to its corresponding sub-interval | |
if self.config.beacon_pos == "append" and beacon_size > 0: | |
window_size = self.beacon_window | |
window_size_with_beacon = window_size + beacon_size | |
beacon_start_idx = -beacon_size | |
# batch_size, head_num, window_size | |
reference_attention_mask = attention_mask[..., -beacon_size - 1, -window_size_with_beacon: -beacon_size] | |
# beacon_size, compression_ratio | |
indices = torch.arange(compression_ratio * beacon_size, device=device).view(beacon_size, -1) | |
# beacon_size, window_size | |
ordinal_attention_mask = attention_mask.new_full((beacon_size, window_size), min_value) | |
ordinal_attention_mask.scatter_(dim=-1, index=indices, value=0) | |
# NOTE: add reference attention_mask so that padding tokens are considered | |
ordinal_attention_mask = ordinal_attention_mask[None, None, ...] + reference_attention_mask.unsqueeze(-2) | |
if self.config.beacon_attend_prev: | |
beacon_attention_mask = attention_mask.new_full((beacon_size, beacon_size), min_value).triu(1) | |
# the beacon token is next to the last ordinal token it attends to | |
beacon_position_ids = position_ids.new_full(beacon_size, fill_value=compression_ratio + mem_size) | |
beacon_position_ids = beacon_position_ids + torch.arange(beacon_size) | |
position_ids[:, beacon_start_idx:] = beacon_position_ids | |
else: | |
beacon_attention_mask = attention_mask.new_full((beacon_size, beacon_size), min_value).fill_diagonal_(0) | |
# the beacon token is next to the last ordinal token it attends to | |
beacon_position_ids = position_ids.new_full(beacon_size, fill_value=compression_ratio + mem_size) | |
position_ids[:, beacon_start_idx:] = beacon_position_ids | |
attention_mask[..., beacon_start_idx:, -window_size_with_beacon: -beacon_size] = ordinal_attention_mask | |
attention_mask[..., beacon_start_idx:, beacon_start_idx:] = beacon_attention_mask | |
# beacons of different ratios are blind to others | |
attention_mask[..., beacon_start_idx:, -beacon_size: beacon_start_idx] = min_value | |
elif self.config.beacon_pos == "interleave": | |
raise NotImplementedError | |
elif self.config.beacon_attn == "full-coverage": | |
pass | |
return attention_mask, position_ids | |
def _extract_beacon_and_raw_memory( | |
self, | |
key, | |
value, | |
previous_beacon_key, | |
previous_beacon_value, | |
previous_raw_key, | |
previous_raw_value, | |
beacon_indices, | |
): | |
"""Extract beacon and raw memory from the returned key and value when the window is full.""" | |
key = cat_tensor([ | |
previous_raw_key, | |
key | |
], dim=self.k_seq_dim) | |
value = cat_tensor([ | |
previous_raw_value, | |
value | |
], dim=self.v_seq_dim) | |
# NOTE: we use magic slice instead of boolean index here for efficiency | |
beacon_key = slice_tensor(key, index=torch.logical_or(beacon_indices == 1, beacon_indices == -1), dim=self.k_seq_dim) | |
beacon_value = slice_tensor(value, index=torch.logical_or(beacon_indices == 1, beacon_indices == -1), dim=self.v_seq_dim) | |
if self.config.beacon_accum: | |
beacon_key = cat_tensor([previous_beacon_key, beacon_key], dim=self.k_seq_dim) | |
beacon_value = cat_tensor([previous_beacon_value, beacon_value], dim=self.v_seq_dim) | |
if self.raw_size_to_cache > 0: | |
raw_key = slice_tensor(key, index=beacon_indices == 0, dim=self.k_seq_dim) | |
raw_key = slice_tensor(raw_key, start=-raw_size_to_cache, dim=self.k_seq_dim) | |
raw_value = slice_tensor(value, index=beacon_indices == 0, dim=self.v_seq_dim) | |
raw_value = slice_tensor(raw_value, start=-raw_size_to_cache, dim=self.v_seq_dim) | |
else: | |
raw_key = None | |
raw_value = None | |
return beacon_key, beacon_value, raw_key, raw_value | |
def slice_tensor(x, start=None, end=None, step=None, index=None, dim=2): | |
if x is None: | |
return None | |
if end == 0: | |
return None | |
if start == x.shape[dim]: | |
return None | |
if start is not None and start == end: | |
return None | |
if dim == 2: | |
if index is not None: | |
return x[:, :, index] | |
elif start is None and end is not None: | |
if step is None: | |
return x[:, :, :end, ...] | |
else: | |
return x[:, :, :end:step, ...] | |
elif start is not None and end is None: | |
if step is None: | |
return x[:, :, start:, ...] | |
else: | |
return x[:, :, start::step, ...] | |
elif start is not None and end is not None: | |
if step is None: | |
return x[:, :, start:end, ...] | |
else: | |
return x[:, :, start:end:step, ...] | |
elif dim == 1: | |
if index is not None: | |
return x[:, :, index] | |
elif start is None and end is not None: | |
if step is None: | |
return x[:, :end, ...] | |
else: | |
return x[:, :end:step, ...] | |
elif start is not None and end is None: | |
if step is None: | |
return x[:, start:, ...] | |
else: | |
return x[:, start::step, ...] | |
elif start is not None and end is not None: | |
if step is None: | |
return x[:, start:end, ...] | |
else: | |
return x[:, start:end:step, ...] | |
else: | |
raise NotImplementedError | |
def cat_tensor(list_of_tensors, dim=-1): | |
list_of_tensors = [t for t in list_of_tensors if t is not None] | |
if len(list_of_tensors) > 1: | |
result = torch.cat(list_of_tensors, dim=dim) | |
elif len(list_of_tensors) == 1: | |
result = list_of_tensors[0] | |
else: | |
result = None | |
return result | |
def slice_activations(activations, start=None, end=None, k_seq_dim=2, v_seq_dim=2): | |
new_activations = [] | |
for key, value in activations: | |
new_key = slice_tensor(key, start=start, end=end, dim=k_seq_dim) | |
new_value = slice_tensor(value, start=start, end=end, dim=v_seq_dim) | |
new_activations.append([new_key, new_value]) | |
return new_activations | |
def cat_activations(list_of_activations, k_seq_dim=2, v_seq_dim=2): | |
assert all(len(x) == len(list_of_activations[0]) for x in list_of_activations), f"Make sure all activations have the same number of layers! Found {[len(x) for x in list_of_activations]}." | |
new_activations = [] | |
for layer_idx in range(len(list_of_activations[0])): | |
keys = [x[layer_idx][0] for x in list_of_activations] | |
values = [x[layer_idx][1] for x in list_of_activations] | |
new_key = cat_tensor(keys, dim=k_seq_dim) | |
new_value = cat_tensor(values, dim=v_seq_dim) | |
new_activations.append([new_key, new_value]) | |
return new_activations | |
def interleave_activations(main_activations, augment_activations, main_spans, augment_spans, k_seq_dim=2, v_seq_dim=2, device=torch.device("cuda")): | |
""" Interleave main_activations and augment_activations according to main_span and augment_span. | |
Args: | |
main_span: a list of tuples (start_idx, end_idx). when start_idx and end_idx is None, the augment_activations will be plugged in. | |
augment_span: a list of tuples (start_idx, end_idx) | |
""" | |
assert len(main_activations) == len(augment_activations) , f"Make sure main and augment activations have the same number of layers! Found {len(main_activations)} and {len(augment_activations)}!" | |
assert sum(x[0] is None and x[1] is None for x in main_spans) == len(augment_spans), f"Make sure the number of slots for augmentation (start_idx=None and end_idx=None in main_spans) matches the number of augmentations. Found {sum(x for x in main_spans if x[0] is None and x[1] is None)} slots but {len(augment_spans)} augmentations!" | |
new_activations = [] | |
for layer_idx in range(len(main_activations)): | |
main_key, main_value = main_activations[layer_idx] | |
augment_key, augment_value = augment_activations[layer_idx] | |
sliced_keys = [] | |
sliced_values = [] | |
augment_idx = 0 | |
for start, end in main_spans: | |
if start is None and end is None: | |
# this means the augment key/value should be plugged in | |
augment_start, augment_end = augment_spans[augment_idx] | |
sliced_key = slice_tensor( | |
augment_key, | |
start=augment_start, | |
end=augment_end, | |
dim=k_seq_dim | |
).to(device) | |
sliced_value = slice_tensor( | |
augment_value, | |
start=augment_start, | |
end=augment_end, | |
dim=v_seq_dim | |
).to(device) | |
else: | |
sliced_key = slice_tensor( | |
main_key, | |
start=start, | |
end=end, | |
dim=k_seq_dim | |
) | |
sliced_value = slice_tensor( | |
main_value, | |
start=start, | |
end=end, | |
dim=v_seq_dim | |
) | |
sliced_keys.append(sliced_key) | |
sliced_values.append(sliced_value) | |
new_key = cat_tensor(sliced_keys, dim=k_seq_dim) | |
new_value = cat_tensor(sliced_values, dim=v_seq_dim) | |
new_activations.append([new_key, new_value]) | |
return new_activations | |
def softmax(x:np.ndarray, axis=-1, temperature=1): | |
if isinstance(x, list): | |
x = np.array(x) | |
x = x / temperature | |
x = x - x.max(axis=axis, keepdims=True) | |
y = np.exp(x) | |
return y / y.sum(axis=axis, keepdims=True) | |
def l1_norm(x): | |
sum_x = sum(x) | |
x = [y/sum_x for y in x] | |
return x | |