|
from pathlib import Path |
|
|
|
import wandb |
|
|
|
|
|
def version_to_int(artifact) -> int: |
|
"""Convert versions of the form vX to X. For example, v12 to 12.""" |
|
return int(artifact.version[1:]) |
|
|
|
|
|
def download_checkpoint( |
|
run_id: str, |
|
download_dir: Path, |
|
version: str | None, |
|
) -> Path: |
|
api = wandb.Api() |
|
run = api.run(run_id) |
|
|
|
|
|
chosen = None |
|
for artifact in run.logged_artifacts(): |
|
if artifact.type != "model" or artifact.state != "COMMITTED": |
|
continue |
|
|
|
|
|
if version is None: |
|
if chosen is None or version_to_int(artifact) > version_to_int(chosen): |
|
chosen = artifact |
|
|
|
|
|
elif version == artifact.version: |
|
chosen = artifact |
|
break |
|
|
|
|
|
download_dir.mkdir(exist_ok=True, parents=True) |
|
root = download_dir / run_id |
|
chosen.download(root=root) |
|
return root / "model.ckpt" |
|
|
|
|
|
def update_checkpoint_path(path: str | None, wandb_cfg: dict) -> Path | None: |
|
if path is None: |
|
return None |
|
|
|
if not str(path).startswith("wandb://"): |
|
return Path(path) |
|
|
|
run_id, *version = path[len("wandb://") :].split(":") |
|
if len(version) == 0: |
|
version = None |
|
elif len(version) == 1: |
|
version = version[0] |
|
else: |
|
raise ValueError("Invalid version specifier!") |
|
|
|
project = wandb_cfg["project"] |
|
return download_checkpoint( |
|
f"{project}/{run_id}", |
|
Path("checkpoints"), |
|
version, |
|
) |
|
|