File size: 3,618 Bytes
75a14ff
 
 
 
 
 
8a773aa
c642c7f
 
 
 
 
 
 
 
 
8a773aa
f3b907f
75a14ff
f3b907f
75a14ff
 
 
 
1280d3a
75a14ff
f3b907f
13d207b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a773aa
 
13d207b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dadcf29
13d207b
 
 
 
 
 
 
 
 
 
 
 
 
 
75a14ff
f3b907f
75a14ff
733f687
75a14ff
f3b907f
75a14ff
f3b907f
75a14ff
f3b907f
 
75a14ff
f3b907f
 
75a14ff
 
 
590b429
f3b907f
 
 
 
75a14ff
1c12df8
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
---
library_name: transformers
tags: []
---

<!-- Provide a quick summary of what the model is/does. -->
Foundation Neural-Network Quantum State trained on the Ising in transverse field model on a chain with \\(L=100\\) sites. 
The system is described by the following Hamiltonian (with periodic boundary conditions):

$$
    \hat{H} = -J\sum_{i=1}^N \hat{S}_i^z \hat{S}_{i+1}^z - h \sum_{i=1}^N \hat{S}_i^x \ ,
$$

where \\(\hat{S}_i^x\\) and \\(\hat{S}_i^z\\) are spin- \\(1/2\\) operators on site \\(i\\).


The model has been trained on \\(R=6000\\) different values of the field \\(h\\) equispaced in the interval \\(h \in [0.8, 1.2]\\),
using a total batch size of \\(M=12000\\) samples.

The computation has been distributed over 4 A100-64GB GPUs for few hours. 


## How to Get Started with the Model

Use the code below to get started with the model. In particular, we sample the model for a fixed value of the external field \\(h\\) using NetKet.

```python
from functools import partial
import numpy as np

import jax
import jax.numpy as jnp
import netket as nk

import flax
from flax.training import checkpoints

flax.config.update('flax_use_orbax_checkpointing', False)

lattice = nk.graph.Hypercube(length=100, n_dim=1, pbc=True)

revision = "main"
h = 1.0 #* fix the value of the external field

assert h >= 0.8 and h <= 1.2 #* the model has been trained on this interval

from transformers import FlaxAutoModel
wf = FlaxAutoModel.from_pretrained("nqs-models/ising_fnqs", trust_remote_code=True)
N_params = nk.jax.tree_size(wf.params)
print('Number of parameters = ', N_params, flush=True)

hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes)
hamiltonian = nk.operator.IsingJax(hilbert=hilbert, graph=lattice, h=h, J=-1.0)

action = nk.sampler.rules.LocalRule()
sampler = nk.sampler.MetropolisSampler(hilbert=hilbert, 
                                       rule=action, 
                                       n_chains=12000, 
                                       n_sweeps=lattice.n_nodes)

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key, 2)
vstate = nk.vqs.MCState(sampler=sampler, 
                        apply_fun=partial(wf.__call__, coups=h), 
                        sampler_seed=subkey,
                        n_samples=12000, 
                        n_discard_per_chain=0,
                        variables=wf.params,
                        chunk_size=12000)

# start from thermalized configurations
from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id="nqs-models/ising_fnqs", filename="spins", revision=revision)
samples = checkpoints.restore_checkpoint(path, prefix="spins", target=None)
samples = jnp.array(samples, dtype='int8')
vstate.sampler_state = vstate.sampler_state.replace(σ = samples)

import time
# Sample the model
for _ in range(10):
    start = time.time()
    E = vstate.expect(hamiltonian)
    vstate.sample()

    print("Mean: ", E.mean.real / lattice.n_nodes, "\t time=", time.time()-start)

```

The time per sweep is 3.5s, evaluated on a single A100-40GB GPU.

### Extract hidden representation

The hidden representation associated to the input batch of configurations can be extracted as:

```python
wf = FlaxAutoModel.from_pretrained("nqs-models/ising_fnqs", trust_remote_code=True, return_z=True)

z = wf(wf.params, samples, h)
```

#### Training Hyperparameters

Number of layers: 6  
Embedding dimension: 72   
Hidden dimension: 144  
Number of heads: 12  
Patch size: 4

Total number of parameters: 198288 


## Model Card Contact

Riccardo Rende ([email protected])  
Luciano Loris Viteritti ([email protected])