jaxgmg_checkpoints / load_checkpoint.py
davidquarel's picture
Upload folder using huggingface_hub
c763d83 verified
import os
import jax
from flax.training.train_state import TrainState
import optax
import orbax.checkpoint as ocp
from jaxgmg.procgen import maze_generation
from jaxgmg.environments import cheese_in_the_corner
from jaxgmg.baselines import networks
# INPUTS
SEED = 42
# checkpoint to load
CHECKPOINT_FOLDER: str = "2ypu90e7"
CHECKPOINT_NUMBER: int = 7168
# environment-specific stuff
ENV = cheese_in_the_corner.Env(
obs_level_of_detail=0,
img_level_of_detail=1,
penalize_time=False,
terminate_after_cheese_and_corner=False,
)
ARBITRARY_LEVEL_GENERATOR = cheese_in_the_corner.LevelGenerator(
height=15,
width=15,
maze_generator=maze_generation.get_generator_class_from_name("blocks")(),
corner_size=1,
)
# actor critic policy config
NET_CNN_TYPE: str = "large"
NET_RNN_TYPE: str = "ff"
NET_WIDTH: int = 256
# PREPARE THE LOAD THE CHECKPOINT
rng = jax.random.PRNGKey(seed=SEED)
rng_setup, rng_eval = jax.random.split(rng)
# configure actor critic network architecture
net = networks.Impala(
num_actions=ENV.num_actions,
cnn_type=NET_CNN_TYPE,
rnn_type=NET_RNN_TYPE,
width=NET_WIDTH,
)
# initialise template network
rng_model_init, rng_setup = jax.random.split(rng_setup)
rng_example_level, rng_setup = jax.random.split(rng_setup)
example_level = ARBITRARY_LEVEL_GENERATOR.sample(rng_example_level)
net_init_params, net_init_state = net.init_params_and_state(
rng=rng_model_init,
obs_type=ENV.obs_type(level=example_level),
)
# initialise the checkpointer
checkpoint_manager = ocp.CheckpointManager(
directory=os.path.abspath(CHECKPOINT_FOLDER),
options=ocp.CheckpointManagerOptions(
max_to_keep=None,
save_interval_steps=1,
),
)
# load the checkpoint to default device (CPU or GPU)
net_params = checkpoint_manager.restore(
CHECKPOINT_NUMBER,
args=ocp.args.PyTreeRestore(
net_init_params,
restore_args=ocp.checkpoint_utils.construct_restore_args(net_init_params),
)
)
# initialise train state
train_state = TrainState.create(
apply_fn=net.apply,
params=net_params,
tx=optax.sgd(learning_rate=0),
)
print(train_state)