|
|
--- |
|
|
library_name: transformers |
|
|
tags: [] |
|
|
--- |
|
|
|
|
|
<!-- Provide a quick summary of what the model is/does. --> |
|
|
Foundation Neural-Network Quantum State trained on the two-dimension \\(J_1\\)- \\(J_2\\) Heisenberg on a \\(10\times 10\\) square lattice. |
|
|
The system is described by the following Hamiltonian (with periodic boundary conditions): |
|
|
|
|
|
$$ |
|
|
\hat{H} = J_1\!\!\sum_{\langle {\boldsymbol{r}},{\boldsymbol{r'}} \rangle} \hat{\boldsymbol{S}}_{\boldsymbol{r}}\cdot\hat{\boldsymbol{S}}_{\boldsymbol{r'}} |
|
|
+ J_2 \!\!\!\!\sum_{\langle \langle {\boldsymbol{r}},{\boldsymbol{r'}} \rangle \rangle} \!\!\!\hat{\boldsymbol{S}}_{\boldsymbol{r}}\cdot\hat{\boldsymbol{S}}_{\boldsymbol{r'}} \ . |
|
|
$$ |
|
|
|
|
|
The architecture has been trained on \\(R=100\\) different couplings \\(J_2\\) equispaced in the interval \\(J_2 \in [0.4, 0.6]\\), |
|
|
using a total batch size of \\(M=16000\\) samples. |
|
|
|
|
|
The computation has been distributed over 4 A100-64GB GPUs for few hours. |
|
|
|
|
|
|
|
|
## How to Get Started with the Model |
|
|
|
|
|
Use the code below to get started with the model. In particular, we sample the architecture for a fixed value of the coupling \\(J_2\\) using NetKet. |
|
|
|
|
|
```python |
|
|
from functools import partial |
|
|
import numpy as np |
|
|
|
|
|
import jax |
|
|
import jax.numpy as jnp |
|
|
import netket as nk |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
import flax |
|
|
from flax.training import checkpoints |
|
|
|
|
|
flax.config.update('flax_use_orbax_checkpointing', False) |
|
|
|
|
|
lattice = nk.graph.Hypercube(length=10, n_dim=2, pbc=True, max_neighbor_order=2) |
|
|
|
|
|
J2 = 0.5 |
|
|
|
|
|
assert J2 >= 0.4 and J2 <= 0.6 #* the model has been trained on this interval |
|
|
|
|
|
from transformers import FlaxAutoModel |
|
|
wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_fnqs", trust_remote_code=True) |
|
|
N_params = nk.jax.tree_size(wf.params) |
|
|
print('Number of parameters = ', N_params, flush=True) |
|
|
|
|
|
hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes, total_sz=0) |
|
|
hamiltonian = nk.operator.Heisenberg(hilbert=hilbert, |
|
|
graph=lattice, |
|
|
J=[1.0, J2], |
|
|
sign_rule=[False, False]).to_jax_operator() # No Marshall sign rule |
|
|
|
|
|
sampler = nk.sampler.MetropolisExchange(hilbert=hilbert, |
|
|
graph=lattice, |
|
|
d_max=2, |
|
|
n_chains=16000, |
|
|
sweep_size=lattice.n_nodes) |
|
|
|
|
|
key = jax.random.PRNGKey(0) |
|
|
key, subkey = jax.random.split(key, 2) |
|
|
vstate = nk.vqs.MCState(sampler=sampler, |
|
|
apply_fun=partial(wf.__call__, coups=J2), |
|
|
sampler_seed=subkey, |
|
|
n_samples=16000, |
|
|
n_discard_per_chain=0, |
|
|
variables=wf.params, |
|
|
chunk_size=16000) |
|
|
|
|
|
# Overwrite samples with already thermalized ones |
|
|
path = hf_hub_download(repo_id="nqs-models/j1j2_square_fnqs", filename="spins") |
|
|
samples = checkpoints.restore_checkpoint(ckpt_dir=path, prefix="spins", target=None) |
|
|
samples = jnp.array(samples, dtype='int8') |
|
|
vstate.sampler_state = vstate.sampler_state.replace(σ = samples) |
|
|
|
|
|
import time |
|
|
# Sample the model |
|
|
for _ in range(10): |
|
|
start = time.time() |
|
|
E = vstate.expect(hamiltonian) |
|
|
vstate.sample() |
|
|
|
|
|
print("Mean: ", E.mean.real / lattice.n_nodes / 4, "\t time=", time.time()-start) |
|
|
``` |
|
|
|
|
|
The time per sweep is 21s, evaluated on a single A100-40GB GPU. |
|
|
|
|
|
### Extract hidden representation |
|
|
|
|
|
The hidden representation associated to the input batch of configurations can be extracted as: |
|
|
|
|
|
```python |
|
|
wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_fnqs", trust_remote_code=True, return_z=True) |
|
|
|
|
|
z = wf(wf.params, samples, J2) |
|
|
``` |
|
|
|
|
|
#### Training Hyperparameters |
|
|
|
|
|
Number of layers: 4 |
|
|
Embedding dimension: 72 |
|
|
Hidden dimension: 288 |
|
|
Number of heads: 12 |
|
|
Patch size: 2x2 |
|
|
|
|
|
Total number of parameters: 223,104 |
|
|
|
|
|
|
|
|
## Model Card Contact |
|
|
|
|
|
Riccardo Rende ([email protected]) |
|
|
Luciano Loris Viteritti ([email protected]) |