rrende commited on
Commit
a01c7fd
·
verified ·
1 Parent(s): e5f9751

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +2 -2
README.md CHANGED
@@ -18,7 +18,7 @@ from flax.training import checkpoints
18
  flax.config.update('flax_use_orbax_checkpointing', False)
19
  # Load the model from HuggingFace
20
  from transformers import FlaxAutoModel
21
- wf = FlaxAutoModel.from_pretrained("rrende/j1j2_square_10x10", trust_remote_code=True)
22
  N_params = nk.jax.tree_size(wf.params)
23
  print('Number of parameters = ', N_params, flush=True)
24
  lattice = nk.graph.Hypercube(length=10, n_dim=2, pbc=True, max_neighbor_order=2)
@@ -43,7 +43,7 @@ vstate = nk.vqs.MCState(sampler=sampler,
43
  chunk_size=16384)
44
  # Overwrite samples with already thermalized ones
45
  from huggingface_hub import hf_hub_download
46
- path = hf_hub_download(repo_id="rrende/j1j2_square_10x10", filename="spins")
47
  samples = checkpoints.restore_checkpoint(ckpt_dir=path, prefix="spins", target=None)
48
  samples = jnp.array(samples, dtype='int8')
49
  vstate.sampler_state = vstate.sampler_state.replace(σ = samples)
 
18
  flax.config.update('flax_use_orbax_checkpointing', False)
19
  # Load the model from HuggingFace
20
  from transformers import FlaxAutoModel
21
+ wf = FlaxAutoModel.from_pretrained("nqs-models/j1j2_square_10x10", trust_remote_code=True)
22
  N_params = nk.jax.tree_size(wf.params)
23
  print('Number of parameters = ', N_params, flush=True)
24
  lattice = nk.graph.Hypercube(length=10, n_dim=2, pbc=True, max_neighbor_order=2)
 
43
  chunk_size=16384)
44
  # Overwrite samples with already thermalized ones
45
  from huggingface_hub import hf_hub_download
46
+ path = hf_hub_download(repo_id="nqs-models/j1j2_square_10x10", filename="spins")
47
  samples = checkpoints.restore_checkpoint(ckpt_dir=path, prefix="spins", target=None)
48
  samples = jnp.array(samples, dtype='int8')
49
  vstate.sampler_state = vstate.sampler_state.replace(σ = samples)