# Copyright (c) 2025 NVIDIA CORPORATION. # Licensed under the MIT license. # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. # LICENSE is in incl_licenses directory. # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import json import logging import os import os.path as osp import warnings from abc import ABC from collections import OrderedDict, defaultdict, deque from itertools import chain from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist import torch.nn.functional as F from einops import rearrange from hydra.utils import instantiate from transformers import AutoConfig, GenerationConfig, LogitsProcessor, PreTrainedModel from transformers.modeling_utils import ContextManagers, no_init_weights from llava.constants import DEFAULT_SOUND_TOKEN,DEFAULT_SPEECH_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS from llava.mm_utils import process_image, process_images, process_sounds,process_sound_masks from llava.model.configuration_llava import LlavaConfig, ResponseFormat from llava.model.language_model.builder import build_llm_and_tokenizer from llava.model.multimodal_encoder.builder import build_sound_tower from llava.model.multimodal_projector.builder import build_speech_mm_projector, build_sound_mm_projector from llava.model.utils import get_model_config from llava.train.sequence_parallel import get_pg_manager from llava.utils import distributed from llava.utils.media import extract_media from llava.utils.tokenizer import tokenize_conversation class LlavaMetaModel(ABC): def _init_llm(self, llm_cfg, config, *args, **kwargs): llm, tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs) return llm, tokenizer def init_vlm(self, config: PreTrainedModel = None, *args, **kwargs): # TODO(ligeng): figure out how from_config and from_pretrained works in HF implementation. if hasattr(self, "llm") or hasattr(self, "vision_tower") or hasattr(self, "speech_tower") or hasattr(self, "sound_tower") or hasattr(self, "mm_projector") or hasattr(self, "speech_mm_projector") or hasattr(self, "sound_mm_projector"): # already initialized, skipped return model_dtype = getattr(config, "model_dtype", "torch.float16") if not hasattr(config, "model_dtype"): warnings.warn("model_dtype not found in config, defaulting to torch.float16.") config.model_dtype = model_dtype cfgs = get_model_config(config) print(cfgs) if len(cfgs) == 7: llm_cfg, vision_tower_cfg, speech_tower_cfg,sound_tower_cfg, mm_projector_cfg, speech_mm_projector_cfg,sound_mm_projector_cfg = cfgs else: raise ValueError("`llm_cfg` `mm_projector_cfg` `speech_mm_projector_cfg` `sound_mm_projector_cfg` `vision_tower_cfg` `speech_tower_cfg` `sound_tower_cfg` not found in the config.") self.llm, self.tokenizer = self._init_llm(llm_cfg, config, *args, **kwargs) self.sound_tower = build_sound_tower(sound_tower_cfg, config) self.sound_mm_projector = build_sound_mm_projector(sound_mm_projector_cfg, config) if isinstance(self.config, dict): self.vocab_size = config.llm_cfg["vocab_size"] + NUM_EXTRA_TOKENS else: self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS logging.info( f"[XGrammar] config is not a dict, loading vocab size from tokenizer {self.tokenizer.vocab_size} + {NUM_EXTRA_TOKENS} => {self.vocab_size}" ) # XGrammar tokenizer and grammar compiler # lazy init only when specified json output during inference self.grammar_compiler = None self.encoders = {} for name in ["sound"]: config = getattr(self.config, f"{name}_encoder") if isinstance(config, str): config = json.loads(config) self.encoders[name] = instantiate(config, parent=self) self.post_config() self.is_loaded = True assert ( self.llm is not None or self.vision_tower is not None or self.speech_tower is not None or self.mm_projector is not None or self.speech_mm_projector is not None ), "At least one of the components must be instantiated." @classmethod def load_from_config(cls, model_path_or_config, *args, **kwargs): pass ## FIXME we will use this function to load model in the future @classmethod def load_pretrained(cls, model_path_or_config, *args, **kwargs): kwargs.pop("config", None) if isinstance(model_path_or_config, str): config = AutoConfig.from_pretrained(model_path_or_config) elif isinstance(model_path_or_config, LlavaConfig): config = model_path_or_config else: raise NotImplementedError( f"wrong type, {type(model_path_or_config)} \ {isinstance(model_path_or_config, LlavaConfig)}" ) model_dtype = getattr(config, "model_dtype", "torch.float16") if not hasattr(config, "model_dtype"): warnings.warn("model_dtype not found in config, defaulting to torch.float16.") config.model_dtype = model_dtype cfgs = get_model_config(config) if len(cfgs) == 7: llm_cfg, vision_tower_cfg, speech_tower_cfg,sound_tower_cfg, mm_projector_cfg, speech_mm_projector_cfg,sound_mm_projector_cfg = cfgs else: raise ValueError("`llm_cfg` `mm_projector_cfg` `speech_mm_projector_cfg` `sound_mm_projector_cfg` `vision_tower_cfg` `speech_tower_cfg` `sound_tower_cfg` not found in the config.") init_context = [ no_init_weights(_enable=True), ] with ContextManagers(init_context): vlm = cls(config, *args, **kwargs) if hasattr(vlm, "llm") or hasattr(vlm, "vision_tower") or hasattr(vlm, "speech_tower") or hasattr(vlm, "sound_tower") or hasattr(vlm, "mm_projector") or hasattr(vlm, "speech_mm_projector") or hasattr(vlm, "sound_mm_projector"): if vlm.is_loaded: return vlm vlm.llm, vlm.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs) vlm.sound_tower = build_sound_tower(sound_tower_cfg, config) vlm.sound_mm_projector = build_sound_mm_projector(sound_mm_projector_cfg, config) self.post_config() self.is_loaded = True # FIXME(ligeng, yunhao): llm should never be none here. assert ( vlm.llm is not None or vlm.vision_tower is not None or vlm.speech_tower is not None or vlm.mm_projector is not None or vlm.speech_mm_projector is not None ), "At least one of the components must be instantiated." return vlm ## FIXME we will use this function to save the model in the future def save_pretrained(self, output_dir, state_dict=None): if state_dict is None: # other wise fetch from deepspeed # state_dict = accelerator.get_state_dict(is_deepspeed_enabled) state_dict = self.state_dict() if getattr(self, "tokenizer", None): self.tokenizer.save_pretrained(osp.join(output_dir, "llm")) if self.get_llm(): print(f"saving llm to {osp.join(output_dir, 'llm')}") self.llm.config._name_or_path = osp.join(output_dir, "llm") llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k}) self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict) self.config.llm_cfg = self.llm.config if self.get_sound_tower(): print(f"saving sound_tower to {osp.join(output_dir, 'sound_tower')}") self.sound_tower.config._name_or_path = osp.join(output_dir, "sound_tower") sound_tower_state_dict = OrderedDict( {k.split("sound_tower.sound_tower.")[-1]: v for k, v in state_dict.items() if "sound_tower" in k} ) self.sound_tower.sound_tower.save_pretrained( os.path.join(output_dir, "sound_tower"), state_dict=sound_tower_state_dict, ) self.config.sound_tower_cfg = self.sound_tower.config if self.get_sound_mm_projector(): print(f"saving sound_mm_projector to {osp.join(output_dir, 'sound_mm_projector')}") self.sound_mm_projector.config._name_or_path = osp.join(output_dir, "sound_mm_projector") sound_mm_projector_state_dict = OrderedDict( {k.split("sound_mm_projector.")[-1]: v for k, v in state_dict.items() if "sound_mm_projector" in k} ) self.sound_mm_projector.save_pretrained( os.path.join(output_dir, "sound_mm_projector"), state_dict=sound_mm_projector_state_dict, ) self.config.sound_mm_projector_cfg = self.sound_mm_projector.config ## update and save top-level config self.config._name_or_path = output_dir self.config.architectures = [self.__class__.__name__] self.config.save_pretrained(output_dir) def get_llm(self): llm = getattr(self, "llm", None) if type(llm) is list: llm = llm[0] return llm def get_lm_head(self): lm_head = getattr(self.get_llm(), "lm_head", None) return lm_head def get_sound_tower(self): sound_tower = getattr(self, "sound_tower", None) if type(sound_tower) is list: sound_tower = sound_tower[0] return sound_tower def get_sound_mm_projector(self): sound_mm_projector = getattr(self, "sound_mm_projector", None) if type(sound_mm_projector) is list: sound_mm_projector = sound_mm_projector[0] return sound_mm_projector def post_config(self): self.training = self.get_llm().training ## configuration if getattr(self.config, "llm_cfg", None) is None: self.config.llm_cfg = self.llm.config self.config.speech_tower_cfg = self.speech_tower.config if getattr(self.config, "sound_tower_cfg", None) is None: self.config.sound_tower_cfg = self.sound_tower.config if getattr(self.config, "sound_mm_projector_cfg", None) is None: self.config.sound_mm_projector_cfg = self.sound_mm_projector.config def freezed_module_patch(self): """ Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules. """ if self.training: if self.get_llm() and not getattr(self.config, "tune_language_model", False): pass if self.get_sound_tower() and not getattr(self.config, "tune_sound_tower", False): self.get_sound_tower().eval() if self.get_sound_mm_projector() and not getattr(self.config, "tune_sound_mm_projector", False): self.get_sound_mm_projector().eval() def encode_sound(self, sounds, masks=None): sound_features = self.get_sound_tower()(sounds, masks) sound_features = self.get_sound_mm_projector()(sound_features) return sound_features ## @yunhao: is there a better way to handle function call and attributes for llm? ## support beam search def _temporary_reorder_cache(self, past_key_values, sorted_idx): return self.get_llm()._temporary_reorder_cache(past_key_values, sorted_idx) def get_input_embeddings(self): return self.get_llm().get_input_embeddings() def get_output_embeddings(self): return self.get_llm().get_output_embeddings() def resize_token_embeddings(self, embed_size): self.get_llm().resize_token_embeddings(embed_size) class LlavaMetaForCausalLM(ABC): def _embed( self, input_ids: torch.Tensor, media: Dict[str, List[torch.Tensor]], media_config: Dict[str, Dict[str, Any]], labels: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor], media_meta: Dict[str, Dict[str, Any]]= None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: labels = labels if labels is not None else torch.full_like(input_ids, IGNORE_INDEX) attention_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids, dtype=torch.bool) PROCESS_GROUP_MANAGER = get_pg_manager() if PROCESS_GROUP_MANAGER is not None: for name in media: self.encoders[name].end_tokens = None # Extract text and media embeddings text_embeds = self.llm.model.embed_tokens(input_ids) media_embeds = self.__embed_media_tokens(media, media_config, media_meta) # This is a workaround to make sure the dummy embeddings are consumed while media_embeds.get("dummy"): dummy_embed = media_embeds["dummy"].popleft() text_embeds += torch.sum(dummy_embed) * 0 # Remove padding batch_size = labels.shape[0] # Build inverse mapping from token ID to media name media_tokens = {} for name, token_id in self.tokenizer.media_token_ids.items(): media_tokens[token_id] = name # -------------------------------- # num_audio_tokens = torch.stack(media_meta["sound_embed_masks"], dim=0).sum(-1) num_audio_tokens = torch.tensor([round(int(x) / 10) * 10 for x in num_audio_tokens]) num_audios = len(media_embeds['sound']) # length of queue is the number of audios we have in total max_audio_tokens, embed_dim = media_embeds['sound'][0].shape audio_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to( num_audio_tokens.device ) < num_audio_tokens.unsqueeze(1) audio_embeds = [] while media_embeds['sound']: audio_embeds.append(media_embeds['sound'].popleft()) audio_embeds = torch.stack(audio_embeds,dim=0) masked_audio_features = audio_embeds[audio_features_mask].view(-1, embed_dim) batch_size, sequence_length = input_ids.shape _left_padding = torch.any(attention_mask[:, 0] == 0) _right_padding = torch.any(attention_mask[:, -1] == 0) left_padding = True if batch_size > 1: if _left_padding and not _right_padding: left_padding = True elif not _left_padding and _right_padding: left_padding = False elif not _left_padding and not _right_padding: # both side is 1, so cannot tell left_padding = self.tokenizer.padding_side == "left" else: # invalid attention_mask raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}") # 1. Create a mask to know where special audio tokens are special_audio_token_mask = input_ids == self.tokenizer.media_token_ids['sound'] #hard coded to just work with 'sound' num_special_audio_tokens = torch.sum(special_audio_token_mask, dim=-1) # In case the Audio model or the Language model has been offloaded to CPU, we need to manually # set the corresponding tensors into their correct target device. target_device = text_embeds.device attention_mask = attention_mask.to(target_device) input_ids = input_ids.to(target_device) num_audio_tokens = num_audio_tokens.to(target_device) batch_indices, non_audio_indices = torch.where( (input_ids != self.tokenizer.media_token_ids['sound']) & (attention_mask == 1) ) # 2. Compute the positions where text should be written # Calculate new positions for text tokens in merged audio-text sequence. # `special_audio_token_mask` identifies audio tokens. Each audio token will be replaced by `audio_feat_lengths - 1` text tokens. # `torch.cumsum` computes how each audio token shifts subsequent text token positions. token_placeholder_num = torch.zeros_like(input_ids) token_placeholder_num[special_audio_token_mask] = num_audio_tokens.long() - 1 token_placeholder_num = token_placeholder_num + 1 new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1 max_token_num = token_placeholder_num.sum(-1).max() nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1] if left_padding: new_token_positions += nb_audio_pad[:, None] # offset for left padding text_to_overwrite = new_token_positions[batch_indices, non_audio_indices] batch_indices, non_audio_indices, text_to_overwrite = ( batch_indices.to(target_device), non_audio_indices.to(target_device), text_to_overwrite.to(target_device), ) # 3. Create the full embedding, already padded to the maximum position final_embedding = torch.zeros( batch_size, max_token_num, embed_dim, dtype=text_embeds.dtype, device=text_embeds.device ) final_attention_mask = torch.zeros( batch_size, max_token_num, dtype=attention_mask.dtype, device=text_embeds.device ) final_input_ids = torch.full( (batch_size, max_token_num), self.tokenizer.pad_token_id, dtype=input_ids.dtype, device=text_embeds.device ) # 4. Fill the embeddings based on the mask. If we have ["hey" "