Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from transformers import LlamaForCausalLM, Cache | |
| from transformers.activations import GELUActivation | |
| from transformers.utils import logging | |
| from .configuration_live_llama import LiveLlamaConfig | |
| from ..modeling_live import build_live, LiveMixin | |
| logger = logging.get_logger(__name__) | |
| class LiveLlamaForCausalLM(LlamaForCausalLM, LiveMixin): | |
| config_class = LiveLlamaConfig | |
| _keys_to_ignore_on_load_missing = ['vision_encoder', 'connector'] | |
| def __init__(self, config: LiveLlamaConfig): | |
| super().__init__(config) | |
| self.connector = torch.nn.Sequential( | |
| torch.nn.Linear(config.vision_hidden_size, config.hidden_size, bias=True), | |
| GELUActivation(config.hidden_size), | |
| torch.nn.Linear(config.hidden_size, config.hidden_size, bias=True), | |
| ) | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| frames: torch.FloatTensor = None, | |
| attention_mask: torch.Tensor = None, | |
| position_ids: torch.LongTensor = None, | |
| past_key_values: list[torch.FloatTensor] = None, | |
| inputs_embeds: torch.FloatTensor = None, | |
| labels: torch.LongTensor = None, | |
| use_cache: bool = None, | |
| output_attentions: bool = None, | |
| output_hidden_states: bool = None, | |
| return_dict: bool = None, | |
| cache_position: torch.LongTensor = None, | |
| **kwargs, | |
| ): | |
| if inputs_embeds is None: | |
| inputs_embeds = self.joint_embed(input_ids, frames) | |
| outputs = super().forward( | |
| attention_mask = attention_mask, | |
| position_ids = position_ids, | |
| past_key_values = past_key_values, | |
| inputs_embeds = inputs_embeds, | |
| # labels | |
| use_cache = use_cache, | |
| output_attentions = output_attentions, | |
| output_hidden_states = output_hidden_states, | |
| return_dict = return_dict, | |
| cache_position=cache_position, | |
| ) | |
| loss = None | |
| if labels is not None: | |
| logits = outputs[0] | |
| v_mask = input_ids.flatten(0, 1) == self.config.v_placeholder_id | |
| weight = v_mask * self.config.stream_loss_weight + ~v_mask | |
| loss = nn.functional.cross_entropy(logits.flatten(0, 1), labels.flatten(), reduction='none') * weight | |
| loss = loss.sum() / (labels >= 0).sum() | |
| if not return_dict: | |
| return (loss,) + outputs[1:] if loss is not None else outputs | |
| outputs.loss = loss | |
| return outputs | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids, | |
| past_key_values=None, | |
| attention_mask=None, | |
| inputs_embeds=None, | |
| cache_position=None, | |
| use_cache=True, | |
| **kwargs, | |
| ): | |
| past_length = 0 | |
| if past_key_values is not None: | |
| if isinstance(past_key_values, Cache): | |
| past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() | |
| max_cache_length = ( | |
| torch.tensor(past_key_values.get_max_length(), device=input_ids.device) | |
| if past_key_values.get_max_length() is not None | |
| else None | |
| ) | |
| cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) | |
| # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects | |
| else: | |
| cache_length = past_length = past_key_values[0][0].shape[2] | |
| max_cache_length = None | |
| # Keep only the unprocessed tokens: | |
| # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where | |
| # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) | |
| if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: | |
| input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] | |
| # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard | |
| # input_ids based on the past_length. | |
| elif past_length < input_ids.shape[1]: | |
| input_ids = input_ids[:, past_length:] | |
| # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. | |
| # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. | |
| if ( | |
| max_cache_length is not None | |
| and attention_mask is not None | |
| and cache_length + input_ids.shape[1] > max_cache_length | |
| ): | |
| attention_mask = attention_mask[:, -max_cache_length:] | |
| position_ids = kwargs.get("position_ids", None) | |
| if attention_mask is not None and position_ids is None: | |
| # create position_ids on the fly for batch generation | |
| position_ids = attention_mask.long().cumsum(-1) - 1 | |
| position_ids.masked_fill_(attention_mask == 0, 1) | |
| if past_key_values: | |
| position_ids = position_ids[:, past_length :] # NOTE | |
| # NOTE | |
| if inputs_embeds is not None and past_length < inputs_embeds.size(1): | |
| model_inputs = {"inputs_embeds": inputs_embeds[:, past_length:]} | |
| else: | |
| # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise | |
| # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 | |
| # TODO: use `next_tokens` directly instead. | |
| model_inputs = {"input_ids": input_ids.contiguous()} | |
| input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] | |
| if cache_position is None: | |
| cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) | |
| elif use_cache: | |
| cache_position = cache_position[-input_length:] | |
| model_inputs.update( | |
| { | |
| "position_ids": position_ids, # 长度为新的inputs,从past开始 | |
| "cache_position": cache_position, # 没有被cache的区域 | |
| "past_key_values": past_key_values, | |
| "use_cache": use_cache, | |
| "attention_mask": attention_mask, # cache + input的长度 | |
| } | |
| ) | |
| return model_inputs | |
| def build_live_llama(**kwargs): | |
| return build_live(config_class=LiveLlamaConfig, model_class=LiveLlamaForCausalLM, **kwargs) | |
| if __name__ == '__main__': | |
| from ..arguments_live import LiveOnePlusTrainingArguments | |
| print(LiveOnePlusTrainingArguments().to_dict()) | |
| model, tokenizer = build_live_llama(is_training=True, **LiveOnePlusTrainingArguments().to_dict()) | |
| print(model.config, tokenizer) | |