File size: 2,152 Bytes
c763d83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os

import jax
from flax.training.train_state import TrainState
import optax
import orbax.checkpoint as ocp

from jaxgmg.procgen import maze_generation
from jaxgmg.environments import cheese_in_the_corner
from jaxgmg.baselines import networks


# INPUTS

SEED = 42
# checkpoint to load
CHECKPOINT_FOLDER: str = "2ypu90e7"
CHECKPOINT_NUMBER: int = 7168
# environment-specific stuff
ENV = cheese_in_the_corner.Env(
    obs_level_of_detail=0,
    img_level_of_detail=1,
    penalize_time=False,
    terminate_after_cheese_and_corner=False,
)
ARBITRARY_LEVEL_GENERATOR = cheese_in_the_corner.LevelGenerator(
    height=15,
    width=15,
    maze_generator=maze_generation.get_generator_class_from_name("blocks")(),
    corner_size=1,
)
# actor critic policy config
NET_CNN_TYPE: str = "large"
NET_RNN_TYPE: str = "ff"
NET_WIDTH: int = 256



# PREPARE THE LOAD THE CHECKPOINT

rng = jax.random.PRNGKey(seed=SEED)
rng_setup, rng_eval = jax.random.split(rng)

# configure actor critic network architecture
net = networks.Impala(
    num_actions=ENV.num_actions,
    cnn_type=NET_CNN_TYPE,
    rnn_type=NET_RNN_TYPE,
    width=NET_WIDTH,
)

# initialise template network
rng_model_init, rng_setup = jax.random.split(rng_setup)
rng_example_level, rng_setup = jax.random.split(rng_setup)
example_level = ARBITRARY_LEVEL_GENERATOR.sample(rng_example_level)
net_init_params, net_init_state = net.init_params_and_state(
    rng=rng_model_init,
    obs_type=ENV.obs_type(level=example_level),
)

# initialise the checkpointer
checkpoint_manager = ocp.CheckpointManager(
    directory=os.path.abspath(CHECKPOINT_FOLDER),
    options=ocp.CheckpointManagerOptions(
        max_to_keep=None,
        save_interval_steps=1,
    ),
)

# load the checkpoint to default device (CPU or GPU)
net_params = checkpoint_manager.restore(
    CHECKPOINT_NUMBER,
    args=ocp.args.PyTreeRestore(
        net_init_params,
        restore_args=ocp.checkpoint_utils.construct_restore_args(net_init_params),
    )
)

# initialise train state
train_state = TrainState.create(
    apply_fn=net.apply,
    params=net_params,
    tx=optax.sgd(learning_rate=0),
)

print(train_state)