rrende commited on
Commit
d30db4c
·
verified ·
1 Parent(s): 615bc8e

Upload model

Browse files
Files changed (7) hide show
  1. README.md +199 -0
  2. attentions.py +54 -0
  3. config.json +18 -0
  4. model.safetensors +3 -0
  5. transformer.py +141 -0
  6. vitnqs_config.py +26 -0
  7. vitnqs_model.py +34 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
attentions.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+
4
+ from flax import linen as nn
5
+ import jax.numpy as jnp
6
+
7
+ from einops import rearrange
8
+
9
+ def roll(J, shift, axis=-1):
10
+ return jnp.roll(J, shift, axis=axis)
11
+
12
+ from functools import partial
13
+ @partial(jax.vmap, in_axes=(None, 0, None), out_axes=1)
14
+ @partial(jax.vmap, in_axes=(None, None, 0), out_axes=1)
15
+ def roll2d(spins, i, j):
16
+ side = int(spins.shape[-1]**0.5)
17
+ spins = spins.reshape(spins.shape[0], side, side)
18
+ spins = jnp.roll(jnp.roll(spins, i, axis=-2), j, axis=-1)
19
+ return spins.reshape(spins.shape[0], -1)
20
+
21
+ class FMHA(nn.Module):
22
+ d_model : int
23
+ h: int
24
+ L_eff: int
25
+ transl_invariant: bool = True
26
+ two_dimensional: bool = False
27
+
28
+ def setup(self):
29
+ self.v = nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64)
30
+ if self.transl_invariant:
31
+ self.J = self.param("J", nn.initializers.xavier_uniform(), (self.h, self.L_eff), jnp.float64)
32
+ if self.two_dimensional:
33
+ sq_L_eff = int(self.L_eff**0.5)
34
+ assert sq_L_eff * sq_L_eff == self.L_eff
35
+ self.J = roll2d(self.J, jnp.arange(sq_L_eff), jnp.arange(sq_L_eff))
36
+ self.J = self.J.reshape(self.h, -1, self.L_eff)
37
+ else:
38
+ self.J = jax.vmap(roll, (None, 0), out_axes=1)(self.J, jnp.arange(self.L_eff))
39
+ else:
40
+ self.J = self.param("J", nn.initializers.xavier_uniform(), (self.h, self.L_eff, self.L_eff), jnp.float64)
41
+
42
+ self.W = nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64)
43
+
44
+ def __call__(self, x):
45
+ v = self.v(x)
46
+ v = rearrange(v, 'batch L_eff (h d_eff) -> batch L_eff h d_eff', h=self.h)
47
+ v = rearrange(v, 'batch L_eff h d_eff -> batch h L_eff d_eff')
48
+ x = jnp.matmul(self.J, v)
49
+ x = rearrange(x, 'batch h L_eff d_eff -> batch L_eff h d_eff')
50
+ x = rearrange(x, 'batch L_eff h d_eff -> batch L_eff (h d_eff)')
51
+
52
+ x = self.W(x)
53
+
54
+ return x
config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "L_eff": 25,
3
+ "architectures": [
4
+ "QSModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "vitnqs_config.ViTNQSConfig",
8
+ "FlaxAutoModel": "vitnqs_model.ViTNQSModel"
9
+ },
10
+ "b": 2,
11
+ "d_model": 72,
12
+ "heads": 12,
13
+ "model_type": "vit_nqs",
14
+ "num_layers": 8,
15
+ "transformers_version": "4.48.0",
16
+ "tras_inv": true,
17
+ "two_dim": true
18
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73bcef74adf67486945e05ed20e67eb48d071fbe22c8b3de8aed501b2f417df7
3
+ size 3490136
transformer.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+
4
+ from flax import linen as nn
5
+ import jax.numpy as jnp
6
+
7
+ from einops import rearrange
8
+
9
+ from .attentions import FMHA
10
+
11
+ def log_cosh(x):
12
+ sgn_x = -2 * jnp.signbit(x.real) + 1
13
+ x = x * sgn_x
14
+ return x + jnp.log1p(jnp.exp(-2.0 * x)) - jnp.log(2.0)
15
+
16
+ def extract_patches1d(x, b):
17
+ return rearrange(x, 'batch (L_eff b) -> batch L_eff b', b=b)
18
+
19
+ def extract_patches2d(x, b):
20
+ batch = x.shape[0]
21
+ L_eff = int((x.shape[1] // b**2)**0.5)
22
+ x = x.reshape(batch, L_eff, b, L_eff, b) # [L_eff, b, L_eff, b]
23
+ x = x.transpose(0, 1, 3, 2, 4) # [L_eff, L_eff, b, b]
24
+ # flatten the patches
25
+ x = x.reshape(batch, L_eff, L_eff, -1) # [L_eff, L_eff, b*b]
26
+ x = x.reshape(batch, L_eff*L_eff, -1) # [L_eff*L_eff, b*b]
27
+ return x
28
+
29
+ class Embed(nn.Module):
30
+ d_model : int
31
+ b: int
32
+ two_dimensional: bool = False
33
+
34
+ def setup(self):
35
+ if self.two_dimensional:
36
+ self.extract_patches = extract_patches2d
37
+ else:
38
+ self.extract_patches = extract_patches1d
39
+
40
+ self.embed = nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64)
41
+
42
+ def __call__(self, x):
43
+ x = self.extract_patches(x, self.b)
44
+ x = self.embed(x)
45
+
46
+ return x
47
+
48
+ class EncoderBlock(nn.Module):
49
+ d_model : int
50
+ h: int
51
+ L_eff: int
52
+ transl_invariant: bool = True
53
+ two_dimensional: bool = False
54
+
55
+ def setup(self):
56
+ self.attn = FMHA(d_model=self.d_model, h=self.h, L_eff=self.L_eff, transl_invariant=self.transl_invariant, two_dimensional=self.two_dimensional)
57
+
58
+ self.layer_norm_1 = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64)
59
+ self.layer_norm_2 = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64)
60
+
61
+ self.ff = nn.Sequential([
62
+ nn.Dense(4*self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64),
63
+ nn.gelu,
64
+ nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64),
65
+ ])
66
+
67
+
68
+ def __call__(self, x):
69
+ x = x + self.attn(self.layer_norm_1(x))
70
+
71
+ x = x + self.ff( self.layer_norm_2(x) )
72
+ return x
73
+
74
+ class Encoder(nn.Module):
75
+ num_layers: int
76
+ d_model : int
77
+ h: int
78
+ L_eff: int
79
+ transl_invariant: bool = True
80
+ two_dimensional: bool = False
81
+
82
+ def setup(self):
83
+ self.layers = [EncoderBlock(d_model=self.d_model, h=self.h, L_eff=self.L_eff, transl_invariant=self.transl_invariant, two_dimensional=self.two_dimensional) for _ in range(self.num_layers)]
84
+
85
+ def __call__(self, x):
86
+
87
+ for l in self.layers:
88
+ x = l(x)
89
+
90
+ return x
91
+
92
+ class OuputHead(nn.Module):
93
+ d_model : int
94
+
95
+ def setup(self):
96
+ self.out_layer_norm = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64)
97
+
98
+ self.norm2 = nn.LayerNorm(use_scale=True, use_bias=True, dtype=jnp.float64, param_dtype=jnp.float64)
99
+ self.norm3 = nn.LayerNorm(use_scale=True, use_bias=True, dtype=jnp.float64, param_dtype=jnp.float64)
100
+
101
+ self.output_layer0 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
102
+ self.output_layer1 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
103
+
104
+ def __call__(self, x):
105
+
106
+ x = self.out_layer_norm(x.sum(axis=1))
107
+
108
+ amp = self.norm2(self.output_layer0(x))
109
+ sign = self.norm3(self.output_layer1(x))
110
+
111
+ z = amp + 1j*sign
112
+
113
+ return jnp.sum(log_cosh(z), axis=-1)
114
+
115
+ class ViT(nn.Module):
116
+ num_layers: int
117
+ d_model : int
118
+ heads: int
119
+ L_eff: int
120
+ b: int
121
+ transl_invariant: bool = True
122
+ two_dimensional: bool = False
123
+
124
+ def setup(self):
125
+ self.patches_and_embed = Embed(self.d_model, self.b, two_dimensional=self.two_dimensional)
126
+
127
+ self.encoder = Encoder(num_layers=self.num_layers, d_model=self.d_model, h=self.heads, L_eff=self.L_eff, transl_invariant=self.transl_invariant, two_dimensional=self.two_dimensional)
128
+
129
+ self.output = OuputHead(self.d_model)
130
+
131
+
132
+ def __call__(self, spins):
133
+ x = jnp.atleast_2d(spins)
134
+
135
+ x = self.patches_and_embed(x)
136
+
137
+ x = self.encoder(x)
138
+
139
+ z = self.output(x)
140
+
141
+ return z
vitnqs_config.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+
5
+ class ViTNQSConfig(PretrainedConfig):
6
+ model_type = "vit_nqs"
7
+
8
+ def __init__(
9
+ self,
10
+ L_eff=25,
11
+ num_layers = 8,
12
+ d_model = 72,
13
+ heads = 12,
14
+ b = 2,
15
+ tras_inv = True,
16
+ two_dim = True,
17
+ **kwargs,
18
+ ):
19
+ self.L_eff = L_eff
20
+ self.num_layers = num_layers
21
+ self.d_model = d_model
22
+ self.heads = heads
23
+ self.b = b
24
+ self.tras_inv = tras_inv
25
+ self.two_dim = two_dim
26
+ super().__init__(**kwargs)
vitnqs_model.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import FlaxPreTrainedModel
2
+ import jax.numpy as jnp
3
+ from .transformer import ViT
4
+ from .vitnqs_config import ViTNQSConfig
5
+
6
+
7
+ class ViTNQSModel(FlaxPreTrainedModel):
8
+ config_class = ViTNQSConfig
9
+
10
+ def __init__(
11
+ self,
12
+ config: ViTNQSConfig,
13
+ input_shape = jnp.zeros((1, 100)),
14
+ seed: int = 0,
15
+ dtype: jnp.dtype = jnp.float64,
16
+ _do_init: bool = True,
17
+ **kwargs,
18
+ ):
19
+ self.model = ViT(L_eff=config.L_eff,
20
+ num_layers=config.num_layers,
21
+ d_model=config.d_model,
22
+ heads=config.heads,
23
+ b=config.b,
24
+ transl_invariant=config.tras_inv,
25
+ two_dimensional=config.two_dim,
26
+ )
27
+
28
+ super().__init__(config, ViT, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
29
+
30
+ def __call__(self, params, spins):
31
+ return self.model.apply(params, spins)
32
+
33
+ def init_weights(self, rng, input_shape):
34
+ return self.model.init(rng, input_shape)