|
import os |
|
|
|
import jax |
|
from flax.training.train_state import TrainState |
|
import optax |
|
import orbax.checkpoint as ocp |
|
|
|
from jaxgmg.procgen import maze_generation |
|
from jaxgmg.environments import cheese_in_the_corner |
|
from jaxgmg.baselines import networks |
|
|
|
|
|
|
|
|
|
SEED = 42 |
|
|
|
CHECKPOINT_FOLDER: str = "2ypu90e7" |
|
CHECKPOINT_NUMBER: int = 7168 |
|
|
|
ENV = cheese_in_the_corner.Env( |
|
obs_level_of_detail=0, |
|
img_level_of_detail=1, |
|
penalize_time=False, |
|
terminate_after_cheese_and_corner=False, |
|
) |
|
ARBITRARY_LEVEL_GENERATOR = cheese_in_the_corner.LevelGenerator( |
|
height=15, |
|
width=15, |
|
maze_generator=maze_generation.get_generator_class_from_name("blocks")(), |
|
corner_size=1, |
|
) |
|
|
|
NET_CNN_TYPE: str = "large" |
|
NET_RNN_TYPE: str = "ff" |
|
NET_WIDTH: int = 256 |
|
|
|
|
|
|
|
|
|
|
|
rng = jax.random.PRNGKey(seed=SEED) |
|
rng_setup, rng_eval = jax.random.split(rng) |
|
|
|
|
|
net = networks.Impala( |
|
num_actions=ENV.num_actions, |
|
cnn_type=NET_CNN_TYPE, |
|
rnn_type=NET_RNN_TYPE, |
|
width=NET_WIDTH, |
|
) |
|
|
|
|
|
rng_model_init, rng_setup = jax.random.split(rng_setup) |
|
rng_example_level, rng_setup = jax.random.split(rng_setup) |
|
example_level = ARBITRARY_LEVEL_GENERATOR.sample(rng_example_level) |
|
net_init_params, net_init_state = net.init_params_and_state( |
|
rng=rng_model_init, |
|
obs_type=ENV.obs_type(level=example_level), |
|
) |
|
|
|
|
|
checkpoint_manager = ocp.CheckpointManager( |
|
directory=os.path.abspath(CHECKPOINT_FOLDER), |
|
options=ocp.CheckpointManagerOptions( |
|
max_to_keep=None, |
|
save_interval_steps=1, |
|
), |
|
) |
|
|
|
|
|
net_params = checkpoint_manager.restore( |
|
CHECKPOINT_NUMBER, |
|
args=ocp.args.PyTreeRestore( |
|
net_init_params, |
|
restore_args=ocp.checkpoint_utils.construct_restore_args(net_init_params), |
|
) |
|
) |
|
|
|
|
|
train_state = TrainState.create( |
|
apply_fn=net.apply, |
|
params=net_params, |
|
tx=optax.sgd(learning_rate=0), |
|
) |
|
|
|
print(train_state) |
|
|