File size: 4,884 Bytes
d30db4c
 
e5f9751
d30db4c
 
c97c16d
 
 
 
 
 
 
358fd2d
c97c16d
 
 
f6f84f0
c97c16d
399c149
 
 
d30db4c
 
 
e5f9751
 
 
 
 
 
 
 
 
 
 
a01c7fd
e5f9751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a01c7fd
e5f9751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d30db4c
c97c16d
e9e43dd
 
 
 
 
c97c16d
d30db4c
3a1e09f
 
 
 
 
 
 
9b489d9
3a1e09f
d30db4c
5cabed2
 
 
bdf733d
 
d30db4c
 
e5f9751
 
 
 
d30db4c
e5f9751
d30db4c
 
 
e5f9751
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
---
library_name: transformers
license: apache-2.0
---
<!-- Provide a quick summary of what the model is/does. -->
Pretrained Vision Transformer Neural Quantum State on the \\(J_1\\) - \\(J_2\\) Heinseberg model on a \\(10\times10\\) square lattice. 
The frustration ratio is set to \\(J_2/J_1=0.5\\).

|     Revision    | Variational energy | Time per sweep |                           Description                           |
|:---------------:|:------------------:|:--------------:|:---------------------------------------------------------------:|
|       main      |    -0.497505103    |      41s       |       Plain ViT with translation invariance among patches       |
|      symm_t     |     -0.49760546    |      166s      |                 ViT with translational symmetry                 |
| symm_trxy_ising |  **-0.497676335**  |      3317s     | ViT with translational, point group and sz inversion symmetries |

The time per sweep is evaluated on a single A100-40GB GPU.

The architecture has been trained by distributing the computation over 40 A100-64GB GPUs for about four days.

## Citation

https://www.nature.com/articles/s42005-024-01732-4

## How to Get Started with the Model

Use the code below to get started with the model. In particular, we sample the model using NetKet.

```python
import jax
import jax.numpy as jnp
import netket as nk
import flax
from flax.training import checkpoints
flax.config.update('flax_use_orbax_checkpointing', False)
# Load the model from HuggingFace
from transformers import FlaxAutoModel
wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_10x10", trust_remote_code=True)
N_params = nk.jax.tree_size(wf.params)
print('Number of parameters = ', N_params, flush=True)
lattice = nk.graph.Hypercube(length=10, n_dim=2, pbc=True, max_neighbor_order=2)
hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes, total_sz=0)
hamiltonian = nk.operator.Heisenberg(hilbert=hilbert, 
                                    graph=lattice, 
                                    J=[1.0, 0.5], 
                                    sign_rule=[False, False]).to_jax_operator() # No Marshall sign rule
sampler = nk.sampler.MetropolisExchange(hilbert=hilbert,
                                        graph=lattice,
                                        d_max=2,
                                        n_chains=16384,
                                        sweep_size=lattice.n_nodes)
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key, 2)
vstate = nk.vqs.MCState(sampler=sampler, 
                        apply_fun=wf.__call__, 
                        sampler_seed=subkey,
                        n_samples=16384, 
                        n_discard_per_chain=0,
                        variables=wf.params,
                        chunk_size=16384)
# Overwrite samples with already thermalized ones
from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id="nqs-models/j1j2_square_10x10", filename="spins")
samples = checkpoints.restore_checkpoint(ckpt_dir=path, prefix="spins", target=None)
samples = jnp.array(samples, dtype='int8')
vstate.sampler_state = vstate.sampler_state.replace(σ = samples)
# Sample the model
for _ in range(10):
    E = vstate.expect(hamiltonian)
    print("Mean: ", E.mean.real / lattice.n_nodes / 4)
    vstate.sample()
```

The expected output is:

> Number of parameters =  434760  
> Mean:  -0.4975034481394982  
> Mean:  -0.4975697817150899  
> Mean:  -0.49753878662981793  
> Mean:  -0.49749150331671876  
> Mean:  -0.4975093308123018  
> Mean:  -0.49755810175173776  
> Mean:  -0.49753726455462444  
> Mean:  -0.49748956161946795  
> Mean:  -0.497479875901942  
> Mean:  -0.49752966071413424  

The fully translational invariant wavefunction can be also be downloaded using:

```python
wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_10x10", trust_remote_code=True, revision="symm_t")
```

Use `revision="symm_trxy_ising"` for a wavefunction including also the point group and the sz inversion symmetries.

### Extract hidden representation

The hidden representation associated to the input batch of configurations can be extracted as:

```python
wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_10x10", trust_remote_code=True, return_z=True)

z = wf(wf.params, samples)
```

Starting from the vector \\(z\\), a fully connected network can be trained to *fine-tune* the model on a different value of the ratio \\(J_2/J_1\\).
See https://doi.org/10.1103/PhysRevResearch.6.023057 for more informations.

Note: the hidden representation is well defined only for the non symmetrized model.

#### Training Hyperparameters

Number of layers: 8  
Embedding dimension: 72   
Hidden dimension: 288  
Number of heads: 12  

Total number of parameters: 434760

## Model Card Contact

Riccardo Rende ([email protected])  
Luciano Loris Viteritti ([email protected])