import os from omegaconf import OmegaConf import torch import tempfile from safetensors.torch import load_file import requests import yaml def get_ckpt(path, key="state_dict"): is_url = path.startswith("http://") or path.startswith("https://") suffix = os.path.splitext(path)[-1] if is_url: print(f"Loading checkpoint from URL: {path}") with tempfile.NamedTemporaryFile(suffix=suffix) as tmp_file: response = requests.get(path) response.raise_for_status() tmp_file.write(response.content) tmp_file.flush() ckpt_path = tmp_file.name if suffix == ".safetensors": checkpoint = load_file(ckpt_path) else: checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) else: print(f"Loading checkpoint from local path: {path}") if suffix == ".safetensors": checkpoint = load_file(path) else: checkpoint = torch.load(path, map_location="cpu", weights_only=False) if key is not None and key in checkpoint: checkpoint = checkpoint[key] return checkpoint def get_yaml_config(path): if path.startswith("http://") or path.startswith("https://"): response = requests.get(path) response.raise_for_status() config = OmegaConf.create(response.text) else: with open(path, 'r') as f: config = OmegaConf.load(f) return config