llviteritti's picture
Update README.md
f6f84f0 verified
---
library_name: transformers
license: apache-2.0
---
<!-- Provide a quick summary of what the model is/does. -->
Pretrained Vision Transformer Neural Quantum State on the \\(J_1\\) - \\(J_2\\) Heinseberg model on a \\(10\times10\\) square lattice.
The frustration ratio is set to \\(J_2/J_1=0.5\\).
| Revision | Variational energy | Time per sweep | Description |
|:---------------:|:------------------:|:--------------:|:---------------------------------------------------------------:|
| main | -0.497505103 | 41s | Plain ViT with translation invariance among patches |
| symm_t | -0.49760546 | 166s | ViT with translational symmetry |
| symm_trxy_ising | **-0.497676335** | 3317s | ViT with translational, point group and sz inversion symmetries |
The time per sweep is evaluated on a single A100-40GB GPU.
The architecture has been trained by distributing the computation over 40 A100-64GB GPUs for about four days.
## Citation
https://www.nature.com/articles/s42005-024-01732-4
## How to Get Started with the Model
Use the code below to get started with the model. In particular, we sample the model using NetKet.
```python
import jax
import jax.numpy as jnp
import netket as nk
import flax
from flax.training import checkpoints
flax.config.update('flax_use_orbax_checkpointing', False)
# Load the model from HuggingFace
from transformers import FlaxAutoModel
wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_10x10", trust_remote_code=True)
N_params = nk.jax.tree_size(wf.params)
print('Number of parameters = ', N_params, flush=True)
lattice = nk.graph.Hypercube(length=10, n_dim=2, pbc=True, max_neighbor_order=2)
hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes, total_sz=0)
hamiltonian = nk.operator.Heisenberg(hilbert=hilbert,
graph=lattice,
J=[1.0, 0.5],
sign_rule=[False, False]).to_jax_operator() # No Marshall sign rule
sampler = nk.sampler.MetropolisExchange(hilbert=hilbert,
graph=lattice,
d_max=2,
n_chains=16384,
sweep_size=lattice.n_nodes)
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key, 2)
vstate = nk.vqs.MCState(sampler=sampler,
apply_fun=wf.__call__,
sampler_seed=subkey,
n_samples=16384,
n_discard_per_chain=0,
variables=wf.params,
chunk_size=16384)
# Overwrite samples with already thermalized ones
from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id="nqs-models/j1j2_square_10x10", filename="spins")
samples = checkpoints.restore_checkpoint(ckpt_dir=path, prefix="spins", target=None)
samples = jnp.array(samples, dtype='int8')
vstate.sampler_state = vstate.sampler_state.replace(σ = samples)
# Sample the model
for _ in range(10):
E = vstate.expect(hamiltonian)
print("Mean: ", E.mean.real / lattice.n_nodes / 4)
vstate.sample()
```
The expected output is:
> Number of parameters = 434760
> Mean: -0.4975034481394982
> Mean: -0.4975697817150899
> Mean: -0.49753878662981793
> Mean: -0.49749150331671876
> Mean: -0.4975093308123018
> Mean: -0.49755810175173776
> Mean: -0.49753726455462444
> Mean: -0.49748956161946795
> Mean: -0.497479875901942
> Mean: -0.49752966071413424
The fully translational invariant wavefunction can be also be downloaded using:
```python
wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_10x10", trust_remote_code=True, revision="symm_t")
```
Use `revision="symm_trxy_ising"` for a wavefunction including also the point group and the sz inversion symmetries.
### Extract hidden representation
The hidden representation associated to the input batch of configurations can be extracted as:
```python
wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_10x10", trust_remote_code=True, return_z=True)
z = wf(wf.params, samples)
```
Starting from the vector \\(z\\), a fully connected network can be trained to *fine-tune* the model on a different value of the ratio \\(J_2/J_1\\).
See https://doi.org/10.1103/PhysRevResearch.6.023057 for more informations.
Note: the hidden representation is well defined only for the non symmetrized model.
#### Training Hyperparameters
Number of layers: 8
Embedding dimension: 72
Hidden dimension: 288
Number of heads: 12
Total number of parameters: 434760
## Model Card Contact
Riccardo Rende ([email protected])
Luciano Loris Viteritti ([email protected])