Last final stable checkpoint

#1
by umarbutler - opened

Is it correct to assume that the final LR-stable (non-decay) checkpoint is ep0-ba49552-rank0.pt?

Additionally, is this the right way of loading that checkpoint?

BASE_MODEL_NAME = "answerdotai/ModernBERT-large"
BASE_MODEL_CHECKPOINT = (
    "answerdotai/ModernBERT-large-training-checkpoints",
    "context-extension/ep0-ba49552-rank0.pt",
)
base_model_checkpoint_path = hf_hub_download(
    repo_id=BASE_MODEL_CHECKPOINT[0],
    filename=BASE_MODEL_CHECKPOINT[1],
    repo_type="model",
)
base_model_checkpoint = torch.load(base_model_checkpoint_path, map_location="cpu")["state"]
model_state = base_model_checkpoint["model"]
model_state_key_map = {
    k: k.replace(".bert", "")
    .replace(".encoder", "")
    .replace("model.head", "head")
    .replace("model.decoder", "decoder")
    for k in model_state
}
model_state = {model_state_key_map[k]: v for k, v in model_state.items()}
base_model.load_state_dict(model_state)

Sign up or log in to comment