|
|
|
|
|
|
|
|
|
import math |
|
import random |
|
from PIL import Image |
|
|
|
import torch |
|
from torch.nn.attention.flex_attention import or_masks, and_masks |
|
|
|
|
|
def create_sparse_mask(document_lens, split_lens, attn_modes, device): |
|
def causal_mask(b, h, q_idx, kv_idx): |
|
return q_idx >= kv_idx |
|
|
|
def full_and_noise_mask(b, h, q_idx, kv_idx): |
|
return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0) |
|
|
|
def remove_noise_mask(b, h, q_idx, kv_idx): |
|
return (~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx]))) |
|
|
|
def sample_mask(b, h, q_idx, kv_idx): |
|
return document_id[q_idx] == document_id[kv_idx] |
|
|
|
full_and_noise_tmp = [] |
|
noise_tmp = [] |
|
|
|
for i, (length, model) in enumerate(zip(split_lens, attn_modes)): |
|
value = i if model in ['full', 'noise'] else -1 |
|
full_and_noise_tmp.extend([value] * length) |
|
value_noise = i if model == 'noise' else -1 |
|
noise_tmp.extend([value_noise] * length) |
|
|
|
full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device) |
|
noise_seq_id = torch.Tensor(noise_tmp).to(device) |
|
|
|
document_id = torch.cat([torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]).to(device) |
|
|
|
return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask) |
|
|
|
|
|
def patchify(image, patch_size): |
|
p = patch_size |
|
c, h, w = image.shape |
|
assert h % p == 0 and w % p == 0 |
|
image = image.reshape(c, h // p, p, w // p, p) |
|
image = torch.einsum("chpwq->hwpqc", image) |
|
image = image.reshape(-1, p**2 * c) |
|
return image |
|
|
|
|
|
def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side): |
|
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size |
|
coords_h = torch.arange(0, num_patches_h) |
|
coords_w = torch.arange(0, num_patches_w) |
|
pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten() |
|
return pos_ids |
|
|
|
|
|
def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side): |
|
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size |
|
boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side) |
|
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h) |
|
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w) |
|
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) |
|
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) |
|
pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten() |
|
return pos_ids |
|
|
|
|
|
def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"): |
|
""" |
|
nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within |
|
a sample, where each sample contains multiple splits with different attn modes. |
|
nested_attn_modes: whether to use full attn in each split. |
|
""" |
|
sample_len = sum(split_lens) |
|
attention_mask = torch.zeros((sample_len, sample_len), dtype=torch.bool, device=device) |
|
|
|
csum = 0 |
|
for s, attn_mode in zip(split_lens, attn_modes): |
|
assert attn_mode in ['causal', 'full', 'noise'] |
|
if attn_mode == "causal": |
|
attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s), device=device).tril() |
|
attention_mask[csum:csum + s, :csum] = 1 |
|
else: |
|
attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s)) |
|
attention_mask[csum:csum + s, :csum] = 1 |
|
csum += s |
|
|
|
csum = 0 |
|
for s, attn_mode in zip(split_lens, attn_modes): |
|
if attn_mode == "noise": |
|
attention_mask[:, csum : csum + s] = torch.zeros((sample_len, s)) |
|
attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s)) |
|
csum += s |
|
|
|
attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_( |
|
~attention_mask, float("-inf") |
|
) |
|
|
|
return attention_mask |
|
|
|
|
|
def split_integer_exp_decay(S, ng_sample_decay=1.0): |
|
if ng_sample_decay == 1.0: |
|
N = random.randint(1, S) |
|
else: |
|
base = (1 - ng_sample_decay) / (1 - math.pow(ng_sample_decay, S)) |
|
p = [base * math.pow(ng_sample_decay, i) for i in range(S)] |
|
N = random.choices(list(range(1, S + 1)), p, k=1)[0] |
|
cumsum = [0] + sorted(random.sample(range(1, S), N - 1)) + [S] |
|
result = [cumsum[i+1] - cumsum[i] for i in range(len(cumsum) - 1)] |
|
return result, cumsum |
|
|
|
|
|
def pil_img2rgb(image): |
|
if image.mode == "RGBA" or image.info.get("transparency", None) is not None: |
|
image = image.convert("RGBA") |
|
white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255)) |
|
white.paste(image, mask=image.split()[3]) |
|
image = white |
|
else: |
|
image = image.convert("RGB") |
|
|
|
return image |
|
|
|
|
|
def add_special_tokens(tokenizer): |
|
all_special_tokens = [] |
|
for k, v in tokenizer.special_tokens_map.items(): |
|
if isinstance(v, str): |
|
all_special_tokens.append(v) |
|
elif isinstance(v, list): |
|
all_special_tokens += v |
|
|
|
new_tokens = [] |
|
|
|
if '<|im_start|>' not in all_special_tokens: |
|
new_tokens.append('<|im_start|>') |
|
|
|
if '<|im_end|>' not in all_special_tokens: |
|
new_tokens.append('<|im_end|>') |
|
|
|
if '<|vision_start|>' not in all_special_tokens: |
|
new_tokens.append('<|vision_start|>') |
|
|
|
if '<|vision_end|>' not in all_special_tokens: |
|
new_tokens.append('<|vision_end|>') |
|
|
|
num_new_tokens = tokenizer.add_tokens(new_tokens) |
|
bos_token_id = tokenizer.convert_tokens_to_ids('<|im_start|>') |
|
eos_token_id = tokenizer.convert_tokens_to_ids('<|im_end|>') |
|
start_of_image = tokenizer.convert_tokens_to_ids('<|vision_start|>') |
|
end_of_image = tokenizer.convert_tokens_to_ids('<|vision_end|>') |
|
|
|
new_token_ids = dict( |
|
bos_token_id=bos_token_id, |
|
eos_token_id=eos_token_id, |
|
start_of_image=start_of_image, |
|
end_of_image=end_of_image, |
|
) |
|
|
|
return tokenizer, new_token_ids, num_new_tokens |
|
|
|
|
|
def len2weight(x, loss_reduction='square'): |
|
if x == 0: |
|
return x |
|
if loss_reduction == 'token': |
|
return 1 |
|
if loss_reduction == 'sample': |
|
return 1 / x |
|
if loss_reduction == 'square': |
|
return 1 / (x ** 0.5) |
|
raise NotImplementedError(loss_reduction) |
|
|