VTBench / src /utils.py
huaweilin's picture
update
14ce5a9
import os
from omegaconf import OmegaConf
import torch
import tempfile
from safetensors.torch import load_file
import requests
import yaml
def get_ckpt(path, key="state_dict"):
is_url = path.startswith("http://") or path.startswith("https://")
suffix = os.path.splitext(path)[-1]
if is_url:
print(f"Loading checkpoint from URL: {path}")
with tempfile.NamedTemporaryFile(suffix=suffix) as tmp_file:
response = requests.get(path)
response.raise_for_status()
tmp_file.write(response.content)
tmp_file.flush()
ckpt_path = tmp_file.name
if suffix == ".safetensors":
checkpoint = load_file(ckpt_path)
else:
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
else:
print(f"Loading checkpoint from local path: {path}")
if suffix == ".safetensors":
checkpoint = load_file(path)
else:
checkpoint = torch.load(path, map_location="cpu", weights_only=False)
if key is not None and key in checkpoint:
checkpoint = checkpoint[key]
return checkpoint
def get_yaml_config(path):
if path.startswith("http://") or path.startswith("https://"):
response = requests.get(path)
response.raise_for_status()
config = OmegaConf.create(response.text)
else:
with open(path, 'r') as f:
config = OmegaConf.load(f)
return config