Update README.md
Browse files
README.md
CHANGED
@@ -18,6 +18,61 @@ The computation has been distributed over 4 A100-64GB GPUs for few hours.
|
|
18 |
Use the code below to get started with the model. In particular, we sample the model for a fixed value of the externa field \\(h\\) using NetKet.
|
19 |
|
20 |
```python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
```
|
23 |
|
|
|
18 |
Use the code below to get started with the model. In particular, we sample the model for a fixed value of the externa field \\(h\\) using NetKet.
|
19 |
|
20 |
```python
|
21 |
+
from functools import partial
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
import jax
|
25 |
+
import jax.numpy as jnp
|
26 |
+
import netket as nk
|
27 |
+
|
28 |
+
import flax
|
29 |
+
from flax.training import checkpoints
|
30 |
+
|
31 |
+
flax.config.update('flax_use_orbax_checkpointing', False)
|
32 |
+
|
33 |
+
lattice = nk.graph.Hypercube(length=100, n_dim=1, pbc=True)
|
34 |
+
|
35 |
+
revision = "main"
|
36 |
+
h = 1.0 #* fix the value of the external field
|
37 |
+
|
38 |
+
from transformers import FlaxAutoModel
|
39 |
+
wf = FlaxAutoModel.from_pretrained("nqs-models/ising_fnqs", trust_remote_code=True)
|
40 |
+
N_params = nk.jax.tree_size(wf.params)
|
41 |
+
print('Number of parameters = ', N_params, flush=True)
|
42 |
+
|
43 |
+
hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes)
|
44 |
+
hamiltonian = nk.operator.IsingJax(hilbert=hilbert, graph=lattice, h=h, J=-1.0)
|
45 |
+
|
46 |
+
action = nk.sampler.rules.LocalRule()
|
47 |
+
sampler = nk.sampler.MetropolisSampler(hilbert=hilbert,
|
48 |
+
rule=action,
|
49 |
+
n_chains=12000,
|
50 |
+
n_sweeps=lattice.n_nodes)
|
51 |
+
|
52 |
+
key = jax.random.PRNGKey(0)
|
53 |
+
key, subkey = jax.random.split(key, 2)
|
54 |
+
vstate = nk.vqs.MCState(sampler=sampler,
|
55 |
+
apply_fun=partial(wf.__call__, coups=h),
|
56 |
+
sampler_seed=subkey,
|
57 |
+
n_samples=12000,
|
58 |
+
n_discard_per_chain=0,
|
59 |
+
variables=wf.params,
|
60 |
+
chunk_size=12000)
|
61 |
+
|
62 |
+
from huggingface_hub import hf_hub_download
|
63 |
+
path = hf_hub_download(repo_id="nqs-models/ising_fnqs", filename="spins", revision=revision)
|
64 |
+
samples = checkpoints.restore_checkpoint(path, prefix="spins", target=None)
|
65 |
+
samples = jnp.array(samples, dtype='int8')
|
66 |
+
vstate.sampler_state = vstate.sampler_state.replace(σ = samples)
|
67 |
+
|
68 |
+
import time
|
69 |
+
# Sample the model
|
70 |
+
for _ in range(10):
|
71 |
+
start = time.time()
|
72 |
+
E = vstate.expect(hamiltonian)
|
73 |
+
vstate.sample()
|
74 |
+
|
75 |
+
print("Mean: ", E.mean.real / lattice.n_nodes, "\t time=", time.time()-start)
|
76 |
|
77 |
```
|
78 |
|