Foundation Neural-Network Quantum State trained on the two-dimension - Heisenberg on a square lattice. The system is described by the following Hamiltonian (with periodic boundary conditions):
The architecture has been trained on different couplings equispaced in the interval , using a total batch size of samples.
The computation has been distributed over 4 A100-64GB GPUs for few hours.
How to Get Started with the Model
Use the code below to get started with the model. In particular, we sample the architecture for a fixed value of the coupling using NetKet.
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
import netket as nk
from huggingface_hub import hf_hub_download
import flax
from flax.training import checkpoints
flax.config.update('flax_use_orbax_checkpointing', False)
lattice = nk.graph.Hypercube(length=10, n_dim=2, pbc=True, max_neighbor_order=2)
J2 = 0.5
assert J2 >= 0.4 and J2 <= 0.6 #* the model has been trained on this interval
from transformers import FlaxAutoModel
wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_fnqs", trust_remote_code=True)
N_params = nk.jax.tree_size(wf.params)
print('Number of parameters = ', N_params, flush=True)
hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes, total_sz=0)
hamiltonian = nk.operator.Heisenberg(hilbert=hilbert,
graph=lattice,
J=[1.0, J2],
sign_rule=[False, False]).to_jax_operator() # No Marshall sign rule
sampler = nk.sampler.MetropolisExchange(hilbert=hilbert,
graph=lattice,
d_max=2,
n_chains=16000,
sweep_size=lattice.n_nodes)
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key, 2)
vstate = nk.vqs.MCState(sampler=sampler,
apply_fun=partial(wf.__call__, coups=J2),
sampler_seed=subkey,
n_samples=16000,
n_discard_per_chain=0,
variables=wf.params,
chunk_size=16000)
# Overwrite samples with already thermalized ones
path = hf_hub_download(repo_id="nqs-models/j1j2_square_fnqs", filename="spins")
samples = checkpoints.restore_checkpoint(ckpt_dir=path, prefix="spins", target=None)
samples = jnp.array(samples, dtype='int8')
vstate.sampler_state = vstate.sampler_state.replace(Ο = samples)
import time
# Sample the model
for _ in range(10):
start = time.time()
E = vstate.expect(hamiltonian)
vstate.sample()
print("Mean: ", E.mean.real / lattice.n_nodes / 4, "\t time=", time.time()-start)
The time per sweep is 21s, evaluated on a single A100-40GB GPU.
Extract hidden representation
The hidden representation associated to the input batch of configurations can be extracted as:
wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_fnqs", trust_remote_code=True, return_z=True)
z = wf(wf.params, samples, J2)
Training Hyperparameters
Number of layers: 4
Embedding dimension: 72
Hidden dimension: 288
Number of heads: 12
Patch size: 2x2
Total number of parameters: 223,104
Model Card Contact
Riccardo Rende ([email protected])
Luciano Loris Viteritti ([email protected])
- Downloads last month
- 32