Spaces:
Sleeping
Sleeping
Commit
·
6e800de
1
Parent(s):
f81d7bd
Revert modeling_nova
Browse files- modeling_nova.py +3 -6
modeling_nova.py
CHANGED
@@ -5,7 +5,7 @@ import torch.nn as nn
|
|
5 |
import torch.nn.functional as F
|
6 |
from typing import Tuple, List, Optional
|
7 |
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
|
8 |
-
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaMLP, LlamaRMSNorm
|
9 |
from transformers.models.llama.modeling_llama import LlamaSdpaAttention, apply_rotary_pos_emb, repeat_kv
|
10 |
from transformers import logging, Cache, DynamicCache, StaticCache
|
11 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
@@ -421,10 +421,7 @@ class NovaModel(LlamaModel):
|
|
421 |
if position_ids is None:
|
422 |
position_ids = cache_position.unsqueeze(0)
|
423 |
|
424 |
-
|
425 |
-
past_seen_tokens = None
|
426 |
-
|
427 |
-
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens, output_attentions=False)
|
428 |
|
429 |
# apply the nova attention
|
430 |
if nova_attention_mask is not None:
|
@@ -667,4 +664,4 @@ class NovaForCausalLM(LlamaForCausalLM, NovaGenerationMixin):
|
|
667 |
"no_mask_idx": kwargs.get("no_mask_idx")
|
668 |
}
|
669 |
)
|
670 |
-
return model_inputs
|
|
|
5 |
import torch.nn.functional as F
|
6 |
from typing import Tuple, List, Optional
|
7 |
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
|
8 |
+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LLAMA_ATTENTION_CLASSES, LlamaMLP, LlamaRMSNorm
|
9 |
from transformers.models.llama.modeling_llama import LlamaSdpaAttention, apply_rotary_pos_emb, repeat_kv
|
10 |
from transformers import logging, Cache, DynamicCache, StaticCache
|
11 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
|
421 |
if position_ids is None:
|
422 |
position_ids = cache_position.unsqueeze(0)
|
423 |
|
424 |
+
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
|
|
|
|
|
|
|
425 |
|
426 |
# apply the nova attention
|
427 |
if nova_attention_mask is not None:
|
|
|
664 |
"no_mask_idx": kwargs.get("no_mask_idx")
|
665 |
}
|
666 |
)
|
667 |
+
return model_inputs
|