Update vit_fnqs_model.py
Browse files- vit_fnqs_model.py +4 -4
vit_fnqs_model.py
CHANGED
@@ -25,10 +25,10 @@ class ViTFNQSModel(FlaxPreTrainedModel):
|
|
25 |
transl_invariant=config.tras_inv,
|
26 |
two_dimensional=config.two_dim,
|
27 |
)
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
|
33 |
super().__init__(config, ViTFNQS, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
34 |
|
|
|
25 |
transl_invariant=config.tras_inv,
|
26 |
two_dimensional=config.two_dim,
|
27 |
)
|
28 |
+
if not "return_z" in kwargs:
|
29 |
+
self.return_z = False
|
30 |
+
else:
|
31 |
+
self.return_z = kwargs["return_z"]
|
32 |
|
33 |
super().__init__(config, ViTFNQS, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
34 |
|