File size: 4,189 Bytes
2c57478 591f724 1d95bf6 505d764 1d95bf6 5ca7cf9 8496906 591f724 5ca7cf9 8496906 2c57478 8c72881 591f724 8c72881 505d764 8c72881 2c57478 5ca7cf9 bdf2615 5ca7cf9 76f30da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
---
library_name: transformers
tags: []
---
<!-- Provide a quick summary of what the model is/does. -->
Foundation Neural-Network Quantum State trained on the Ising in disordered transverse field model on a chain with \\(L\\) sites. The Hamiltonian (assuming periodic boundary conditions) is given by:
$$
\hat{H} = -J\sum_{i=1}^N \hat{S}_i^z \hat{S}_{i+1}^z - \sum_{i=1}^N h_i \hat{S}_i^x \ ,
$$
where \\(h_i\\) is the on-site transverse magnetic field at the \\(i\\)-th site.
In the disordered case, \\(h_i\\) varies randomly along the chain, drawn independently and identically from the uniform distribution on the interval \\([0, h_0]\\).
Several values of the external field intensity \\(h_0\\) are available (check the different revisions).
The architecture has been trained on \\(R=2000\\) different disorder realization for a fixed value of \\(h_0\\), using a total batch size of \\(M=10000\\) samples.
The computation has been distributed over 4 A100-64GB GPUs for about two hours.
## How to Get Started with the Model
Use the code below to get started with the model. In particular, we sample the architecture for a fixed disordered realization using NetKet.
```python
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
import netket as nk
import math
import flax
from flax.training import checkpoints
from netket.operator.spin import sigmax,sigmaz
flax.config.update('flax_use_orbax_checkpointing', False)
h0 = 1.0 #* fix the value of the external field
L = 32
revision = f"L={L}_h={h0}" #check the revisions for the available values of h0 and L
from transformers import FlaxAutoModel
wf = FlaxAutoModel.from_pretrained("nqs-models/ising_disorder_fnqs",
trust_remote_code=True,
revision=revision,
)
N_params = nk.jax.tree_size(wf.params)
print('Number of parameters = ', N_params, flush=True)
lattice = nk.graph.Hypercube(length=L, n_dim=1, pbc=True)
J = -1.0/math.e
key = jax.random.key(0)
h = jax.random.uniform(key, shape=(L,))
hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes)
hamiltonian = sum([(-h[i]*h0)*sigmax(hilbert,i) for i in range(L)])
hamiltonian += sum([J*sigmaz(hilbert,i)*sigmaz(hilbert,(i+1)%L) for i in range(L)])
action = nk.sampler.rules.LocalRule()
sampler = nk.sampler.MetropolisSampler(hilbert=hilbert,
rule=action,
n_chains=10000,
n_sweeps=lattice.n_nodes)
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key, 2)
vstate = nk.vqs.MCState(sampler=sampler,
apply_fun=partial(wf.__call__, coups=h),
sampler_seed=subkey,
n_samples=10000,
n_discard_per_chain=0,
variables=wf.params,
chunk_size=10000)
from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id="nqs-models/ising_disorder_fnqs", filename="spins", revision=revision)
samples = checkpoints.restore_checkpoint(path, prefix="spins", target=None)
samples = jnp.array(samples, dtype='int8') # some netket versions require floats
vstate.sampler_state = vstate.sampler_state.replace(σ = samples)
import time
# Sample the model
for _ in range(10):
start = time.time()
E = vstate.expect(hamiltonian)
vstate.sample()
print("Mean: ", E.mean.real / lattice.n_nodes, "\t time=", time.time()-start)
```
The time per sweep is 0.5s, evaluated on a single A100-40GB GPU.
### Extract hidden representation
The hidden representation associated to the input batch of configurations can be extracted as:
```python
wf = FlaxAutoModel.from_pretrained("nqs-models/ising_disorder_fnqs", trust_remote_code=True, return_z=True)
z = wf(wf.params, samples, h)
```
#### Training Hyperparameters
Number of layers: 6
Embedding dimension: 72
Hidden dimension: 288
Number of heads: 12
Patch size: 4
Total number of parameters: 326124
## Contacts
Riccardo Rende ([email protected])
Luciano Loris Viteritti ([email protected]) |