ejschwartz commited on
Commit
6e800de
·
1 Parent(s): f81d7bd

Revert modeling_nova

Browse files
Files changed (1) hide show
  1. modeling_nova.py +3 -6
modeling_nova.py CHANGED
@@ -5,7 +5,7 @@ import torch.nn as nn
5
  import torch.nn.functional as F
6
  from typing import Tuple, List, Optional
7
  from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
8
- from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaMLP, LlamaRMSNorm
9
  from transformers.models.llama.modeling_llama import LlamaSdpaAttention, apply_rotary_pos_emb, repeat_kv
10
  from transformers import logging, Cache, DynamicCache, StaticCache
11
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -421,10 +421,7 @@ class NovaModel(LlamaModel):
421
  if position_ids is None:
422
  position_ids = cache_position.unsqueeze(0)
423
 
424
- if past_seen_tokens == 0:
425
- past_seen_tokens = None
426
-
427
- causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens, output_attentions=False)
428
 
429
  # apply the nova attention
430
  if nova_attention_mask is not None:
@@ -667,4 +664,4 @@ class NovaForCausalLM(LlamaForCausalLM, NovaGenerationMixin):
667
  "no_mask_idx": kwargs.get("no_mask_idx")
668
  }
669
  )
670
- return model_inputs
 
5
  import torch.nn.functional as F
6
  from typing import Tuple, List, Optional
7
  from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
8
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LLAMA_ATTENTION_CLASSES, LlamaMLP, LlamaRMSNorm
9
  from transformers.models.llama.modeling_llama import LlamaSdpaAttention, apply_rotary_pos_emb, repeat_kv
10
  from transformers import logging, Cache, DynamicCache, StaticCache
11
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 
421
  if position_ids is None:
422
  position_ids = cache_position.unsqueeze(0)
423
 
424
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
 
 
 
425
 
426
  # apply the nova attention
427
  if nova_attention_mask is not None:
 
664
  "no_mask_idx": kwargs.get("no_mask_idx")
665
  }
666
  )
667
+ return model_inputs