|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os, sys, os.path as osp |
|
import warnings |
|
from abc import ABC, abstractmethod |
|
|
|
import torch, logging |
|
|
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModel, |
|
AutoModelForCausalLM, |
|
AutoConfig, |
|
BitsAndBytesConfig, |
|
PretrainedConfig, |
|
PreTrainedModel, |
|
) |
|
|
|
from .constants import ( |
|
DEFAULT_IM_END_TOKEN, |
|
DEFAULT_IM_START_TOKEN, |
|
DEFAULT_IMAGE_PATCH_TOKEN, |
|
IGNORE_INDEX, |
|
IMAGE_TOKEN_INDEX, |
|
MASK_TOKEN_INDEX, |
|
) |
|
|
|
from collections import OrderedDict |
|
from .utils import get_model_config |
|
from .language_model.builder import build_llm_and_tokenizer |
|
from .multimodal_encoder.builder import build_vision_tower, build_context_provider |
|
from .multimodal_projector.builder import build_mm_projector |
|
from .configuration_llava import LlavaConfig |
|
|
|
from transformers.modeling_utils import ContextManagers, no_init_weights |
|
|
|
|
|
class LlavaMetaModel(ABC): |
|
def init_vlm(self, config: PreTrainedModel = None, *args, **kwargs): |
|
|
|
if hasattr(self, "llm") or hasattr(self, "vision_tower") or hasattr(self, "mm_projector"): |
|
|
|
return |
|
|
|
model_dtype = getattr(config, "model_dtype", "torch.float16") |
|
if not hasattr(config, "model_dtype"): |
|
warnings.warn("model_dtype not found in config, defaulting to torch.float16.") |
|
config.model_dtype = model_dtype |
|
|
|
|
|
cfgs = get_model_config(config) |
|
|
|
llm_cfg, vision_tower_cfg, mm_projector_cfg, mask_encoder_cfg, context_provider_cfg = cfgs |
|
if llm_cfg is None or vision_tower_cfg is None or mm_projector_cfg is None: |
|
raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.") |
|
|
|
|
|
self.llm, self.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs) |
|
self.vision_tower = build_vision_tower(vision_tower_cfg, config) |
|
self.mm_projector = build_mm_projector(mm_projector_cfg, config) |
|
self.context_provider = build_context_provider(context_provider_cfg, config) if context_provider_cfg is not None else None |
|
|
|
self.post_config() |
|
self.is_loaded = True |
|
|
|
assert ( |
|
self.llm is not None or self.vision_tower is not None or self.mm_projector is not None |
|
), "At least one of the components must be instantiated." |
|
|
|
@classmethod |
|
def load_from_config(cls, model_path_or_config, *args, **kwargs): |
|
pass |
|
|
|
|
|
@classmethod |
|
def load_pretrained(cls, model_path_or_config, *args, **kwargs): |
|
kwargs.pop("config", None) |
|
|
|
if isinstance(model_path_or_config, str): |
|
config = AutoConfig.from_pretrained(model_path_or_config) |
|
elif isinstance(model_path_or_config, LlavaConfig): |
|
config = model_path_or_config |
|
else: |
|
raise NotImplementedError(f"wrong type, {type(model_path_or_config)} \ |
|
{isinstance(model_path_or_config, LlavaConfig)}") |
|
|
|
model_dtype = getattr(config, "model_dtype", "torch.float16") |
|
if not hasattr(config, "model_dtype"): |
|
warnings.warn("model_dtype not found in config, defaulting to torch.float16.") |
|
config.model_dtype = model_dtype |
|
|
|
cfgs = get_model_config(config) |
|
|
|
llm_cfg, vision_tower_cfg, mm_projector_cfg, mask_encoder_cfg, context_provider_cfg = cfgs |
|
if llm_cfg is None or vision_tower_cfg is None or mm_projector_cfg is None: |
|
raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.") |
|
|
|
|
|
with ContextManagers([no_init_weights(_enable=True),]): |
|
vlm = cls(config, *args, **kwargs) |
|
|
|
|
|
if hasattr(vlm, "llm") or hasattr(vlm, "vision_tower") or hasattr(vlm, "mm_projector"): |
|
if vlm.is_loaded: |
|
return vlm |
|
|
|
vlm.llm, vlm.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs) |
|
vlm.vision_tower = build_vision_tower(vision_tower_cfg, config) |
|
vlm.mm_projector = build_mm_projector(mm_projector_cfg, config) |
|
if mask_encoder_cfg is not None: |
|
raise NotImplementedError("Mask encoder is not supported.") |
|
vlm.context_provider = build_context_provider(context_provider_cfg, config) if context_provider_cfg is not None else None |
|
|
|
self.post_config() |
|
self.is_loaded = True |
|
|
|
|
|
assert ( |
|
vlm.llm is not None or vlm.vision_tower is not None or vlm.mm_projector is not None |
|
), "At least one of the components must be instantiated." |
|
return vlm |
|
|
|
|
|
def save_pretrained(self, output_dir, state_dict=None): |
|
if state_dict is None: |
|
|
|
|
|
state_dict = self.state_dict() |
|
|
|
if getattr(self, "tokenizer", None): |
|
self.tokenizer.save_pretrained(osp.join(output_dir, "llm")) |
|
|
|
if self.get_llm(): |
|
print(f"saving llm to {osp.join(output_dir, 'llm')}") |
|
self.llm.config._name_or_path = osp.join(output_dir, "llm") |
|
llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k}) |
|
self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict) |
|
self.config.llm_cfg = self.llm.config |
|
|
|
if self.get_vision_tower() and "radio" not in self.get_vision_tower().__class__.__name__.lower(): |
|
print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}") |
|
self.vision_tower.config._name_or_path = osp.join(output_dir, "vision_tower") |
|
vision_tower_state_dict = OrderedDict( |
|
{k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k} |
|
) |
|
self.vision_tower.vision_tower.save_pretrained( |
|
os.path.join(output_dir, "vision_tower"), |
|
state_dict=vision_tower_state_dict, |
|
) |
|
self.vision_tower.image_processor.save_pretrained(os.path.join(output_dir, "vision_tower")) |
|
self.config.vision_tower_cfg = self.vision_tower.config |
|
if hasattr(self.config.vision_tower_cfg, 'auto_map'): |
|
delattr(self.config.vision_tower_cfg, 'auto_map') |
|
|
|
if self.get_mm_projector(): |
|
print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}") |
|
self.mm_projector.config._name_or_path = osp.join(output_dir, "mm_projector") |
|
mm_projector_state_dict = OrderedDict( |
|
{k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k} |
|
) |
|
self.mm_projector.save_pretrained( |
|
os.path.join(output_dir, "mm_projector"), |
|
state_dict=mm_projector_state_dict, |
|
) |
|
self.config.mm_projector_cfg = self.mm_projector.config |
|
|
|
if self.get_context_provider(): |
|
print(f"saving context_provider to {osp.join(output_dir, 'context_provider')}") |
|
self.context_provider.config._name_or_path = osp.join(output_dir, "context_provider") |
|
context_provider_state_dict = OrderedDict( |
|
{k.split("context_provider.")[-1]: v for k, v in state_dict.items() if "context_provider" in k} |
|
) |
|
self.context_provider.save_pretrained( |
|
os.path.join(output_dir, "context_provider"), |
|
state_dict=context_provider_state_dict, |
|
) |
|
self.config.context_provider_cfg = self.context_provider.config |
|
|
|
|
|
self.config._name_or_path = output_dir |
|
self.config.architectures = [self.__class__.__name__] |
|
self.config.save_pretrained(output_dir) |
|
|
|
|
|
def get_llm(self): |
|
llm = getattr(self, "llm", None) |
|
if type(llm) is list: |
|
llm = llm[0] |
|
return llm |
|
|
|
def get_lm_head(self): |
|
lm_head = getattr(self.get_llm(), "lm_head", None) |
|
return lm_head |
|
|
|
def get_vision_tower(self): |
|
vision_tower = getattr(self, "vision_tower", None) |
|
if type(vision_tower) is list: |
|
vision_tower = vision_tower[0] |
|
return vision_tower |
|
|
|
def get_mm_projector(self): |
|
mm_projector = getattr(self, "mm_projector", None) |
|
if type(mm_projector) is list: |
|
mm_projector = mm_projector[0] |
|
return mm_projector |
|
|
|
def get_context_provider(self): |
|
context_provider = getattr(self, "context_provider", None) |
|
return context_provider |
|
|
|
def post_config(self): |
|
self.training = self.get_llm().training |
|
|
|
if getattr(self.config, "llm_cfg", None) is None: |
|
self.config.llm_cfg = self.llm.config |
|
if getattr(self.config, "vision_tower_cfg", None) is None: |
|
self.config.vision_tower_cfg = self.vision_tower.config |
|
if getattr(self.config, "mm_projector_cfg", None) is None: |
|
self.config.mm_projector_cfg = self.mm_projector.config |
|
if getattr(self.config, "context_provider_cfg", None) is None and self.context_provider is not None: |
|
self.config.context_provider_cfg = self.context_provider.config |
|
|
|
def freezed_module_patch(self): |
|
''' |
|
Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules. |
|
''' |
|
if self.training: |
|
if self.get_llm() and not getattr(self.config, "tune_language_model", False): |
|
logging.warning("Caution: Your LLM is currently in training mode, ensuring accurate gradient computation. Please be vigilant, particularly regarding BatchNorm and Dropout operations.") |
|
if self.get_vision_tower() and not getattr(self.config, "tune_vision_tower", False): |
|
self.get_vision_tower().eval() |
|
if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False): |
|
self.get_mm_projector().eval() |
|
if self.get_context_provider() and not getattr(self.config, "tune_context_provider", False): |
|
self.get_context_provider().eval() |
|
|
|
def encode_images(self, images): |
|
image_features = self.get_vision_tower()(images) |
|
image_features = self.get_mm_projector()(image_features) |
|
return image_features |
|
|
|
def encode_images_with_context(self, images): |
|
context_provider = self.get_context_provider() |
|
|
|
cimage_mask = torch.any((images[:, :4, ...] != images[:, 4:, ...]).flatten(start_dim=1), dim=1) |
|
|
|
if context_provider.treat_image_as_cimage: |
|
|
|
cimage_mask[:] = True |
|
|
|
if context_provider.context_image_as_queries: |
|
|
|
images = torch.cat((images[:, 4:, ...], images[:, :4, ...]), dim=1) |
|
|
|
|
|
vision_tower = self.get_vision_tower() |
|
|
|
image_features = vision_tower(images[:, :4, ...]).to(self.device) |
|
|
|
cimage_concatenated = images[cimage_mask] |
|
cimage_full_features = image_features[cimage_mask] |
|
if context_provider.context_provider_type == "cross_attn_end_to_all": |
|
cimage_features = self.context_provider( |
|
cimage_full_features=cimage_full_features, |
|
cimage_concatenated=cimage_concatenated, |
|
vision_tower=vision_tower |
|
).to(self.device) |
|
elif context_provider.context_provider_type == "concat": |
|
|
|
cimage_features = self.context_provider( |
|
cimage_concatenated=cimage_concatenated, |
|
vision_tower=vision_tower |
|
).to(self.device) |
|
else: |
|
raise NotImplementedError(f"Context provider type {context_provider.context_provider_type} not implemented.") |
|
|
|
image_features[cimage_mask] = cimage_features |
|
|
|
|
|
image_features = self.get_mm_projector()(image_features) |
|
|
|
return image_features |
|
|
|
|
|
|
|
def _temporary_reorder_cache(self, past_key_values, sorted_idx): |
|
return self.get_llm()._temporary_reorder_cache(past_key_values, sorted_idx) |
|
|
|
def get_input_embeddings(self): |
|
return self.get_llm().get_input_embeddings() |
|
|
|
def get_output_embeddings(self): |
|
return self.get_llm().get_output_embeddings() |
|
|
|
def resize_token_embeddings(self, embed_size): |
|
self.get_llm().resize_token_embeddings(embed_size) |
|
|
|
|
|
|
|
class LlavaMetaForCausalLM(ABC): |
|
"""This class is originally implemented by the LLaVA team and |
|
modified by Haotian Tang and Jason Lu based on Ji Lin's implementation |
|
to support multiple images and input packing.""" |
|
|
|
|
|
def prepare_inputs_labels_for_multimodal( |
|
self, input_ids, position_ids, attention_mask, past_key_values, labels, images |
|
): |
|
vision_tower = self.get_vision_tower() |
|
if vision_tower is None or images is None or input_ids.shape[1] == 1: |
|
if ( |
|
past_key_values is not None |
|
and vision_tower is not None |
|
and images is not None |
|
and input_ids.shape[1] == 1 |
|
): |
|
target_shape = past_key_values[-1][-1].shape[-2] + 1 |
|
attention_mask = torch.cat( |
|
( |
|
attention_mask, |
|
torch.ones( |
|
( |
|
attention_mask.shape[0], |
|
target_shape - attention_mask.shape[1], |
|
), |
|
dtype=attention_mask.dtype, |
|
device=attention_mask.device, |
|
), |
|
), |
|
dim=1, |
|
) |
|
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 |
|
return ( |
|
input_ids, |
|
position_ids, |
|
attention_mask, |
|
past_key_values, |
|
None, |
|
labels, |
|
) |
|
|
|
if type(images) is list: |
|
images = torch.cat(images, dim=0) |
|
elif images.ndim == 5: |
|
images = images.flatten(0, 1) |
|
if getattr(self, "context_provider", None): |
|
image_features = self.encode_images_with_context(images) |
|
else: |
|
|
|
|
|
assert images.shape[1] <= 4, f"images have more than 4 channels, but context provider is not included" |
|
image_features = self.encode_images(images).to(self.device) |
|
|
|
if getattr(self.config, "turn_mm_projector", False) and getattr(self.config, "mm_use_im_start_end", False): |
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
_labels = labels |
|
_position_ids = position_ids |
|
_attention_mask = attention_mask |
|
if attention_mask is None: |
|
attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
|
else: |
|
attention_mask = attention_mask.bool() |
|
if position_ids is None: |
|
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) |
|
if labels is None: |
|
labels = torch.full_like(input_ids, IGNORE_INDEX) |
|
|
|
|
|
input_ids_copy = input_ids.clone() |
|
|
|
input_ids_copy[input_ids_copy == IMAGE_TOKEN_INDEX] = 0 |
|
input_embeds = self.llm.model.embed_tokens(input_ids_copy) |
|
|
|
input_ids = [ |
|
cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) |
|
] |
|
input_embeds_1 = [ |
|
cur_input_embeds[cur_attention_mask] |
|
for cur_input_embeds, cur_attention_mask in zip(input_embeds, attention_mask) |
|
] |
|
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] |
|
|
|
new_input_embeds = [] |
|
new_labels = [] |
|
cur_image_idx = 0 |
|
|
|
|
|
|
|
|
|
for batch_idx, cur_input_ids in enumerate(input_ids): |
|
cur_input_ids = input_ids[batch_idx] |
|
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() |
|
if num_images == 0: |
|
cur_image_features = image_features[0] |
|
|
|
cur_input_embeds_1 = input_embeds_1[batch_idx] |
|
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) |
|
new_input_embeds.append(cur_input_embeds) |
|
new_labels.append(labels[batch_idx]) |
|
|
|
|
|
continue |
|
|
|
cur_input_embeds = input_embeds_1[batch_idx] |
|
image_token_indices = ( |
|
[-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] |
|
) |
|
cur_input_ids_noim = [] |
|
cur_labels = labels[batch_idx] |
|
cur_labels_noim = [] |
|
cur_input_embeds_no_im = [] |
|
for i in range(len(image_token_indices) - 1): |
|
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]) |
|
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]) |
|
cur_input_embeds_no_im.append(cur_input_embeds[image_token_indices[i] + 1 : image_token_indices[i + 1]]) |
|
split_sizes = [x.shape[0] for x in cur_labels_noim] |
|
|
|
|
|
cur_new_input_embeds = [] |
|
cur_new_labels = [] |
|
for i in range(num_images + 1): |
|
cur_new_input_embeds.append(cur_input_embeds_no_im[i]) |
|
cur_new_labels.append(cur_labels_noim[i]) |
|
if i < num_images: |
|
cur_image_features = image_features[cur_image_idx] |
|
cur_image_idx += 1 |
|
cur_new_input_embeds.append(cur_image_features) |
|
cur_new_labels.append( |
|
torch.full( |
|
(cur_image_features.shape[0],), |
|
IGNORE_INDEX, |
|
device=cur_labels.device, |
|
dtype=cur_labels.dtype, |
|
) |
|
) |
|
|
|
cur_new_input_embeds = torch.cat(cur_new_input_embeds) |
|
cur_new_labels = torch.cat(cur_new_labels) |
|
|
|
new_input_embeds.append(cur_new_input_embeds) |
|
new_labels.append(cur_new_labels) |
|
|
|
|
|
tokenizer_model_max_length = getattr(self.llm.config, "tokenizer_model_max_length", None) |
|
if tokenizer_model_max_length is not None: |
|
if any(len(x) > tokenizer_model_max_length for x in new_input_embeds): |
|
warnings.warn("Inputs truncated!") |
|
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] |
|
new_labels = [x[:tokenizer_model_max_length] for x in new_labels] |
|
|
|
max_len = max(x.shape[0] for x in new_input_embeds) |
|
batch_size = len(new_input_embeds) |
|
|
|
new_input_embeds_padded = [] |
|
new_labels_padded = torch.full( |
|
(batch_size, max_len), |
|
IGNORE_INDEX, |
|
dtype=new_labels[0].dtype, |
|
device=new_labels[0].device, |
|
) |
|
attention_mask = torch.zeros( |
|
(batch_size, max_len), |
|
dtype=attention_mask.dtype, |
|
device=attention_mask.device, |
|
) |
|
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) |
|
|
|
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): |
|
cur_len = cur_new_embed.shape[0] |
|
if getattr(self.llm.config, "tokenizer_padding_side", "right") == "left": |
|
new_input_embeds_padded.append( |
|
torch.cat( |
|
( |
|
torch.zeros( |
|
(max_len - cur_len, cur_new_embed.shape[1]), |
|
dtype=cur_new_embed.dtype, |
|
device=cur_new_embed.device, |
|
), |
|
cur_new_embed, |
|
), |
|
dim=0, |
|
) |
|
) |
|
if cur_len > 0: |
|
new_labels_padded[i, -cur_len:] = cur_new_labels |
|
attention_mask[i, -cur_len:] = True |
|
position_ids[i, -cur_len:] = torch.arange( |
|
0, cur_len, dtype=position_ids.dtype, device=position_ids.device |
|
) |
|
else: |
|
new_input_embeds_padded.append( |
|
torch.cat( |
|
( |
|
cur_new_embed, |
|
torch.zeros( |
|
(max_len - cur_len, cur_new_embed.shape[1]), |
|
dtype=cur_new_embed.dtype, |
|
device=cur_new_embed.device, |
|
), |
|
), |
|
dim=0, |
|
) |
|
) |
|
if cur_len > 0: |
|
new_labels_padded[i, :cur_len] = cur_new_labels |
|
attention_mask[i, :cur_len] = True |
|
position_ids[i, :cur_len] = torch.arange( |
|
0, cur_len, dtype=position_ids.dtype, device=position_ids.device |
|
) |
|
|
|
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) |
|
|
|
if _labels is None: |
|
new_labels = None |
|
else: |
|
new_labels = new_labels_padded |
|
|
|
if _attention_mask is None: |
|
attention_mask = None |
|
else: |
|
attention_mask = attention_mask.to(dtype=_attention_mask.dtype) |
|
|
|
if _position_ids is None: |
|
position_ids = None |
|
|
|
return ( |
|
None, |
|
position_ids, |
|
attention_mask, |
|
past_key_values, |
|
new_input_embeds, |
|
new_labels, |
|
) |
|
|
|
def repack_multimodal_data( |
|
self, |
|
input_ids, |
|
position_ids, |
|
attention_mask, |
|
past_key_values, |
|
inputs_embeds, |
|
labels, |
|
): |
|
|
|
|
|
new_inputs_embeds = [] |
|
new_position_ids = [] |
|
new_labels = [] |
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
|
sorted_seqlens_in_batch, sorted_idx = torch.sort(seqlens_in_batch, descending=True) |
|
|
|
max_seqlen = inputs_embeds.shape[1] |
|
|
|
cur_inputs_embeds = [] |
|
cur_position_ids = [] |
|
cur_labels = [] |
|
cur_batch_len = 0 |
|
|
|
for i in range(len(sorted_seqlens_in_batch)): |
|
cur_seqlen = sorted_seqlens_in_batch[i].item() |
|
if cur_seqlen + cur_batch_len <= max_seqlen: |
|
cur_batch_len += cur_seqlen |
|
|
|
|
|
cur_inputs_embeds.append(inputs_embeds[sorted_idx[i]][attention_mask[sorted_idx[i]]]) |
|
|
|
cur_position_ids.append( |
|
torch.arange( |
|
cur_inputs_embeds[-1].shape[0], |
|
device=cur_inputs_embeds[-1].device, |
|
) |
|
) |
|
|
|
|
|
cur_labels.append(labels[sorted_idx[i]][attention_mask[sorted_idx[i]]]) |
|
else: |
|
new_inputs_embeds.append(torch.cat(cur_inputs_embeds, 0)) |
|
new_position_ids.append(torch.cat(cur_position_ids, 0)) |
|
new_labels.append(torch.cat(cur_labels, 0)) |
|
|
|
cur_batch_len = cur_seqlen |
|
cur_inputs_embeds = [inputs_embeds[sorted_idx[i]][attention_mask[sorted_idx[i]]]] |
|
cur_position_ids = [ |
|
torch.arange( |
|
cur_inputs_embeds[-1].shape[0], |
|
device=cur_inputs_embeds[-1].device, |
|
) |
|
] |
|
cur_labels = [labels[sorted_idx[i]][attention_mask[sorted_idx[i]]]] |
|
|
|
if len(cur_inputs_embeds): |
|
new_inputs_embeds.append(torch.cat(cur_inputs_embeds, 0)) |
|
new_position_ids.append(torch.cat(cur_position_ids, 0)) |
|
new_labels.append(torch.cat(cur_labels, 0)) |
|
|
|
|
|
|
|
new_inputs_embeds = torch.nn.utils.rnn.pad_sequence( |
|
new_inputs_embeds, batch_first=True, padding_value=self.llm.pad_token_id |
|
) |
|
|
|
new_position_ids = torch.nn.utils.rnn.pad_sequence(new_position_ids, batch_first=True, padding_value=-1) |
|
|
|
new_labels = torch.nn.utils.rnn.pad_sequence(new_labels, batch_first=True, padding_value=IGNORE_INDEX) |
|
|
|
new_attention_mask = new_position_ids.ne(-1) |
|
|
|
assert new_attention_mask.sum() == attention_mask.sum() |
|
|
|
|
|
|
|
|
|
return ( |
|
None, |
|
new_position_ids, |
|
new_attention_mask, |
|
past_key_values, |
|
new_inputs_embeds, |
|
new_labels, |
|
sorted_seqlens_in_batch, |
|
) |
|
|
|
def initialize_vision_tokenizer(self, model_args, tokenizer): |
|
if model_args.mm_use_im_patch_token: |
|
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) |
|
self.resize_token_embeddings(len(tokenizer)) |
|
|
|
if model_args.mm_use_im_start_end: |
|
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) |
|
self.resize_token_embeddings(len(tokenizer)) |
|
|
|
if num_new_tokens > 0: |
|
input_embeddings = self.get_input_embeddings().weight.data |
|
output_embeddings = self.get_output_embeddings().weight.data |
|
|
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
|
|
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg |
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg |
|
|
|
if model_args.pretrain_mm_mlp_adapter: |
|
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu") |
|
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] |
|
assert num_new_tokens == 2 |
|
if input_embeddings.shape == embed_tokens_weight.shape: |
|
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] |
|
elif embed_tokens_weight.shape[0] == num_new_tokens: |
|
input_embeddings[-num_new_tokens:] = embed_tokens_weight |
|
else: |
|
raise ValueError( |
|
f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}." |
|
) |
|
elif model_args.mm_use_im_patch_token: |
|
if model_args.mm_projector: |
|
for p in self.get_input_embeddings().parameters(): |
|
p.requires_grad = False |
|
for p in self.get_output_embeddings().parameters(): |
|
p.requires_grad = False |
|
|