rrende commited on
Commit
13d207b
·
verified ·
1 Parent(s): f3b907f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +55 -0
README.md CHANGED
@@ -18,6 +18,61 @@ The computation has been distributed over 4 A100-64GB GPUs for few hours.
18
  Use the code below to get started with the model. In particular, we sample the model for a fixed value of the externa field \\(h\\) using NetKet.
19
 
20
  ```python
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  ```
23
 
 
18
  Use the code below to get started with the model. In particular, we sample the model for a fixed value of the externa field \\(h\\) using NetKet.
19
 
20
  ```python
21
+ from functools import partial
22
+ import numpy as np
23
+
24
+ import jax
25
+ import jax.numpy as jnp
26
+ import netket as nk
27
+
28
+ import flax
29
+ from flax.training import checkpoints
30
+
31
+ flax.config.update('flax_use_orbax_checkpointing', False)
32
+
33
+ lattice = nk.graph.Hypercube(length=100, n_dim=1, pbc=True)
34
+
35
+ revision = "main"
36
+ h = 1.0 #* fix the value of the external field
37
+
38
+ from transformers import FlaxAutoModel
39
+ wf = FlaxAutoModel.from_pretrained("nqs-models/ising_fnqs", trust_remote_code=True)
40
+ N_params = nk.jax.tree_size(wf.params)
41
+ print('Number of parameters = ', N_params, flush=True)
42
+
43
+ hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes)
44
+ hamiltonian = nk.operator.IsingJax(hilbert=hilbert, graph=lattice, h=h, J=-1.0)
45
+
46
+ action = nk.sampler.rules.LocalRule()
47
+ sampler = nk.sampler.MetropolisSampler(hilbert=hilbert,
48
+ rule=action,
49
+ n_chains=12000,
50
+ n_sweeps=lattice.n_nodes)
51
+
52
+ key = jax.random.PRNGKey(0)
53
+ key, subkey = jax.random.split(key, 2)
54
+ vstate = nk.vqs.MCState(sampler=sampler,
55
+ apply_fun=partial(wf.__call__, coups=h),
56
+ sampler_seed=subkey,
57
+ n_samples=12000,
58
+ n_discard_per_chain=0,
59
+ variables=wf.params,
60
+ chunk_size=12000)
61
+
62
+ from huggingface_hub import hf_hub_download
63
+ path = hf_hub_download(repo_id="nqs-models/ising_fnqs", filename="spins", revision=revision)
64
+ samples = checkpoints.restore_checkpoint(path, prefix="spins", target=None)
65
+ samples = jnp.array(samples, dtype='int8')
66
+ vstate.sampler_state = vstate.sampler_state.replace(σ = samples)
67
+
68
+ import time
69
+ # Sample the model
70
+ for _ in range(10):
71
+ start = time.time()
72
+ E = vstate.expect(hamiltonian)
73
+ vstate.sample()
74
+
75
+ print("Mean: ", E.mean.real / lattice.n_nodes, "\t time=", time.time()-start)
76
 
77
  ```
78