File size: 1,495 Bytes
14ce5a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import os
from omegaconf import OmegaConf
import torch
import tempfile
from safetensors.torch import load_file
import requests
import yaml
def get_ckpt(path, key="state_dict"):
is_url = path.startswith("http://") or path.startswith("https://")
suffix = os.path.splitext(path)[-1]
if is_url:
print(f"Loading checkpoint from URL: {path}")
with tempfile.NamedTemporaryFile(suffix=suffix) as tmp_file:
response = requests.get(path)
response.raise_for_status()
tmp_file.write(response.content)
tmp_file.flush()
ckpt_path = tmp_file.name
if suffix == ".safetensors":
checkpoint = load_file(ckpt_path)
else:
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
else:
print(f"Loading checkpoint from local path: {path}")
if suffix == ".safetensors":
checkpoint = load_file(path)
else:
checkpoint = torch.load(path, map_location="cpu", weights_only=False)
if key is not None and key in checkpoint:
checkpoint = checkpoint[key]
return checkpoint
def get_yaml_config(path):
if path.startswith("http://") or path.startswith("https://"):
response = requests.get(path)
response.raise_for_status()
config = OmegaConf.create(response.text)
else:
with open(path, 'r') as f:
config = OmegaConf.load(f)
return config
|