# -*- coding: utf-8 -*- from __future__ import annotations from typing import List, Optional, Tuple, Union import torch import torchaudio from torch import nn from transformers import ( AutoModel, AutoModelForCausalLM, Cache, Gemma3Config, PreTrainedModel, PretrainedConfig, StaticCache, HybridCache, ) from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.gemma3.modeling_gemma3 import ( Gemma3CausalLMOutputWithPast, Gemma3ForConditionalGeneration, Gemma3RMSNorm, ) from transformers.utils import is_torchdynamo_compiling, logging from .speech_conformer_encoder import ConformerEncoder logger = logging.get_logger(__name__) class Gemma3AudioProjectorConfig(PretrainedConfig): model_type = "gemma3_audio" def __init__( self, hidden_size: int = 1024, num_hidden_layers: int = 24, sample_rate: int = 16_000, n_mels: int = 80, audio_token_id: int = 0, **kwargs, ): super().__init__(**kwargs) self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.sample_rate = sample_rate self.n_mels = n_mels self.audio_token_id = audio_token_id class Gemma3AudioProjector(PreTrainedModel): """Conformer-based audio encoder → project to LM hidden-dim.""" config_class = Gemma3AudioProjectorConfig base_model_prefix = "audio_projector" def __init__(self, config: Gemma3AudioProjectorConfig): super().__init__(config) # encoder_config = config.audio_processor.get("config", None) encoder_config = { "activation": "swish", "activation_checkpointing": { "interval": 1, "module": "transformer", "offload": False }, "attention_dim": 1024, "attention_heads": 16, "batch_norm": False, "bias_in_glu": True, "causal": True, "chunk_size": -1, "cnn_layer_norm": True, "conv_activation": "swish", "conv_glu_type": "swish", "depthwise_multiplier": 1, "depthwise_seperable_out_channel": 1024, "dropout_rate": 0.0, "encoder_embedding_config": { "input_size": 80 }, "ext_pw_kernel_size": 1, "ext_pw_out_channel": 1024, "input_layer": "nemo_conv", "input_size": 80, "kernel_size": 3, "left_chunk": 18, "linear_units": 1536, "nemo_conv_settings": { "conv_channels": 1024 }, "num_blocks": 24, "relative_attention_bias_args": { "t5_bias_max_distance": 500, "type": "t5" }, "time_reduction": 8 } self.encoder = ConformerEncoder(**encoder_config) self.mel = torchaudio.transforms.MelSpectrogram( sample_rate=config.sample_rate, n_mels=config.n_mels ) self.proj = nn.Linear(1024, config.hidden_size, bias=False) self.layer_norm = nn.LayerNorm(config.hidden_size) self.post_init() # ---------- helpers ---------- def wav2mel(self, wav: torch.Tensor) -> torch.Tensor: return self.mel(wav).clamp(min=1e-5).log().transpose(1, 2) # ---------- forward ---------- @torch.no_grad() def forward(self, wav: torch.Tensor) -> torch.Tensor: # (B,T) or (B,1,T) if wav.dim() == 3: wav = wav.squeeze(1) mel = self.wav2mel(wav) lengths = torch.full( (mel.size(0),), mel.size(1), dtype=torch.long, device=mel.device ) hidden = self.encoder(mel, lengths) hidden = self.proj(hidden) return self.layer_norm(hidden) # ────────────────────────────────────────────────────────────────────────────── # Vision projector (與原版一致,只改 dtype) # ────────────────────────────────────────────────────────────────────────────── class Gemma3VisionProjector(nn.Module): def __init__(self, config: Gemma3Config): super().__init__() self.mm_input_projection_weight = nn.Parameter( torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size) ) self.mm_soft_emb_norm = Gemma3RMSNorm( config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps ) self.patches_per_image = config.vision_config.image_size // config.vision_config.patch_size self.tokens_per_side = int(config.mm_tokens_per_image ** 0.5) self.kernel_size = self.patches_per_image // self.tokens_per_side self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) def forward(self, vision_outputs: torch.Tensor): b, _, seq_len = vision_outputs.shape x = vision_outputs.transpose(1, 2).reshape( b, seq_len, self.patches_per_image, self.patches_per_image ) x = self.avg_pool(x).flatten(2).transpose(1, 2) x = self.mm_soft_emb_norm(x) return torch.matmul(x, self.mm_input_projection_weight).type_as(vision_outputs) # ────────────────────────────────────────────────────────────────────────────── # Gemma-3 Multimodal wrapper # ────────────────────────────────────────────────────────────────────────────── class Gemma3OmniForConditionalGeneration(Gemma3ForConditionalGeneration): """Gemma-3 Omni:vision + audio + text causal LM.""" def __init__(self, config: Gemma3Config): super().__init__(config) # ---- sub-modules self.vision_tower = AutoModel.from_config(config=config.vision_config) self.multi_modal_projector = Gemma3VisionProjector(config) self.audio_projector = Gemma3AudioProjector( Gemma3AudioProjectorConfig(hidden_size=config.text_config.hidden_size) ) self.vocab_size = config.text_config.vocab_size language_model = AutoModelForCausalLM.from_config(config=config.text_config) if language_model._tied_weights_keys is not None: self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] self.language_model = language_model self.pad_token_id = ( self.config.pad_token_id if self.config.pad_token_id is not None else -1 ) self.post_init() # ---------- helper ---------- def get_audio_features(self, audio_values: torch.Tensor) -> torch.Tensor: return self.audio_projector(audio_values) def _update_causal_mask( self, attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training: bool = False, ): if self.config.text_config._attn_implementation == "flash_attention_2": return attention_mask if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted # form and requires no inversion or slicing. return attention_mask using_static_cache = isinstance(past_key_values, StaticCache) min_dtype = torch.finfo(self.dtype).min inputs_lead_dim, sequence_length = input_tensor.shape[:2] if using_static_cache: target_length = past_key_values.get_max_cache_shape() elif isinstance(past_key_values, HybridCache): target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[0] + sequence_length + 1 ) if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. return attention_mask causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device ) # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) # Apply bidirectional mask on images if token type ids are provided if token_type_ids is not None and sequence_length != 1: token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2) token_type_mask[token_type_ids == 0] = False # if text token do not change anything token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool) causal_mask = causal_mask.clone() causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill( token_type_mask, 0.0 ) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] # Then apply padding mask (will mask pad tokens) padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) return causal_mask # ---------- forward ---------- def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, audio_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, token_type_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **lm_kwargs, ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: # === input validation === if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("Exactly one of input_ids or inputs_embeds must be provided") output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) is_training = token_type_ids is not None and labels is not None # OOV image token → pad if input_ids is not None and self.config.image_token_id >= self.vocab_size: special_image_mask = input_ids == self.config.image_token_id llm_input_ids = input_ids.clone() llm_input_ids[special_image_mask] = 0 else: llm_input_ids = input_ids if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(llm_input_ids) # cache_position if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device, ) # === merge image === if pixel_values is not None: image_feat = self.get_image_features(pixel_values) special_image_mask = ( ( inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, device=inputs_embeds.device) ) ) if input_ids is None else ( input_ids == self.config.image_token_id ).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) ) if ( not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_feat.numel() ): raise ValueError("#image tokens ≠ #embedding slots") inputs_embeds = inputs_embeds.masked_scatter( special_image_mask, image_feat.to(inputs_embeds) ) # === merge audio === if audio_values is not None: audio_feat = self.get_audio_features(audio_values) # special_audio_mask = ( # ( # inputs_embeds # == self.get_input_embeddings()( # torch.tensor(self.config.audio_token_id, device=inputs_embeds.device) # ) # ) # if input_ids is None # else ( # input_ids == self.config.audio_token_id # ).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) # ) # if ( # not is_torchdynamo_compiling() # and inputs_embeds[special_audio_mask].numel() != audio_feat.numel() # ): # raise ValueError("#audio tokens ≠ #embedding slots") # inputs_embeds = inputs_embeds.masked_scatter( # special_audio_mask, audio_feat.to(inputs_embeds) # ) print(audio_feat.shape, inputs_embeds.shape) inputs_embeds = torch.cat([audio_feat, inputs_embeds], dim=1) # === label masking === if labels is not None and self.pad_token_id in labels: logger.warning_once( "`labels` contains `pad_token_id`; they will be masked out at loss computation." ) labels = torch.where( input_ids == self.pad_token_id, self.config.ignore_index, labels ) causal_mask = self._update_causal_mask( attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training, ) outputs: CausalLMOutputWithPast = self.language_model( attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, logits_to_keep=logits_to_keep, **lm_kwargs, ) # === loss === logits = outputs.logits loss = None if labels is not None: logits = logits.float() shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:] if attention_mask is not None: shift_attention_mask = attention_mask[:, -shift_logits.shape[1]:].to( logits.device ) shift_logits = shift_logits[shift_attention_mask != 0].contiguous() shift_labels = shift_labels[shift_attention_mask != 0].contiguous() loss = nn.CrossEntropyLoss()( shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1), ) return Gemma3CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_feat if pixel_values is not None else None, ) # ────────────────────────────────────────────────────────────────────────────── # exports # ────────────────────────────────────────────────────────────────────────────── __all__ = [ "Gemma3AudioProjectorConfig", "Gemma3AudioProjector", "Gemma3VisionProjector", "Gemma3OmniForConditionalGeneration", ]