Upload model
Browse files- model.safetensors +1 -1
- transformer.py +8 -6
- vitnqs_model.py +5 -1
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 3490136
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1e327504c22eaeca2ba25d074c771e63462500f1301ee258b9afef233455a82a
|
3 |
size 3490136
|
transformer.py
CHANGED
@@ -101,12 +101,14 @@ class OuputHead(nn.Module):
|
|
101 |
self.output_layer0 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
|
102 |
self.output_layer1 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
|
103 |
|
104 |
-
def __call__(self, x):
|
105 |
|
106 |
-
|
|
|
|
|
107 |
|
108 |
-
amp = self.norm2(self.output_layer0(
|
109 |
-
sign = self.norm3(self.output_layer1(
|
110 |
|
111 |
z = amp + 1j*sign
|
112 |
|
@@ -129,13 +131,13 @@ class ViT(nn.Module):
|
|
129 |
self.output = OuputHead(self.d_model)
|
130 |
|
131 |
|
132 |
-
def __call__(self, spins):
|
133 |
x = jnp.atleast_2d(spins)
|
134 |
|
135 |
x = self.patches_and_embed(x)
|
136 |
|
137 |
x = self.encoder(x)
|
138 |
|
139 |
-
z = self.output(x)
|
140 |
|
141 |
return z
|
|
|
101 |
self.output_layer0 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
|
102 |
self.output_layer1 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
|
103 |
|
104 |
+
def __call__(self, x, return_z=False):
|
105 |
|
106 |
+
z = self.out_layer_norm(x.sum(axis=1))
|
107 |
+
if return_z:
|
108 |
+
return z
|
109 |
|
110 |
+
amp = self.norm2(self.output_layer0(z))
|
111 |
+
sign = self.norm3(self.output_layer1(z))
|
112 |
|
113 |
z = amp + 1j*sign
|
114 |
|
|
|
131 |
self.output = OuputHead(self.d_model)
|
132 |
|
133 |
|
134 |
+
def __call__(self, spins, return_z=False):
|
135 |
x = jnp.atleast_2d(spins)
|
136 |
|
137 |
x = self.patches_and_embed(x)
|
138 |
|
139 |
x = self.encoder(x)
|
140 |
|
141 |
+
z = self.output(x, return_z=return_z)
|
142 |
|
143 |
return z
|
vitnqs_model.py
CHANGED
@@ -24,11 +24,15 @@ class ViTNQSModel(FlaxPreTrainedModel):
|
|
24 |
transl_invariant=config.tras_inv,
|
25 |
two_dimensional=config.two_dim,
|
26 |
)
|
|
|
|
|
|
|
|
|
27 |
|
28 |
super().__init__(config, ViT, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
29 |
|
30 |
def __call__(self, params, spins):
|
31 |
-
return self.model.apply(params, spins)
|
32 |
|
33 |
def init_weights(self, rng, input_shape):
|
34 |
return self.model.init(rng, input_shape)
|
|
|
24 |
transl_invariant=config.tras_inv,
|
25 |
two_dimensional=config.two_dim,
|
26 |
)
|
27 |
+
if not "return_z" in kwargs:
|
28 |
+
self.return_z = False
|
29 |
+
else:
|
30 |
+
self.return_z = kwargs["return_z"]
|
31 |
|
32 |
super().__init__(config, ViT, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
33 |
|
34 |
def __call__(self, params, spins):
|
35 |
+
return self.model.apply(params, spins, self.return_z)
|
36 |
|
37 |
def init_weights(self, rng, input_shape):
|
38 |
return self.model.init(rng, input_shape)
|