rrende commited on
Commit
f5f03c4
·
verified ·
1 Parent(s): 358fd2d

Upload model

Browse files
Files changed (3) hide show
  1. model.safetensors +1 -1
  2. transformer.py +8 -6
  3. vitnqs_model.py +5 -1
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:73bcef74adf67486945e05ed20e67eb48d071fbe22c8b3de8aed501b2f417df7
3
  size 3490136
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e327504c22eaeca2ba25d074c771e63462500f1301ee258b9afef233455a82a
3
  size 3490136
transformer.py CHANGED
@@ -101,12 +101,14 @@ class OuputHead(nn.Module):
101
  self.output_layer0 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
102
  self.output_layer1 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
103
 
104
- def __call__(self, x):
105
 
106
- x = self.out_layer_norm(x.sum(axis=1))
 
 
107
 
108
- amp = self.norm2(self.output_layer0(x))
109
- sign = self.norm3(self.output_layer1(x))
110
 
111
  z = amp + 1j*sign
112
 
@@ -129,13 +131,13 @@ class ViT(nn.Module):
129
  self.output = OuputHead(self.d_model)
130
 
131
 
132
- def __call__(self, spins):
133
  x = jnp.atleast_2d(spins)
134
 
135
  x = self.patches_and_embed(x)
136
 
137
  x = self.encoder(x)
138
 
139
- z = self.output(x)
140
 
141
  return z
 
101
  self.output_layer0 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
102
  self.output_layer1 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
103
 
104
+ def __call__(self, x, return_z=False):
105
 
106
+ z = self.out_layer_norm(x.sum(axis=1))
107
+ if return_z:
108
+ return z
109
 
110
+ amp = self.norm2(self.output_layer0(z))
111
+ sign = self.norm3(self.output_layer1(z))
112
 
113
  z = amp + 1j*sign
114
 
 
131
  self.output = OuputHead(self.d_model)
132
 
133
 
134
+ def __call__(self, spins, return_z=False):
135
  x = jnp.atleast_2d(spins)
136
 
137
  x = self.patches_and_embed(x)
138
 
139
  x = self.encoder(x)
140
 
141
+ z = self.output(x, return_z=return_z)
142
 
143
  return z
vitnqs_model.py CHANGED
@@ -24,11 +24,15 @@ class ViTNQSModel(FlaxPreTrainedModel):
24
  transl_invariant=config.tras_inv,
25
  two_dimensional=config.two_dim,
26
  )
 
 
 
 
27
 
28
  super().__init__(config, ViT, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
29
 
30
  def __call__(self, params, spins):
31
- return self.model.apply(params, spins)
32
 
33
  def init_weights(self, rng, input_shape):
34
  return self.model.init(rng, input_shape)
 
24
  transl_invariant=config.tras_inv,
25
  two_dimensional=config.two_dim,
26
  )
27
+ if not "return_z" in kwargs:
28
+ self.return_z = False
29
+ else:
30
+ self.return_z = kwargs["return_z"]
31
 
32
  super().__init__(config, ViT, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
33
 
34
  def __call__(self, params, spins):
35
+ return self.model.apply(params, spins, self.return_z)
36
 
37
  def init_weights(self, rng, input_shape):
38
  return self.model.init(rng, input_shape)