import os import jax from flax.training.train_state import TrainState import optax import orbax.checkpoint as ocp from jaxgmg.procgen import maze_generation from jaxgmg.environments import cheese_in_the_corner from jaxgmg.baselines import networks # INPUTS SEED = 42 # checkpoint to load CHECKPOINT_FOLDER: str = "2ypu90e7" CHECKPOINT_NUMBER: int = 7168 # environment-specific stuff ENV = cheese_in_the_corner.Env( obs_level_of_detail=0, img_level_of_detail=1, penalize_time=False, terminate_after_cheese_and_corner=False, ) ARBITRARY_LEVEL_GENERATOR = cheese_in_the_corner.LevelGenerator( height=15, width=15, maze_generator=maze_generation.get_generator_class_from_name("blocks")(), corner_size=1, ) # actor critic policy config NET_CNN_TYPE: str = "large" NET_RNN_TYPE: str = "ff" NET_WIDTH: int = 256 # PREPARE THE LOAD THE CHECKPOINT rng = jax.random.PRNGKey(seed=SEED) rng_setup, rng_eval = jax.random.split(rng) # configure actor critic network architecture net = networks.Impala( num_actions=ENV.num_actions, cnn_type=NET_CNN_TYPE, rnn_type=NET_RNN_TYPE, width=NET_WIDTH, ) # initialise template network rng_model_init, rng_setup = jax.random.split(rng_setup) rng_example_level, rng_setup = jax.random.split(rng_setup) example_level = ARBITRARY_LEVEL_GENERATOR.sample(rng_example_level) net_init_params, net_init_state = net.init_params_and_state( rng=rng_model_init, obs_type=ENV.obs_type(level=example_level), ) # initialise the checkpointer checkpoint_manager = ocp.CheckpointManager( directory=os.path.abspath(CHECKPOINT_FOLDER), options=ocp.CheckpointManagerOptions( max_to_keep=None, save_interval_steps=1, ), ) # load the checkpoint to default device (CPU or GPU) net_params = checkpoint_manager.restore( CHECKPOINT_NUMBER, args=ocp.args.PyTreeRestore( net_init_params, restore_args=ocp.checkpoint_utils.construct_restore_args(net_init_params), ) ) # initialise train state train_state = TrainState.create( apply_fn=net.apply, params=net_params, tx=optax.sgd(learning_rate=0), ) print(train_state)