describe-anything / dam /model /model_utils.py
longlian's picture
Initial commit
c62168d
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn
import os
import warnings
from typing import Optional, Union, List, Tuple
from transformers import (
AutoTokenizer,
AutoModel,
AutoModelForCausalLM,
AutoConfig,
BitsAndBytesConfig,
PretrainedConfig,
PreTrainedModel,
LlamaConfig,
LlamaModel,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import PretrainedConfig
from .llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
from .language_model.llava_llama import LlavaLlamaConfig
# TODO: we may move LlavaConfig to configuration_llava.py
# from model.configuration_llava import LlavaConfig
class LlavaLlamaModel(LlavaMetaModel, LlavaMetaForCausalLM, PreTrainedModel):
config_class = LlavaLlamaConfig
main_input_name = "input_embeds"
supports_gradient_checkpointing = True
def __init__(self, config: LlavaLlamaConfig = None, *args, **kwargs) -> None:
super().__init__(config)
self.init_vlm(config=config, *args, **kwargs)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
**kwargs,
):
if hasattr(cls, "load_pretrained"):
return cls.load_pretrained(pretrained_model_name_or_path,
*model_args, config=config, cache_dir=cache_dir, ignore_mismatched_sizes=ignore_mismatched_sizes, force_download=force_download, local_files_only=local_files_only, token=token,
revision=revision, use_safetensors=use_safetensors, **kwargs
)
return super(LlavaLlamaModel).from_pretrained(pretrained_model_name_or_path,
*model_args, config=config, cache_dir=cache_dir, ignore_mismatched_sizes=ignore_mismatched_sizes, force_download=force_download, local_files_only=local_files_only, token=token,
revision=revision, use_safetensors=use_safetensors, **kwargs)
def forward(
self,
input_ids: torch.LongTensor = None,
images: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
self.freezed_module_patch()
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
) = self.prepare_inputs_labels_for_multimodal(
input_ids, position_ids, attention_mask, past_key_values, labels, images
)
# Note (kentang-mit@): we have a unit test for this function.
if self.training:
(
_,
new_position_ids,
new_attention_mask,
_,
new_inputs_embeds,
new_labels,
sorted_seqlens_in_batch,
) = self.repack_multimodal_data(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
)
new_input_ids = None
past_key_values = None
else:
new_attention_mask = attention_mask
new_position_ids = position_ids
new_inputs_embeds = inputs_embeds
new_labels = labels
sorted_seqlens_in_batch = attention_mask.sum(-1).int()
new_input_ids = input_ids
outputs = self.llm.forward(
input_ids=new_input_ids,
attention_mask=new_attention_mask,
position_ids=new_position_ids,
past_key_values=past_key_values,
inputs_embeds=new_inputs_embeds,
labels=new_labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
seqlens_in_batch=sorted_seqlens_in_batch,
)
return outputs
@torch.no_grad()
def generate(
self,
input_ids: Optional[torch.FloatTensor] = None,
images: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
**generation_kwargs,
):
if images is not None:
(
_,
_,
attention_mask,
_,
inputs_embeds,
_,
) = self.prepare_inputs_labels_for_multimodal(
input_ids, None, attention_mask, None, None, images
)
else:
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = inputs_embeds.to(self.dtype)
outputs = self.llm.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**generation_kwargs
)
return outputs
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def load_pretrained_model(
model_path,
model_name,
model_base=None,
load_8bit=False,
load_4bit=False,
device_map="auto",
device="cuda",
**kwargs,
):
kwargs = {"device_map": device_map, **kwargs}
if device != "cuda":
kwargs["device_map"] = {"": device}
if load_8bit:
kwargs["load_in_8bit"] = True
elif load_4bit:
kwargs["load_in_4bit"] = True
kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
else:
kwargs["torch_dtype"] = torch.float16
config = AutoConfig.from_pretrained(model_path)
config.resume_path = model_path
prepare_config_for_eval(config, kwargs)
model = LlavaLlamaModel(
config=config,
low_cpu_mem_usage=True,
**kwargs
)
tokenizer = model.tokenizer
model.eval()
# mm_use_im_start_end = getattr(
# model.config, "mm_use_im_start_end", False)
# mm_use_im_patch_token = getattr(
# model.config, "mm_use_im_patch_token", True)
# if mm_use_im_patch_token:
# tokenizer.add_tokens(
# [DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
# if mm_use_im_start_end:
# tokenizer.add_tokens(
# [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
# )
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
vision_tower.to(device=device, dtype=torch.float16)
mm_projector = model.get_mm_projector()
mm_projector.to(device=device, dtype=torch.float16)
context_provider = model.get_context_provider()
if context_provider is not None:
context_provider.to(device=device, dtype=torch.float16)
image_processor = vision_tower.image_processor
if hasattr(model.llm.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 2048
return tokenizer, model, image_processor, context_len
def parse_model_name_or_path(config: PretrainedConfig, model_name="llm", suffix="_cfg"):
target_model = f"{model_name}{suffix}"
target_cfg = getattr(config, target_model, None)
if isinstance(target_cfg, str):
return target_cfg
elif isinstance(target_cfg, dict):
return target_cfg["architectures"][0]
else:
raise ValueError(f"Invalid {target_model} configuration!")
def prepare_config_for_eval(config: PretrainedConfig, kwargs: dict):
try:
# compatible with deprecated config convention
if getattr(config, "vision_tower_cfg", None) is None:
config.vision_tower_cfg = config.mm_vision_tower
except AttributeError:
raise ValueError(
f"Invalid configuration! Cannot find vision_tower in config:\n{config}")
config.model_dtype = kwargs.pop("torch_dtype").__str__()
# siglip does not support device_map = "auto"
vision_tower_name = parse_model_name_or_path(config, "vision_tower")
if "siglip" in vision_tower_name.lower():
kwargs["device_map"] = "cuda"
AutoConfig.register("llava_llama", LlavaLlamaConfig)
AutoModel.register(LlavaLlamaConfig, LlavaLlamaModel)