File size: 3,618 Bytes
75a14ff 8a773aa c642c7f 8a773aa f3b907f 75a14ff f3b907f 75a14ff 1280d3a 75a14ff f3b907f 13d207b 8a773aa 13d207b dadcf29 13d207b 75a14ff f3b907f 75a14ff 733f687 75a14ff f3b907f 75a14ff f3b907f 75a14ff f3b907f 75a14ff f3b907f 75a14ff 590b429 f3b907f 75a14ff 1c12df8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
---
library_name: transformers
tags: []
---
<!-- Provide a quick summary of what the model is/does. -->
Foundation Neural-Network Quantum State trained on the Ising in transverse field model on a chain with \\(L=100\\) sites.
The system is described by the following Hamiltonian (with periodic boundary conditions):
$$
\hat{H} = -J\sum_{i=1}^N \hat{S}_i^z \hat{S}_{i+1}^z - h \sum_{i=1}^N \hat{S}_i^x \ ,
$$
where \\(\hat{S}_i^x\\) and \\(\hat{S}_i^z\\) are spin- \\(1/2\\) operators on site \\(i\\).
The model has been trained on \\(R=6000\\) different values of the field \\(h\\) equispaced in the interval \\(h \in [0.8, 1.2]\\),
using a total batch size of \\(M=12000\\) samples.
The computation has been distributed over 4 A100-64GB GPUs for few hours.
## How to Get Started with the Model
Use the code below to get started with the model. In particular, we sample the model for a fixed value of the external field \\(h\\) using NetKet.
```python
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
import netket as nk
import flax
from flax.training import checkpoints
flax.config.update('flax_use_orbax_checkpointing', False)
lattice = nk.graph.Hypercube(length=100, n_dim=1, pbc=True)
revision = "main"
h = 1.0 #* fix the value of the external field
assert h >= 0.8 and h <= 1.2 #* the model has been trained on this interval
from transformers import FlaxAutoModel
wf = FlaxAutoModel.from_pretrained("nqs-models/ising_fnqs", trust_remote_code=True)
N_params = nk.jax.tree_size(wf.params)
print('Number of parameters = ', N_params, flush=True)
hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes)
hamiltonian = nk.operator.IsingJax(hilbert=hilbert, graph=lattice, h=h, J=-1.0)
action = nk.sampler.rules.LocalRule()
sampler = nk.sampler.MetropolisSampler(hilbert=hilbert,
rule=action,
n_chains=12000,
n_sweeps=lattice.n_nodes)
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key, 2)
vstate = nk.vqs.MCState(sampler=sampler,
apply_fun=partial(wf.__call__, coups=h),
sampler_seed=subkey,
n_samples=12000,
n_discard_per_chain=0,
variables=wf.params,
chunk_size=12000)
# start from thermalized configurations
from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id="nqs-models/ising_fnqs", filename="spins", revision=revision)
samples = checkpoints.restore_checkpoint(path, prefix="spins", target=None)
samples = jnp.array(samples, dtype='int8')
vstate.sampler_state = vstate.sampler_state.replace(σ = samples)
import time
# Sample the model
for _ in range(10):
start = time.time()
E = vstate.expect(hamiltonian)
vstate.sample()
print("Mean: ", E.mean.real / lattice.n_nodes, "\t time=", time.time()-start)
```
The time per sweep is 3.5s, evaluated on a single A100-40GB GPU.
### Extract hidden representation
The hidden representation associated to the input batch of configurations can be extracted as:
```python
wf = FlaxAutoModel.from_pretrained("nqs-models/ising_fnqs", trust_remote_code=True, return_z=True)
z = wf(wf.params, samples, h)
```
#### Training Hyperparameters
Number of layers: 6
Embedding dimension: 72
Hidden dimension: 144
Number of heads: 12
Patch size: 4
Total number of parameters: 198288
## Model Card Contact
Riccardo Rende ([email protected])
Luciano Loris Viteritti ([email protected]) |