model_type on top
Browse files- configuration_bert.py +2 -2
configuration_bert.py
CHANGED
|
@@ -40,6 +40,8 @@ class BertConfig(TransformersBertConfig):
|
|
| 40 |
|
| 41 |
|
| 42 |
class FlexBertConfig(TransformersBertConfig):
|
|
|
|
|
|
|
| 43 |
def __init__(
|
| 44 |
self,
|
| 45 |
attention_layer: str = "base",
|
|
@@ -97,7 +99,6 @@ class FlexBertConfig(TransformersBertConfig):
|
|
| 97 |
pad_logits: bool = False,
|
| 98 |
compile_model: bool = False,
|
| 99 |
masked_prediction: bool = False,
|
| 100 |
-
model_type: str = "flex_bert",
|
| 101 |
**kwargs,
|
| 102 |
):
|
| 103 |
"""
|
|
@@ -214,7 +215,6 @@ class FlexBertConfig(TransformersBertConfig):
|
|
| 214 |
self.pad_logits = pad_logits
|
| 215 |
self.compile_model = compile_model
|
| 216 |
self.masked_prediction = masked_prediction
|
| 217 |
-
self.model_type = model_type
|
| 218 |
|
| 219 |
if loss_kwargs.get("return_z_loss", False):
|
| 220 |
if loss_function != "fa_cross_entropy":
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
class FlexBertConfig(TransformersBertConfig):
|
| 43 |
+
model_type = "flex_bert"
|
| 44 |
+
|
| 45 |
def __init__(
|
| 46 |
self,
|
| 47 |
attention_layer: str = "base",
|
|
|
|
| 99 |
pad_logits: bool = False,
|
| 100 |
compile_model: bool = False,
|
| 101 |
masked_prediction: bool = False,
|
|
|
|
| 102 |
**kwargs,
|
| 103 |
):
|
| 104 |
"""
|
|
|
|
| 215 |
self.pad_logits = pad_logits
|
| 216 |
self.compile_model = compile_model
|
| 217 |
self.masked_prediction = masked_prediction
|
|
|
|
| 218 |
|
| 219 |
if loss_kwargs.get("return_z_loss", False):
|
| 220 |
if loss_function != "fa_cross_entropy":
|