Last final stable checkpoint
#1
by
umarbutler
- opened
Is it correct to assume that the final LR-stable (non-decay) checkpoint is ep0-ba49552-rank0.pt
?
Additionally, is this the right way of loading that checkpoint?
BASE_MODEL_NAME = "answerdotai/ModernBERT-large"
BASE_MODEL_CHECKPOINT = (
"answerdotai/ModernBERT-large-training-checkpoints",
"context-extension/ep0-ba49552-rank0.pt",
)
base_model_checkpoint_path = hf_hub_download(
repo_id=BASE_MODEL_CHECKPOINT[0],
filename=BASE_MODEL_CHECKPOINT[1],
repo_type="model",
)
base_model_checkpoint = torch.load(base_model_checkpoint_path, map_location="cpu")["state"]
model_state = base_model_checkpoint["model"]
model_state_key_map = {
k: k.replace(".bert", "")
.replace(".encoder", "")
.replace("model.head", "head")
.replace("model.decoder", "decoder")
for k in model_state
}
model_state = {model_state_key_map[k]: v for k, v in model_state.items()}
base_model.load_state_dict(model_state)