File size: 1,495 Bytes
14ce5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
from omegaconf import OmegaConf
import torch
import tempfile
from safetensors.torch import load_file
import requests
import yaml

def get_ckpt(path, key="state_dict"):
    is_url = path.startswith("http://") or path.startswith("https://")
    suffix = os.path.splitext(path)[-1]

    if is_url:
        print(f"Loading checkpoint from URL: {path}")
        with tempfile.NamedTemporaryFile(suffix=suffix) as tmp_file:
            response = requests.get(path)
            response.raise_for_status()
            tmp_file.write(response.content)
            tmp_file.flush()
            ckpt_path = tmp_file.name

            if suffix == ".safetensors":
                checkpoint = load_file(ckpt_path)
            else:
                checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
    else:
        print(f"Loading checkpoint from local path: {path}")
        if suffix == ".safetensors":
            checkpoint = load_file(path)
        else:
            checkpoint = torch.load(path, map_location="cpu", weights_only=False)

    if key is not None and key in checkpoint:
        checkpoint = checkpoint[key]

    return checkpoint


def get_yaml_config(path):
    if path.startswith("http://") or path.startswith("https://"):
        response = requests.get(path)
        response.raise_for_status()
        config = OmegaConf.create(response.text)
    else:
        with open(path, 'r') as f:
            config = OmegaConf.load(f)
    return config