File size: 208 Bytes
779abe8
 
 
 
 
 
1
2
3
4
5
6
7
import torch

state_dict = torch.load(
    "cruise_logs/zephyr_freeze_ift/mp_rank_00_model_states.pt", map_location="cpu"
)
state_dict = {k.replace("module.", ""): v for k, v in state_dict["module"].items()}