File size: 4,884 Bytes
d30db4c e5f9751 d30db4c c97c16d 358fd2d c97c16d f6f84f0 c97c16d 399c149 d30db4c e5f9751 a01c7fd e5f9751 a01c7fd e5f9751 d30db4c c97c16d e9e43dd c97c16d d30db4c 3a1e09f 9b489d9 3a1e09f d30db4c 5cabed2 bdf733d d30db4c e5f9751 d30db4c e5f9751 d30db4c e5f9751 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
---
library_name: transformers
license: apache-2.0
---
<!-- Provide a quick summary of what the model is/does. -->
Pretrained Vision Transformer Neural Quantum State on the \\(J_1\\) - \\(J_2\\) Heinseberg model on a \\(10\times10\\) square lattice.
The frustration ratio is set to \\(J_2/J_1=0.5\\).
| Revision | Variational energy | Time per sweep | Description |
|:---------------:|:------------------:|:--------------:|:---------------------------------------------------------------:|
| main | -0.497505103 | 41s | Plain ViT with translation invariance among patches |
| symm_t | -0.49760546 | 166s | ViT with translational symmetry |
| symm_trxy_ising | **-0.497676335** | 3317s | ViT with translational, point group and sz inversion symmetries |
The time per sweep is evaluated on a single A100-40GB GPU.
The architecture has been trained by distributing the computation over 40 A100-64GB GPUs for about four days.
## Citation
https://www.nature.com/articles/s42005-024-01732-4
## How to Get Started with the Model
Use the code below to get started with the model. In particular, we sample the model using NetKet.
```python
import jax
import jax.numpy as jnp
import netket as nk
import flax
from flax.training import checkpoints
flax.config.update('flax_use_orbax_checkpointing', False)
# Load the model from HuggingFace
from transformers import FlaxAutoModel
wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_10x10", trust_remote_code=True)
N_params = nk.jax.tree_size(wf.params)
print('Number of parameters = ', N_params, flush=True)
lattice = nk.graph.Hypercube(length=10, n_dim=2, pbc=True, max_neighbor_order=2)
hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes, total_sz=0)
hamiltonian = nk.operator.Heisenberg(hilbert=hilbert,
graph=lattice,
J=[1.0, 0.5],
sign_rule=[False, False]).to_jax_operator() # No Marshall sign rule
sampler = nk.sampler.MetropolisExchange(hilbert=hilbert,
graph=lattice,
d_max=2,
n_chains=16384,
sweep_size=lattice.n_nodes)
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key, 2)
vstate = nk.vqs.MCState(sampler=sampler,
apply_fun=wf.__call__,
sampler_seed=subkey,
n_samples=16384,
n_discard_per_chain=0,
variables=wf.params,
chunk_size=16384)
# Overwrite samples with already thermalized ones
from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id="nqs-models/j1j2_square_10x10", filename="spins")
samples = checkpoints.restore_checkpoint(ckpt_dir=path, prefix="spins", target=None)
samples = jnp.array(samples, dtype='int8')
vstate.sampler_state = vstate.sampler_state.replace(σ = samples)
# Sample the model
for _ in range(10):
E = vstate.expect(hamiltonian)
print("Mean: ", E.mean.real / lattice.n_nodes / 4)
vstate.sample()
```
The expected output is:
> Number of parameters = 434760
> Mean: -0.4975034481394982
> Mean: -0.4975697817150899
> Mean: -0.49753878662981793
> Mean: -0.49749150331671876
> Mean: -0.4975093308123018
> Mean: -0.49755810175173776
> Mean: -0.49753726455462444
> Mean: -0.49748956161946795
> Mean: -0.497479875901942
> Mean: -0.49752966071413424
The fully translational invariant wavefunction can be also be downloaded using:
```python
wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_10x10", trust_remote_code=True, revision="symm_t")
```
Use `revision="symm_trxy_ising"` for a wavefunction including also the point group and the sz inversion symmetries.
### Extract hidden representation
The hidden representation associated to the input batch of configurations can be extracted as:
```python
wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_10x10", trust_remote_code=True, return_z=True)
z = wf(wf.params, samples)
```
Starting from the vector \\(z\\), a fully connected network can be trained to *fine-tune* the model on a different value of the ratio \\(J_2/J_1\\).
See https://doi.org/10.1103/PhysRevResearch.6.023057 for more informations.
Note: the hidden representation is well defined only for the non symmetrized model.
#### Training Hyperparameters
Number of layers: 8
Embedding dimension: 72
Hidden dimension: 288
Number of heads: 12
Total number of parameters: 434760
## Model Card Contact
Riccardo Rende ([email protected])
Luciano Loris Viteritti ([email protected]) |