update config & modeling
Browse files- config.json +1 -1
- modeling_grok1.py +6 -4
config.json
CHANGED
|
@@ -28,6 +28,6 @@
|
|
| 28 |
"num_experts": 8,
|
| 29 |
"output_router_logits": false,
|
| 30 |
"router_aux_loss_coef": 0.001,
|
| 31 |
-
"torch_dtype": "
|
| 32 |
"transformers_version": "4.35.0"
|
| 33 |
}
|
|
|
|
| 28 |
"num_experts": 8,
|
| 29 |
"output_router_logits": false,
|
| 30 |
"router_aux_loss_coef": 0.001,
|
| 31 |
+
"torch_dtype": "bfloat16",
|
| 32 |
"transformers_version": "4.35.0"
|
| 33 |
}
|
modeling_grok1.py
CHANGED
|
@@ -7,14 +7,16 @@ from transformers.modeling_utils import PreTrainedModel
|
|
| 7 |
from transformers.utils import logging
|
| 8 |
|
| 9 |
try:
|
| 10 |
-
from transformers.modeling_attn_mask_utils import
|
|
|
|
| 11 |
|
| 12 |
HAS_MASK_UTILS = True
|
| 13 |
except ImportError:
|
| 14 |
HAS_MASK_UTILS = False
|
| 15 |
|
| 16 |
from .configuration_grok1 import Grok1Config
|
| 17 |
-
from .modeling_grok1_outputs import MoeCausalLMOutputWithPast,
|
|
|
|
| 18 |
|
| 19 |
logger = logging.get_logger(__name__)
|
| 20 |
|
|
@@ -549,7 +551,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|
| 549 |
|
| 550 |
|
| 551 |
class Grok1Model(Grok1PretrainedModel):
|
| 552 |
-
def __init__(self, config: Grok1Config) -> None:
|
| 553 |
super().__init__(config)
|
| 554 |
self.padding_idx = config.pad_token_id
|
| 555 |
self.vocab_size = config.vocab_size
|
|
@@ -787,7 +789,7 @@ class Grok1Model(Grok1PretrainedModel):
|
|
| 787 |
class Grok1ModelForCausalLM(Grok1PretrainedModel):
|
| 788 |
_tied_weights_keys = ["lm_head.weight"]
|
| 789 |
|
| 790 |
-
def __init__(self, config: Grok1Config):
|
| 791 |
super().__init__(config)
|
| 792 |
self.model = Grok1Model(config)
|
| 793 |
self.vocab_size = config.vocab_size
|
|
|
|
| 7 |
from transformers.utils import logging
|
| 8 |
|
| 9 |
try:
|
| 10 |
+
from transformers.modeling_attn_mask_utils import \
|
| 11 |
+
_prepare_4d_causal_attention_mask
|
| 12 |
|
| 13 |
HAS_MASK_UTILS = True
|
| 14 |
except ImportError:
|
| 15 |
HAS_MASK_UTILS = False
|
| 16 |
|
| 17 |
from .configuration_grok1 import Grok1Config
|
| 18 |
+
from .modeling_grok1_outputs import (MoeCausalLMOutputWithPast,
|
| 19 |
+
MoeModelOutputWithPast)
|
| 20 |
|
| 21 |
logger = logging.get_logger(__name__)
|
| 22 |
|
|
|
|
| 551 |
|
| 552 |
|
| 553 |
class Grok1Model(Grok1PretrainedModel):
|
| 554 |
+
def __init__(self, config: Grok1Config, **kwargs) -> None:
|
| 555 |
super().__init__(config)
|
| 556 |
self.padding_idx = config.pad_token_id
|
| 557 |
self.vocab_size = config.vocab_size
|
|
|
|
| 789 |
class Grok1ModelForCausalLM(Grok1PretrainedModel):
|
| 790 |
_tied_weights_keys = ["lm_head.weight"]
|
| 791 |
|
| 792 |
+
def __init__(self, config: Grok1Config, **kwargs):
|
| 793 |
super().__init__(config)
|
| 794 |
self.model = Grok1Model(config)
|
| 795 |
self.vocab_size = config.vocab_size
|