SreyanG-NVIDIA's picture
Upload 225 files
174ae06 verified
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is modified from https://github.com/haotian-liu/LLaVA/
import os
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
import torch
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from llava.model.loss import soft_cross_entropy
from llava.model.utils.packing import set_seqlens_in_batch
from llava.train.sequence_parallel.globals import get_pg_manager
from llava.utils.logging import logger
from ...train.utils import calculate_loss_weight
from ..configuration_llava import LlavaConfig
from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
class LlavaLlamaConfig(LlavaConfig):
model_type = "llava_llama"
# FIXME we will follow the convention to add a new class for CausalLM in the future
class LlavaLlamaModel(LlavaMetaModel, LlavaMetaForCausalLM, PreTrainedModel):
config_class = LlavaLlamaConfig
main_input_name = "input_embeds"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
def __init__(self, config: LlavaLlamaConfig = None, *args, **kwargs) -> None:
super().__init__(config)
self.init_vlm(config=config, *args, **kwargs)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
**kwargs,
):
if hasattr(cls, "load_pretrained"):
return cls.load_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
**kwargs,
)
return super(LlavaLlamaModel).from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
**kwargs,
)
def forward(
self,
input_ids: torch.LongTensor = None,
media: Optional[Dict[str, List[torch.Tensor]]] = None,
images: Optional[torch.FloatTensor] = None,
media_config: Optional[List] = None,
attention_mask: Optional[torch.Tensor] = None,
media_meta: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
packing: bool = True,
force_packing: bool = False,
seqlens_in_batch: Optional[torch.LongTensor] = None,
dpo_forward: bool = False,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
self.freezed_module_patch()
if images is not None:
if media is not None:
raise ValueError("Both 'media' and 'images' are provided. Please provide only one.")
logger.warning("The 'images' argument is deprecated. Please use 'media' instead.")
media = {"image": images}
if media_config is None:
media_config = defaultdict(dict)
if inputs_embeds is None:
inputs_embeds, labels, attention_mask = self._embed(input_ids, media, media_config, labels, attention_mask,media_meta)
if force_packing or (packing and self.training and not dpo_forward):
if seqlens_in_batch is None:
seqlens_in_batch = torch.sum(attention_mask, dim=1)
set_seqlens_in_batch(seqlens_in_batch)
(inputs_embeds, attention_mask, position_ids, labels) = self.repack_multimodal_data(
inputs_embeds, attention_mask, position_ids, labels
)
outputs = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
labels=labels,
**kwargs,
)
if self.training and getattr(self.config, "time_token_ids", []):
outputs.loss = soft_cross_entropy(
outputs.logits,
labels,
soft_tokens=self.config.time_token_ids,
std=self.config.soft_ce_std,
)
# Loss rescale for SP
if get_pg_manager() is not None:
loss_weight = calculate_loss_weight(labels)
outputs.loss = outputs.loss * loss_weight
if dpo_forward:
return outputs.logits, labels
return outputs
AutoConfig.register("llava_llama", LlavaLlamaConfig)
AutoModel.register(LlavaLlamaConfig, LlavaLlamaModel)