Spaces:
Runtime error
Runtime error
| import math | |
| from dataclasses import dataclass | |
| from typing import List, Optional, Tuple, Union, Dict, Any | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from torch.nn import CrossEntropyLoss | |
| from transformers.cache_utils import Cache, DynamicCache | |
| from transformers.utils import ModelOutput, logging | |
| from transformers.models.llama.modeling_llama import LlamaModel, LlamaPreTrainedModel | |
| from audio_transformer import AudioTransformer | |
| logger = logging.get_logger(__name__) | |
| # Copied from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L43 | |
| class LayerNorm(torch.nn.LayerNorm): | |
| """Layer norm with transpose""" | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| x = input.transpose(-2, -1) | |
| x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) | |
| x = x.transpose(-2, -1) | |
| return x | |
| # Copied from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L53 | |
| class ConvLayerBlock(torch.nn.Module): | |
| """Convolution unit of FeatureExtractor""" | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| stride: int, | |
| bias: bool, | |
| layer_norm: Optional[torch.nn.Module], | |
| ): | |
| super().__init__() | |
| self.kernel_size = kernel_size | |
| self.stride = stride | |
| self.layer_norm = layer_norm | |
| self.conv = torch.nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| bias=bias, | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """ | |
| Args: | |
| x (Tensor): Shape: ``[batch, in_channels, in_frame]``. | |
| Returns: | |
| Tensor: Shape ``[batch, out_channels, out_frames]``. | |
| Optional[Tensor]: Shape ``[batch, ]``. | |
| """ | |
| x = self.conv(x) | |
| if self.layer_norm is not None: | |
| x = self.layer_norm(x) | |
| x = torch.nn.functional.gelu(x) | |
| return x | |
| # Copied from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L146 | |
| class FeatureProjection(torch.nn.Module): | |
| """Layer that connects FeatureExtractor and Encoder | |
| Projects features to encoder dimension. | |
| Args: | |
| in_features (int): Input feature dim. | |
| out_features (int): Output feature dim. | |
| dropout (float): Dropout probability. | |
| """ | |
| def __init__( | |
| self, | |
| in_features: int, | |
| out_features: int, | |
| dropout=0.1, | |
| ): | |
| super().__init__() | |
| self.layer_norm = torch.nn.LayerNorm(in_features) | |
| self.projection = torch.nn.Linear( | |
| in_features, | |
| out_features, | |
| ) | |
| self.dropout = torch.nn.Dropout(dropout) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x (Tensor): | |
| Feature Tensor. shape: ``[batch, frame, in_feature]`` | |
| Returns: | |
| Tensor: Projected features. ``[batch, frame, out_feature]``. | |
| """ | |
| x = self.layer_norm(x) | |
| x = self.projection(x) | |
| x = self.dropout(x) | |
| return x | |
| # Modified from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L102 | |
| class FeatureExtractor(torch.nn.Module): | |
| """Extract features from audio | |
| Args: | |
| conv_layers (nn.ModuleList): | |
| convolution layers | |
| """ | |
| def __init__( | |
| self, | |
| shapes=[(512, 10, 5), (512, 3, 2), (512, 3, 2), (512, 3, 2), (512, 3, 2), (512, 2, 2), (512, 2, 2)], | |
| bias=False, | |
| norm_mode="group_norm", | |
| ): | |
| super().__init__() | |
| if norm_mode not in ["group_norm", "layer_norm"]: | |
| raise ValueError("Invalid norm mode") | |
| blocks = [] | |
| in_channels = 1 | |
| for i, (out_channels, kernel_size, stride) in enumerate(shapes): | |
| normalization = None | |
| if norm_mode == "group_norm" and i == 0: | |
| normalization = torch.nn.GroupNorm( | |
| num_groups=out_channels, | |
| num_channels=out_channels, | |
| affine=True, | |
| ) | |
| elif norm_mode == "layer_norm": | |
| normalization = LayerNorm( | |
| normalized_shape=out_channels, | |
| elementwise_affine=True, | |
| ) | |
| blocks.append( | |
| ConvLayerBlock( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| bias=bias, | |
| layer_norm=normalization, | |
| ) | |
| ) | |
| in_channels = out_channels | |
| self.conv_layers = torch.nn.ModuleList(blocks) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """ | |
| Args: | |
| x (Tensor): | |
| Input Tensor representing a batch of audio, | |
| shape: ``[batch, time]``. | |
| Returns: | |
| Tensor: | |
| The resulting feature, shape: ``[batch, frame, feature]`` | |
| Optional[Tensor]: | |
| Valid length of each output sample. shape: ``[batch, ]``. | |
| """ | |
| if x.ndim != 2: | |
| raise ValueError(f"Expected the input Tensor to be 2D (batch, time). Found: {list(x.shape)}") | |
| x = x.unsqueeze(1) # (batch, channel==1, frame) | |
| for layer in self.conv_layers: | |
| x = layer(x) # (batch, feature, frame) | |
| x = x.transpose(1, 2) # (batch, frame, feature) | |
| return x | |
| # Modified from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L102 | |
| class FeatureExtractorAdapter(torch.nn.Module): | |
| """Extract features from audio | |
| Args: | |
| conv_layers (nn.ModuleList): | |
| convolution layers | |
| """ | |
| def __init__( | |
| self, | |
| shapes=(512, 512, 2, 2), | |
| hidden_size=2048, | |
| bias=False, | |
| norm_mode="group_norm", | |
| ): | |
| super().__init__() | |
| if norm_mode not in ["group_norm", "layer_norm"]: | |
| raise ValueError("Invalid norm mode") | |
| in_channels, out_channels, kernel_size, stride = shapes | |
| normalization = LayerNorm( | |
| normalized_shape=out_channels, | |
| elementwise_affine=True, | |
| ) | |
| self.conv_layers = ConvLayerBlock( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| bias=False, | |
| layer_norm=normalization, | |
| ) | |
| self.feat_proj = FeatureProjection(out_channels, hidden_size) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """ | |
| Args: | |
| x (Tensor): | |
| Input Tensor representing a batch of audio, | |
| shape: ``[batch, time]``. | |
| Returns: | |
| Tensor: | |
| The resulting feature, shape: ``[batch, frame, feature]`` | |
| Optional[Tensor]: | |
| Valid length of each output sample. shape: ``[batch, ]``. | |
| """ | |
| x = x.transpose(1, 2) # (batch, feature, frame) | |
| x = self.conv_layers(x) # (batch, feature, frame) | |
| x = x.transpose(1, 2) # (batch, frame, feature) | |
| x = self.feat_proj(x) | |
| return x | |
| class VoilaOutput(ModelOutput): | |
| """ | |
| Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_outputs.py#L678 | |
| Base class for Voila outputs. | |
| Args: | |
| loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): | |
| Language modeling loss (for next-token prediction). | |
| logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): | |
| Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | |
| last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): | |
| The hidden state of the last attention layer. | |
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | |
| Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape | |
| `(batch_size, num_heads, sequence_length, embed_size_per_head)`) | |
| Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see | |
| `past_key_values` input) to speed up sequential decoding. | |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): | |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + | |
| one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. | |
| Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. | |
| attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): | |
| Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, | |
| sequence_length)`. | |
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention | |
| heads. | |
| """ | |
| loss: Optional[torch.FloatTensor] = None | |
| logits: torch.FloatTensor = None | |
| last_hidden_state: torch.FloatTensor = None | |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
| attentions: Optional[Tuple[torch.FloatTensor]] = None | |
| voila_pred: Optional[torch.FloatTensor] = None | |
| # Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1103 | |
| class VoilaModel(LlamaPreTrainedModel): | |
| _tied_weights_keys = ["lm_head.weight"] | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = LlamaModel(config) | |
| self.vocab_size = config.vocab_size | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| self.pad_vocab_size_multiple = 64 | |
| self.ref_emb_linear = nn.Linear(256, config.hidden_size, bias=True) | |
| self.audio_transformer = AudioTransformer(config, use_sdpa=False) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.model.embed_tokens | |
| def set_input_embeddings(self, value): | |
| self.model.embed_tokens = value | |
| def get_output_embeddings(self): | |
| return self.lm_head | |
| def set_output_embeddings(self, new_embeddings): | |
| self.lm_head = new_embeddings | |
| def set_decoder(self, decoder): | |
| self.model = decoder | |
| def get_decoder(self): | |
| return self.model | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| audio_labels: Optional[torch.LongTensor] = None, | |
| ref_embs: Optional[List[torch.Tensor]] = None, | |
| ref_embs_mask: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| num_logits_to_keep: int = 0, | |
| ) -> Union[Tuple, VoilaOutput]: | |
| r""" | |
| Args: | |
| input_ids: [bs, seq_len, num_codebooks] | |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
| """ | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if input_ids is not None and inputs_embeds is not None: | |
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") | |
| if inputs_embeds is None: | |
| inputs_embeds = self.model.embed_tokens(input_ids) | |
| assert len(inputs_embeds.shape) == 4 | |
| if len(inputs_embeds.shape) == 4: | |
| inputs_embeds = inputs_embeds.mean(dim=2) | |
| if self.training or \ | |
| (past_key_values is None and ref_embs is not None) or \ | |
| (past_key_values is not None and past_key_values.get_seq_length() < 4 and ref_embs is not None): | |
| ref_embs = self.ref_emb_linear(ref_embs.to(self.ref_emb_linear.weight.dtype)) | |
| ref_embs = ref_embs * ref_embs_mask.unsqueeze(-1).unsqueeze(-1) | |
| # (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back) | |
| padding = (0, 0, 4, inputs_embeds.shape[1] - 5, 0, 0) | |
| ref_embs = torch.nn.functional.pad(ref_embs, padding, mode='constant', value=0.0) | |
| inputs_embeds = inputs_embeds + ref_embs | |
| # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
| outputs = self.model( | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| cache_position=cache_position, | |
| ) | |
| hidden_states = outputs[0] | |
| if self.config.pretraining_tp > 1: | |
| lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) | |
| logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] | |
| logits = torch.cat(logits, dim=-1) | |
| else: | |
| # Only compute necessary logits, and do not upcast them to float if we are not computing the loss | |
| logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) | |
| loss = None | |
| if not return_dict: | |
| output = (logits,) + outputs[1:] | |
| return (loss,) + output if loss is not None else output | |
| return VoilaOutput( | |
| loss=loss, | |
| logits=logits, | |
| last_hidden_state=hidden_states, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| def _prepare_inputs_for_generation( | |
| self, input_ids, ref_embs=None, ref_embs_mask=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs | |
| ): | |
| if past_key_values is not None and past_key_values.get_seq_length() > 0: | |
| if isinstance(past_key_values, Cache): | |
| cache_length = past_key_values.get_seq_length() | |
| past_length = past_key_values.seen_tokens | |
| max_cache_length = past_key_values.get_max_cache_shape() | |
| else: | |
| cache_length = past_length = past_key_values[0][0].shape[2] | |
| max_cache_length = None | |
| # Keep only the unprocessed tokens: | |
| # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where | |
| # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as | |
| # input) | |
| if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: | |
| input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] | |
| # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard | |
| # input_ids based on the past_length. | |
| elif past_length < input_ids.shape[1]: | |
| input_ids = input_ids[:, past_length:] | |
| # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. | |
| # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. | |
| if ( | |
| max_cache_length is not None | |
| and attention_mask is not None | |
| and cache_length + input_ids.shape[1] > max_cache_length | |
| ): | |
| attention_mask = attention_mask[:, -max_cache_length:] | |
| position_ids = kwargs.get("position_ids", None) | |
| if attention_mask is not None and position_ids is None: | |
| # create position_ids on the fly for batch generation | |
| position_ids = attention_mask.long().cumsum(-1) - 1 | |
| position_ids.masked_fill_(attention_mask == 0, 1) | |
| if past_key_values: | |
| position_ids = position_ids[:, -input_ids.shape[1] :] | |
| # if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
| if inputs_embeds is None and \ | |
| (past_key_values is None or past_key_values.get_seq_length() <= 0): | |
| inputs_embeds = self.model.embed_tokens(input_ids) | |
| if inputs_embeds is not None and \ | |
| (past_key_values is None or past_key_values.get_seq_length() <= 0): | |
| model_inputs = {"inputs_embeds": inputs_embeds, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask} | |
| else: | |
| model_inputs = {"input_ids": input_ids, "ref_embs": None} | |
| model_inputs.update( | |
| { | |
| "position_ids": position_ids, | |
| "past_key_values": past_key_values, | |
| "use_cache": kwargs.get("use_cache"), | |
| "attention_mask": attention_mask, | |
| } | |
| ) | |
| return model_inputs | |
| def _update_model_kwargs_for_generation( | |
| self, | |
| outputs, | |
| model_kwargs: Dict[str, Any], | |
| num_new_token: int = 1, | |
| ) -> Dict[str, Any]: | |
| # update past_key_values | |
| model_kwargs["past_key_values"] = outputs.past_key_values | |
| # update attention mask | |
| if "attention_mask" in model_kwargs: | |
| attention_mask = model_kwargs["attention_mask"] | |
| model_kwargs["attention_mask"] = torch.cat( | |
| [attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_token))], dim=-1 | |
| ) | |
| return model_kwargs | |
| def _prepare_attention_mask_for_generation( | |
| self, | |
| inputs: torch.Tensor, | |
| pad_token_id: Optional[int], | |
| eos_token_id: Optional[Union[int, List[int]]], | |
| ) -> torch.LongTensor: | |
| is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] | |
| is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) | |
| if isinstance(eos_token_id, int): | |
| eos_token_id = [eos_token_id] | |
| is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id) | |
| # Check if input is input_ids and padded -> only then is attention_mask defined | |
| if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: | |
| return inputs.ne(pad_token_id).long() | |
| else: | |
| return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) | |
| def run_generate( | |
| self, | |
| input_ids: torch.LongTensor, | |
| ref_embs: Optional[List[torch.Tensor]] = None, | |
| ref_embs_mask: Optional[torch.LongTensor] = None, | |
| max_new_tokens: Optional[int] = 128, | |
| pad_token_id: Optional[int] = None, | |
| eos_token_id: Optional[Union[int, List[int]]] = None, | |
| streamer: Optional["BaseStreamer"] = None, | |
| llm_audio_token_id: Optional[int] = None, | |
| min_audio_token_id: Optional[int] = None, | |
| temperature=0.2, | |
| top_k=50, | |
| audio_temperature=0.2, | |
| audio_top_k=50, | |
| ): | |
| assert eos_token_id is not None and pad_token_id is not None, "eos_token_id and pad_token_id are required for inference" | |
| assert llm_audio_token_id is not None and min_audio_token_id is not None, "llm_audio_token_id and min_audio_token_id are required for inference" | |
| assert len(input_ids.shape) == 2 or len(input_ids.shape) == 3, f"input_ids is supposed to be [batch, seq_len] or [batch, seq_len, num_codebooks], and got {input_ids.shape}" | |
| eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device) | |
| # keep track of which sequences are already finished | |
| unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | |
| # Extend input_ids with additional num_codebooks dim | |
| if len(input_ids.shape) == 2: | |
| input_ids = input_ids[:, :, None].expand(1, 1, self.config.num_codebooks) | |
| this_peer_finished = False # used by synced_gpus only | |
| max_length = input_ids.shape[1] + max_new_tokens | |
| model_kwargs = { | |
| "use_cache": True, | |
| "past_key_values": DynamicCache(), | |
| "attention_mask": self._prepare_attention_mask_for_generation( | |
| input_ids, pad_token_id, eos_token_id | |
| ), | |
| } | |
| # auto-regressive generation | |
| while True: | |
| # prepare model inputs | |
| model_inputs = self._prepare_inputs_for_generation( | |
| input_ids, | |
| ref_embs=ref_embs, | |
| ref_embs_mask=ref_embs_mask, | |
| **model_kwargs | |
| ) | |
| # forward pass to get next token | |
| outputs = self( | |
| **model_inputs, | |
| return_dict=True, | |
| ) | |
| audio_tokens = self.audio_transformer.inference( | |
| outputs.last_hidden_state, | |
| temperature=audio_temperature, | |
| top_k=audio_top_k, | |
| ) | |
| audio_tokens = torch.stack( | |
| [ | |
| audio_tokens[:, :, ci] + min_audio_token_id + ci*self.config.codebook_size | |
| for ci in range(self.config.num_codebooks) | |
| ], | |
| dim=2, | |
| ) | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # pre-process distribution | |
| # Apply temperature and top-k | |
| if temperature > 0: | |
| next_token_logits = next_token_logits / temperature | |
| if top_k > 0: | |
| top_k = min(top_k, next_token_logits.size(-1)) # Safety check | |
| # Remove all tokens with a probability less than the last token of the top-k | |
| indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] | |
| next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf")) | |
| # sample | |
| probs = nn.functional.softmax(next_token_logits, dim=-1) | |
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
| # finished sentences should have their next token be a padding token | |
| if eos_token_id is not None: | |
| if pad_token_id is None: | |
| raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") | |
| next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
| # Append NUM_CODEBOOK text tokens or audio_tokens | |
| if len(next_tokens.shape) == 1: | |
| next_tokens = next_tokens[:, None, None].expand(-1, 1, self.config.num_codebooks) | |
| next_tokens = torch.where(next_tokens==llm_audio_token_id, audio_tokens, next_tokens) | |
| input_ids = torch.cat([input_ids, next_tokens], dim=1) | |
| if streamer is not None: | |
| streamer.put(next_tokens.cpu()) | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, model_kwargs | |
| ) | |
| # if eos_token was found in one sentence, set sentence to finished | |
| if eos_token_id_tensor is not None: | |
| unfinished_sequences = unfinished_sequences.mul( | |
| next_tokens[:, :, 0].ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=1) | |
| ) | |
| # stop when each sentence is finished | |
| if unfinished_sequences.max() == 0: | |
| this_peer_finished = True | |
| # stop if we exceed the maximum length | |
| if input_ids.shape[1] >= max_length: | |
| this_peer_finished = True | |
| if this_peer_finished: | |
| break | |
| if streamer is not None: | |
| streamer.end() | |
| return input_ids | |
| # Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1103 | |
| class VoilaAudioAlphaModel(LlamaPreTrainedModel): | |
| _tied_weights_keys = ["lm_head.weight"] | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = LlamaModel(config) | |
| self.vocab_size = config.vocab_size | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| self.pad_vocab_size_multiple = 64 | |
| self.ref_emb_linear = nn.Linear(256, config.hidden_size, bias=True) | |
| self.audio_transformer = AudioTransformer(config, use_sdpa=False) | |
| self.feature_extractor = FeatureExtractor() | |
| self.audio_feature_extractor_adapter = FeatureExtractorAdapter(hidden_size=config.hidden_size) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.model.embed_tokens | |
| def set_input_embeddings(self, value): | |
| self.model.embed_tokens = value | |
| def get_output_embeddings(self): | |
| return self.lm_head | |
| def set_output_embeddings(self, new_embeddings): | |
| self.lm_head = new_embeddings | |
| def set_decoder(self, decoder): | |
| self.model = decoder | |
| def get_decoder(self): | |
| return self.model | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| audio_labels: Optional[torch.LongTensor] = None, | |
| ref_embs: Optional[List[torch.Tensor]] = None, | |
| ref_embs_mask: Optional[torch.LongTensor] = None, | |
| audio_datas: Optional[torch.FloatTensor] = None, | |
| audio_data_masks: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| num_logits_to_keep: int = 0, | |
| ) -> Union[Tuple, VoilaOutput]: | |
| r""" | |
| Args: | |
| input_ids: [bs, seq_len, num_codebooks] | |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
| """ | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if input_ids is not None and inputs_embeds is not None: | |
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") | |
| if inputs_embeds is None: | |
| inputs_embeds = self.model.embed_tokens(input_ids) | |
| assert len(inputs_embeds.shape) == 4 | |
| if len(inputs_embeds.shape) == 4: | |
| inputs_embeds = inputs_embeds.mean(dim=2) | |
| if self.training or \ | |
| (past_key_values is None and ref_embs is not None) or \ | |
| (past_key_values is not None and past_key_values.get_seq_length() < 4 and ref_embs is not None): | |
| ref_embs = self.ref_emb_linear(ref_embs.to(self.ref_emb_linear.weight.dtype)) | |
| ref_embs = ref_embs * ref_embs_mask.unsqueeze(-1).unsqueeze(-1) | |
| # (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back) | |
| padding = (0, 0, 4, inputs_embeds.shape[1] - 5, 0, 0) | |
| ref_embs = torch.nn.functional.pad(ref_embs, padding, mode='constant', value=0.0) | |
| inputs_embeds = inputs_embeds + ref_embs | |
| if self.training or audio_datas is not None: | |
| audio_embeds = self.feature_extractor(audio_datas) | |
| audio_embeds = self.audio_feature_extractor_adapter(audio_embeds) | |
| audio_embeds = audio_embeds * audio_data_masks[..., None] | |
| inputs_embeds = inputs_embeds + audio_embeds | |
| # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
| outputs = self.model( | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| cache_position=cache_position, | |
| ) | |
| hidden_states = outputs[0] | |
| if self.config.pretraining_tp > 1: | |
| lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) | |
| logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] | |
| logits = torch.cat(logits, dim=-1) | |
| else: | |
| # Only compute necessary logits, and do not upcast them to float if we are not computing the loss | |
| logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) | |
| loss = None | |
| if labels is not None: | |
| # Upcast to float if we need to compute the loss to avoid potential precision issues | |
| logits = logits.float() | |
| # We shift tokens and labels in dataloader | |
| shift_logits = logits.contiguous() | |
| shift_labels = labels.contiguous() | |
| # Flatten the tokens | |
| loss_fct = CrossEntropyLoss() | |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) | |
| shift_labels = shift_labels.view(-1) | |
| # Enable model parallelism | |
| shift_labels = shift_labels.to(shift_logits.device) | |
| loss = loss_fct(shift_logits, shift_labels) | |
| if audio_labels is not None: | |
| au_mask = (audio_labels >= 0).all(dim=-1) | |
| au_hidden_states = hidden_states[au_mask] | |
| au_audio_labels = audio_labels[au_mask] | |
| if len(au_hidden_states) <= 0: | |
| au_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) | |
| au_audio_labels = torch.zeros_like(audio_labels).reshape(-1, self.config.num_codebooks) | |
| loss_weight = 0.0 | |
| else: | |
| loss_weight = 1.0 | |
| au_logits = self.audio_transformer(au_hidden_states, au_audio_labels) | |
| # We shift tokens and labels in dataloader | |
| shift_au_logits = au_logits.contiguous() | |
| shift_audio_labels = au_audio_labels.contiguous() | |
| # Flatten the tokens | |
| loss_fct = CrossEntropyLoss() | |
| shift_au_logits = shift_au_logits.view(-1, self.config.codebook_size) | |
| shift_audio_labels = shift_audio_labels.view(-1) | |
| # Enable model parallelism | |
| shift_audio_labels = shift_audio_labels.to(shift_au_logits.device) | |
| au_loss = loss_fct(shift_au_logits, shift_audio_labels) | |
| loss += au_loss * loss_weight | |
| else: | |
| # au_tokens = self.audio_transformer.inference(hidden_states) | |
| pass | |
| if not return_dict: | |
| output = (logits,) + outputs[1:] | |
| return (loss,) + output if loss is not None else output | |
| return VoilaOutput( | |
| loss=loss, | |
| logits=logits, | |
| last_hidden_state=hidden_states, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| def _prepare_inputs_for_generation( | |
| self, input_ids, ref_embs=None, ref_embs_mask=None, audio_datas=None, audio_data_masks=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs | |
| ): | |
| if past_key_values is not None and past_key_values.get_seq_length() > 0: | |
| if isinstance(past_key_values, Cache): | |
| cache_length = past_key_values.get_seq_length() | |
| past_length = past_key_values.seen_tokens | |
| max_cache_length = past_key_values.get_max_cache_shape() | |
| else: | |
| cache_length = past_length = past_key_values[0][0].shape[2] | |
| max_cache_length = None | |
| # Keep only the unprocessed tokens: | |
| # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where | |
| # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as | |
| # input) | |
| if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: | |
| input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] | |
| # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard | |
| # input_ids based on the past_length. | |
| elif past_length < input_ids.shape[1]: | |
| input_ids = input_ids[:, past_length:] | |
| # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. | |
| # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. | |
| if ( | |
| max_cache_length is not None | |
| and attention_mask is not None | |
| and cache_length + input_ids.shape[1] > max_cache_length | |
| ): | |
| attention_mask = attention_mask[:, -max_cache_length:] | |
| position_ids = kwargs.get("position_ids", None) | |
| if attention_mask is not None and position_ids is None: | |
| # create position_ids on the fly for batch generation | |
| position_ids = attention_mask.long().cumsum(-1) - 1 | |
| position_ids.masked_fill_(attention_mask == 0, 1) | |
| if past_key_values: | |
| position_ids = position_ids[:, -input_ids.shape[1] :] | |
| # if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
| if inputs_embeds is None and \ | |
| (past_key_values is None or past_key_values.get_seq_length() <= 0): | |
| inputs_embeds = self.model.embed_tokens(input_ids) | |
| if inputs_embeds is not None and \ | |
| (past_key_values is None or past_key_values.get_seq_length() <= 0): | |
| model_inputs = {"inputs_embeds": inputs_embeds, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask, "audio_datas": audio_datas, "audio_data_masks": audio_data_masks} | |
| else: | |
| model_inputs = {"input_ids": input_ids, "ref_embs": None, "audio_datas": None, "audio_data_masks": None} | |
| model_inputs.update( | |
| { | |
| "position_ids": position_ids, | |
| "past_key_values": past_key_values, | |
| "use_cache": kwargs.get("use_cache"), | |
| "attention_mask": attention_mask, | |
| } | |
| ) | |
| return model_inputs | |
| def _update_model_kwargs_for_generation( | |
| self, | |
| outputs, | |
| model_kwargs: Dict[str, Any], | |
| num_new_token: int = 1, | |
| ) -> Dict[str, Any]: | |
| # update past_key_values | |
| model_kwargs["past_key_values"] = outputs.past_key_values | |
| # update attention mask | |
| if "attention_mask" in model_kwargs: | |
| attention_mask = model_kwargs["attention_mask"] | |
| model_kwargs["attention_mask"] = torch.cat( | |
| [attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_token))], dim=-1 | |
| ) | |
| return model_kwargs | |
| def _prepare_attention_mask_for_generation( | |
| self, | |
| inputs: torch.Tensor, | |
| pad_token_id: Optional[int], | |
| eos_token_id: Optional[Union[int, List[int]]], | |
| ) -> torch.LongTensor: | |
| is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] | |
| is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) | |
| if isinstance(eos_token_id, int): | |
| eos_token_id = [eos_token_id] | |
| is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id) | |
| # Check if input is input_ids and padded -> only then is attention_mask defined | |
| if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: | |
| return inputs.ne(pad_token_id).long() | |
| else: | |
| return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) | |
| def run_generate( | |
| self, | |
| input_ids: torch.LongTensor, | |
| ref_embs: Optional[List[torch.Tensor]] = None, | |
| ref_embs_mask: Optional[torch.LongTensor] = None, | |
| audio_datas: Optional[torch.FloatTensor] = None, | |
| audio_data_masks: Optional[torch.LongTensor] = None, | |
| max_new_tokens: Optional[int] = 128, | |
| pad_token_id: Optional[int] = None, | |
| eos_token_id: Optional[Union[int, List[int]]] = None, | |
| streamer: Optional["BaseStreamer"] = None, | |
| llm_audio_token_id: Optional[int] = None, | |
| min_audio_token_id: Optional[int] = None, | |
| temperature=0.2, | |
| top_k=50, | |
| audio_temperature=0.2, | |
| audio_top_k=50, | |
| ): | |
| assert eos_token_id is not None and pad_token_id is not None, "eos_token_id and pad_token_id are required for inference" | |
| assert llm_audio_token_id is not None and min_audio_token_id is not None, "llm_audio_token_id and min_audio_token_id are required for inference" | |
| assert len(input_ids.shape) == 2 or len(input_ids.shape) == 3, f"input_ids is supposed to be [batch, seq_len] or [batch, seq_len, num_codebooks], and got {input_ids.shape}" | |
| eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device) | |
| # keep track of which sequences are already finished | |
| unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | |
| # Extend input_ids with additional num_codebooks dim | |
| if len(input_ids.shape) == 2: | |
| input_ids = input_ids[:, :, None].expand(1, 1, self.config.num_codebooks) | |
| this_peer_finished = False # used by synced_gpus only | |
| max_length = input_ids.shape[1] + max_new_tokens | |
| model_kwargs = { | |
| "use_cache": True, | |
| "past_key_values": DynamicCache(), | |
| "attention_mask": self._prepare_attention_mask_for_generation( | |
| input_ids, pad_token_id, eos_token_id | |
| ), | |
| } | |
| # auto-regressive generation | |
| while True: | |
| # prepare model inputs | |
| model_inputs = self._prepare_inputs_for_generation( | |
| input_ids, | |
| ref_embs=ref_embs, | |
| ref_embs_mask=ref_embs_mask, | |
| audio_datas=audio_datas, | |
| audio_data_masks=audio_data_masks, | |
| **model_kwargs | |
| ) | |
| # forward pass to get next token | |
| outputs = self( | |
| **model_inputs, | |
| return_dict=True, | |
| ) | |
| audio_tokens = self.audio_transformer.inference( | |
| outputs.last_hidden_state, | |
| temperature=audio_temperature, | |
| top_k=audio_top_k, | |
| ) | |
| audio_tokens = torch.stack( | |
| [ | |
| audio_tokens[:, :, ci] + min_audio_token_id + ci*self.config.codebook_size | |
| for ci in range(self.config.num_codebooks) | |
| ], | |
| dim=2, | |
| ) | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # pre-process distribution | |
| # Apply temperature and top-k | |
| if temperature > 0: | |
| next_token_logits = next_token_logits / temperature | |
| if top_k > 0: | |
| top_k = min(top_k, next_token_logits.size(-1)) # Safety check | |
| # Remove all tokens with a probability less than the last token of the top-k | |
| indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] | |
| next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf")) | |
| # sample | |
| probs = nn.functional.softmax(next_token_logits, dim=-1) | |
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
| # finished sentences should have their next token be a padding token | |
| if eos_token_id is not None: | |
| if pad_token_id is None: | |
| raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") | |
| next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
| # Append NUM_CODEBOOK text tokens or audio_tokens | |
| if len(next_tokens.shape) == 1: | |
| next_tokens = next_tokens[:, None, None].expand(-1, 1, self.config.num_codebooks) | |
| next_tokens = torch.where(next_tokens==llm_audio_token_id, audio_tokens, next_tokens) | |
| input_ids = torch.cat([input_ids, next_tokens], dim=1) | |
| if streamer is not None: | |
| streamer.put(next_tokens.cpu()) | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, model_kwargs | |
| ) | |
| # if eos_token was found in one sentence, set sentence to finished | |
| if eos_token_id_tensor is not None: | |
| unfinished_sequences = unfinished_sequences.mul( | |
| next_tokens[:, :, 0].ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=1) | |
| ) | |
| # stop when each sentence is finished | |
| if unfinished_sequences.max() == 0: | |
| this_peer_finished = True | |
| # stop if we exceed the maximum length | |
| if input_ids.shape[1] >= max_length: | |
| this_peer_finished = True | |
| if this_peer_finished: | |
| break | |
| if streamer is not None: | |
| streamer.end() | |
| return input_ids | |
| # Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1103 | |
| class VoilaAutonomousModel(LlamaPreTrainedModel): | |
| _tied_weights_keys = ["lm_head.weight"] | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = LlamaModel(config) | |
| self.vocab_size = config.vocab_size | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| self.pad_vocab_size_multiple = 64 | |
| self.ref_emb_linear = nn.Linear(256, config.hidden_size, bias=True) | |
| self.audio_transformer = AudioTransformer(config, use_sdpa=False) | |
| self.voila_predictor = nn.Sequential(nn.Linear(config.hidden_size, 2, bias=True),) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.model.embed_tokens | |
| def set_input_embeddings(self, value): | |
| self.model.embed_tokens = value | |
| def get_output_embeddings(self): | |
| return self.lm_head | |
| def set_output_embeddings(self, new_embeddings): | |
| self.lm_head = new_embeddings | |
| def set_decoder(self, decoder): | |
| self.model = decoder | |
| def get_decoder(self): | |
| return self.model | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| audio_labels: Optional[torch.LongTensor] = None, | |
| voila_labels: Optional[torch.LongTensor] = None, | |
| ref_embs: Optional[List[torch.Tensor]] = None, | |
| ref_embs_mask: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| num_logits_to_keep: int = 0, | |
| ) -> Union[Tuple, VoilaOutput]: | |
| r""" | |
| Args: | |
| input_ids: [bs, seq_len, num_codebooks] | |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
| """ | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if input_ids is not None and inputs_embeds is not None: | |
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") | |
| if inputs_embeds is None: | |
| inputs_embeds = self.model.embed_tokens(input_ids) | |
| assert len(inputs_embeds.shape) == 4 | |
| if len(inputs_embeds.shape) == 4: | |
| inputs_embeds = inputs_embeds.mean(dim=2) | |
| if self.training or \ | |
| (past_key_values is None and ref_embs is not None) or \ | |
| (past_key_values is not None and past_key_values.get_seq_length() < 4 and ref_embs is not None): | |
| ref_embs = self.ref_emb_linear(ref_embs.to(self.ref_emb_linear.weight.dtype)) | |
| ref_embs = ref_embs * ref_embs_mask.unsqueeze(-1).unsqueeze(-1) | |
| # (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back) | |
| padding = (0, 0, 4, inputs_embeds.shape[1] - 5, 0, 0) | |
| ref_embs = torch.nn.functional.pad(ref_embs, padding, mode='constant', value=0.0) | |
| inputs_embeds = inputs_embeds + ref_embs | |
| # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
| outputs = self.model( | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| cache_position=cache_position, | |
| ) | |
| hidden_states = outputs[0] | |
| if self.config.pretraining_tp > 1: | |
| lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) | |
| logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] | |
| logits = torch.cat(logits, dim=-1) | |
| else: | |
| # Only compute necessary logits, and do not upcast them to float if we are not computing the loss | |
| logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) | |
| # calc voila_predict_loss | |
| voila_pred = self.voila_predictor(hidden_states) | |
| voila_pred = voila_pred.float() | |
| loss = None | |
| if not return_dict: | |
| output = (logits,) + outputs[1:] | |
| return (loss,) + output if loss is not None else output | |
| return VoilaOutput( | |
| loss=loss, | |
| logits=logits, | |
| last_hidden_state=hidden_states, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| voila_pred=voila_pred, | |
| ) | |
| def _prepare_inputs_for_generation( | |
| self, input_ids, ref_embs=None, ref_embs_mask=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs | |
| ): | |
| if past_key_values is not None and past_key_values.get_seq_length() > 0: | |
| if isinstance(past_key_values, Cache): | |
| cache_length = past_key_values.get_seq_length() | |
| past_length = past_key_values.seen_tokens | |
| max_cache_length = past_key_values.get_max_cache_shape() | |
| else: | |
| cache_length = past_length = past_key_values[0][0].shape[2] | |
| max_cache_length = None | |
| # Keep only the unprocessed tokens: | |
| # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where | |
| # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as | |
| # input) | |
| if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: | |
| input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] | |
| # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard | |
| # input_ids based on the past_length. | |
| elif past_length < input_ids.shape[1]: | |
| input_ids = input_ids[:, past_length:] | |
| # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. | |
| # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. | |
| if ( | |
| max_cache_length is not None | |
| and attention_mask is not None | |
| and cache_length + input_ids.shape[1] > max_cache_length | |
| ): | |
| attention_mask = attention_mask[:, -max_cache_length:] | |
| position_ids = kwargs.get("position_ids", None) | |
| if attention_mask is not None and position_ids is None: | |
| # create position_ids on the fly for batch generation | |
| position_ids = attention_mask.long().cumsum(-1) - 1 | |
| position_ids.masked_fill_(attention_mask == 0, 1) | |
| if past_key_values: | |
| position_ids = position_ids[:, -input_ids.shape[1] :] | |
| # if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
| if inputs_embeds is None and \ | |
| (past_key_values is None or past_key_values.get_seq_length() <= 0): | |
| inputs_embeds = self.model.embed_tokens(input_ids) | |
| if inputs_embeds is not None and \ | |
| (past_key_values is None or past_key_values.get_seq_length() <= 0): | |
| model_inputs = {"inputs_embeds": inputs_embeds, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask} | |
| else: | |
| model_inputs = {"input_ids": input_ids, "ref_embs": None} | |
| model_inputs.update( | |
| { | |
| "position_ids": position_ids, | |
| "past_key_values": past_key_values, | |
| "use_cache": kwargs.get("use_cache"), | |
| "attention_mask": attention_mask, | |
| } | |
| ) | |
| return model_inputs | |
| def _update_model_kwargs_for_generation( | |
| self, | |
| outputs, | |
| model_kwargs: Dict[str, Any], | |
| num_new_token: int = 1, | |
| ) -> Dict[str, Any]: | |
| # update past_key_values | |
| model_kwargs["past_key_values"] = outputs.past_key_values | |
| # update attention mask | |
| if "attention_mask" in model_kwargs: | |
| attention_mask = model_kwargs["attention_mask"] | |
| model_kwargs["attention_mask"] = torch.cat( | |
| [attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_token))], dim=-1 | |
| ) | |
| return model_kwargs | |
| def _prepare_attention_mask_for_generation( | |
| self, | |
| inputs: torch.Tensor, | |
| pad_token_id: Optional[int], | |
| eos_token_id: Optional[Union[int, List[int]]], | |
| ) -> torch.LongTensor: | |
| is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] | |
| is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) | |
| if isinstance(eos_token_id, int): | |
| eos_token_id = [eos_token_id] | |
| is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id) | |
| # Check if input is input_ids and padded -> only then is attention_mask defined | |
| if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: | |
| return inputs.ne(pad_token_id).long() | |
| else: | |
| return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) | |
| def run_generate( | |
| self, | |
| input_ids: torch.LongTensor, | |
| input_generator, | |
| ref_embs: Optional[List[torch.Tensor]] = None, | |
| ref_embs_mask: Optional[torch.LongTensor] = None, | |
| max_new_tokens: Optional[int] = 128, | |
| pad_token_id: Optional[int] = None, | |
| eos_token_id: Optional[Union[int, List[int]]] = None, | |
| streamer: Optional["BaseStreamer"] = None, | |
| llm_audio_token_id: Optional[int] = None, | |
| min_audio_token_id: Optional[int] = None, | |
| llm_assistant_token_id: Optional[int] = None, | |
| temperature=0.2, | |
| top_k=50, | |
| audio_temperature=0.8, | |
| audio_top_k=50, | |
| ): | |
| assert eos_token_id is not None and pad_token_id is not None, "eos_token_id and pad_token_id are required for inference" | |
| assert llm_audio_token_id is not None and min_audio_token_id is not None, "llm_audio_token_id and min_audio_token_id are required for inference" | |
| assert len(input_ids.shape) == 2 or len(input_ids.shape) == 3, f"input_ids is supposed to be [batch, seq_len] or [batch, seq_len, num_codebooks], and got {input_ids.shape}" | |
| eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device) | |
| # keep track of which sequences are already finished | |
| unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | |
| # Extend input_ids with additional num_codebooks dim | |
| input_ids = input_ids.clone() | |
| if len(input_ids.shape) == 2: | |
| input_ids = input_ids[:, :, None].expand(1, 1, self.config.num_codebooks) | |
| this_peer_finished = False # used by synced_gpus only | |
| max_length = input_ids.shape[1] + max_new_tokens | |
| model_kwargs = { | |
| "use_cache": True, | |
| "past_key_values": DynamicCache(), | |
| "attention_mask": self._prepare_attention_mask_for_generation( | |
| input_ids, pad_token_id, eos_token_id | |
| ), | |
| } | |
| speaking = False | |
| # auto-regressive generation | |
| while True: | |
| # prepare model inputs | |
| model_inputs = self._prepare_inputs_for_generation( | |
| input_ids, | |
| ref_embs=ref_embs, | |
| ref_embs_mask=ref_embs_mask, | |
| **model_kwargs | |
| ) | |
| # forward pass to get next token | |
| outputs = self( | |
| **model_inputs, | |
| return_dict=True, | |
| ) | |
| audio_tokens = self.audio_transformer.inference( | |
| outputs.last_hidden_state, | |
| temperature=audio_temperature, | |
| top_k=audio_top_k, | |
| ) | |
| audio_tokens = torch.stack( | |
| [ | |
| audio_tokens[:, :, ci] + min_audio_token_id + ci*self.config.codebook_size | |
| for ci in range(self.config.num_codebooks) | |
| ], | |
| dim=2, | |
| ) | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # voila head output | |
| voila_head_pred = outputs.voila_pred[:, -1, :] | |
| voila_head_pred = torch.argmax(voila_head_pred, dim=-1) | |
| voila_head_pred = voila_head_pred.cpu()[0].item() | |
| # pre-process distribution | |
| # Apply temperature and top-k | |
| if temperature > 0: | |
| next_token_logits = next_token_logits / temperature | |
| if top_k > 0: | |
| top_k = min(top_k, next_token_logits.size(-1)) # Safety check | |
| # Remove all tokens with a probability less than the last token of the top-k | |
| indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] | |
| next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf")) | |
| # sample | |
| probs = nn.functional.softmax(next_token_logits, dim=-1) | |
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
| # voila head pred == 1, use assistant token | |
| if voila_head_pred == 1 and not speaking: | |
| next_tokens[0] = llm_assistant_token_id | |
| speaking = True | |
| elif next_tokens[0] == eos_token_id: | |
| speaking = False | |
| # finished sentences should have their next token be a padding token | |
| if eos_token_id is not None: | |
| if pad_token_id is None: | |
| raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") | |
| next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
| # Append NUM_CODEBOOK text tokens or audio_tokens | |
| if len(next_tokens.shape) == 1: | |
| next_tokens = next_tokens[:, None, None].expand(-1, 1, self.config.num_codebooks) | |
| audio_token_mask = next_tokens == llm_audio_token_id | |
| next_tokens = next_tokens * torch.logical_not(audio_token_mask) + audio_tokens * audio_token_mask | |
| if audio_token_mask[0, 0, 0].item(): | |
| try: | |
| new_input_tokens = next(input_generator) | |
| except: | |
| this_peer_finished = True | |
| break | |
| new_input_tokens = new_input_tokens[None,None,:] | |
| else: | |
| new_input_tokens = next_tokens | |
| new_input_tokens = torch.cat([new_input_tokens, next_tokens], dim=2) | |
| input_ids = torch.cat([input_ids, new_input_tokens], dim=1) | |
| if streamer is not None: | |
| streamer.put(next_tokens.cpu()) | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, model_kwargs | |
| ) | |
| # # if eos_token was found in one sentence, set sentence to finished | |
| # if eos_token_id_tensor is not None: | |
| # unfinished_sequences = unfinished_sequences.mul( | |
| # next_tokens[:, :, 0].ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=1) | |
| # ) | |
| # # stop when each sentence is finished | |
| # if unfinished_sequences.max() == 0: | |
| # this_peer_finished = True | |
| # stop if we exceed the maximum length | |
| if input_ids.shape[1] >= max_length: | |
| this_peer_finished = True | |
| if this_peer_finished: | |
| break | |
| if streamer is not None: | |
| streamer.end() | |
| return input_ids | |