Update README.md
Browse files
README.md
CHANGED
@@ -18,7 +18,7 @@ from flax.training import checkpoints
|
|
18 |
flax.config.update('flax_use_orbax_checkpointing', False)
|
19 |
# Load the model from HuggingFace
|
20 |
from transformers import FlaxAutoModel
|
21 |
-
wf = FlaxAutoModel.from_pretrained("
|
22 |
N_params = nk.jax.tree_size(wf.params)
|
23 |
print('Number of parameters = ', N_params, flush=True)
|
24 |
lattice = nk.graph.Hypercube(length=10, n_dim=2, pbc=True, max_neighbor_order=2)
|
@@ -43,7 +43,7 @@ vstate = nk.vqs.MCState(sampler=sampler,
|
|
43 |
chunk_size=16384)
|
44 |
# Overwrite samples with already thermalized ones
|
45 |
from huggingface_hub import hf_hub_download
|
46 |
-
path = hf_hub_download(repo_id="
|
47 |
samples = checkpoints.restore_checkpoint(ckpt_dir=path, prefix="spins", target=None)
|
48 |
samples = jnp.array(samples, dtype='int8')
|
49 |
vstate.sampler_state = vstate.sampler_state.replace(σ = samples)
|
|
|
18 |
flax.config.update('flax_use_orbax_checkpointing', False)
|
19 |
# Load the model from HuggingFace
|
20 |
from transformers import FlaxAutoModel
|
21 |
+
wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_10x10", trust_remote_code=True)
|
22 |
N_params = nk.jax.tree_size(wf.params)
|
23 |
print('Number of parameters = ', N_params, flush=True)
|
24 |
lattice = nk.graph.Hypercube(length=10, n_dim=2, pbc=True, max_neighbor_order=2)
|
|
|
43 |
chunk_size=16384)
|
44 |
# Overwrite samples with already thermalized ones
|
45 |
from huggingface_hub import hf_hub_download
|
46 |
+
path = hf_hub_download(repo_id="nqs-models/j1j2_square_10x10", filename="spins")
|
47 |
samples = checkpoints.restore_checkpoint(ckpt_dir=path, prefix="spins", target=None)
|
48 |
samples = jnp.array(samples, dtype='int8')
|
49 |
vstate.sampler_state = vstate.sampler_state.replace(σ = samples)
|