Spaces:
Runtime error
Runtime error
from pathlib import Path | |
from datetime import timedelta | |
from typing import TYPE_CHECKING, Any, Literal, Mapping, Optional, Union | |
from typing_extensions import override | |
from functools import wraps | |
import os | |
from wandb_osh.hooks import TriggerWandbSyncHook | |
import time | |
from lightning.pytorch.loggers.wandb import WandbLogger, _scan_checkpoints, ModelCheckpoint, Tensor | |
from lightning.pytorch.utilities.rank_zero import rank_zero_only | |
from lightning.fabric.utilities.types import _PATH | |
if TYPE_CHECKING: | |
from wandb.sdk.lib import RunDisabled | |
from wandb.wandb_run import Run | |
class SpaceEfficientWandbLogger(WandbLogger): | |
""" | |
A wandb logger that by default overrides artifacts to save space, instead of creating new version. | |
A variable expiration_days can be set to control how long older versions of artifacts are kept. | |
By default, the latest version is kept indefinitely, while older versions are kept for 5 days. | |
""" | |
def __init__( | |
self, | |
name: Optional[str] = None, | |
save_dir: _PATH = ".", | |
version: Optional[str] = None, | |
offline: bool = False, | |
dir: Optional[_PATH] = None, | |
id: Optional[str] = None, | |
anonymous: Optional[bool] = None, | |
project: Optional[str] = None, | |
log_model: Union[Literal["all"], bool] = False, | |
experiment: Union["Run", "RunDisabled", None] = None, | |
prefix: str = "", | |
checkpoint_name: Optional[str] = None, | |
expiration_days: Optional[int] = 5, | |
**kwargs: Any, | |
) -> None: | |
super().__init__( | |
name=name, | |
save_dir=save_dir, | |
version=version, | |
offline=False, | |
dir=dir, | |
id=id, | |
anonymous=anonymous, | |
project=project, | |
log_model=log_model, | |
experiment=experiment, | |
prefix=prefix, | |
checkpoint_name=checkpoint_name, | |
**kwargs, | |
) | |
super().__init__( | |
name=name, | |
save_dir=save_dir, | |
version=version, | |
offline=offline, | |
dir=dir, | |
id=id, | |
anonymous=anonymous, | |
project=project, | |
log_model=log_model, | |
experiment=experiment, | |
prefix=prefix, | |
checkpoint_name=checkpoint_name, | |
**kwargs, | |
) | |
self.expiration_days = expiration_days | |
self._last_artifacts = [] | |
def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: | |
import wandb | |
# get checkpoints to be saved with associated score | |
checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time) | |
# log iteratively all new checkpoints | |
artifacts = [] | |
for t, p, s, tag in checkpoints: | |
metadata = { | |
"score": s.item() if isinstance(s, Tensor) else s, | |
"original_filename": Path(p).name, | |
checkpoint_callback.__class__.__name__: { | |
k: getattr(checkpoint_callback, k) | |
for k in [ | |
"monitor", | |
"mode", | |
"save_last", | |
"save_top_k", | |
"save_weights_only", | |
"_every_n_train_steps", | |
] | |
# ensure it does not break if `ModelCheckpoint` args change | |
if hasattr(checkpoint_callback, k) | |
}, | |
} | |
if not self._checkpoint_name: | |
self._checkpoint_name = f"model-{self.experiment.id}" | |
artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata) | |
artifact.add_file(p, name="model.ckpt") | |
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] | |
self.experiment.log_artifact(artifact, aliases=aliases) | |
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name) | |
self._logged_model_time[p] = t | |
artifacts.append(artifact) | |
for artifact in self._last_artifacts: | |
if not self._offline: | |
artifact.wait() | |
artifact.ttl = timedelta(days=self.expiration_days) | |
artifact.save() | |
self._last_artifacts = artifacts | |
class OfflineWandbLogger(SpaceEfficientWandbLogger): | |
""" | |
Wraps WandbLogger to trigger offline sync hook occasionally. | |
This is useful when running on slurm clusters, many of which | |
only has internet on login nodes, not compute nodes. | |
""" | |
def __init__( | |
self, | |
name: Optional[str] = None, | |
save_dir: _PATH = ".", | |
version: Optional[str] = None, | |
offline: bool = False, | |
dir: Optional[_PATH] = None, | |
id: Optional[str] = None, | |
anonymous: Optional[bool] = None, | |
project: Optional[str] = None, | |
log_model: Union[Literal["all"], bool] = False, | |
experiment: Union["Run", "RunDisabled", None] = None, | |
prefix: str = "", | |
checkpoint_name: Optional[str] = None, | |
**kwargs: Any, | |
) -> None: | |
super().__init__( | |
name=name, | |
save_dir=save_dir, | |
version=version, | |
offline=False, | |
dir=dir, | |
id=id, | |
anonymous=anonymous, | |
project=project, | |
log_model=log_model, | |
experiment=experiment, | |
prefix=prefix, | |
checkpoint_name=checkpoint_name, | |
**kwargs, | |
) | |
self._offline = offline | |
communication_dir = Path(".wandb_osh_command_dir") | |
communication_dir.mkdir(parents=True, exist_ok=True) | |
self.trigger_sync = TriggerWandbSyncHook(communication_dir) | |
self.last_sync_time = 0.0 | |
self.min_sync_interval = 60 | |
self.wandb_dir = os.path.join(self._save_dir, "wandb/latest-run") | |
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: | |
out = super().log_metrics(metrics, step) | |
if time.time() - self.last_sync_time > self.min_sync_interval: | |
self.trigger_sync(self.wandb_dir) | |
self.last_sync_time = time.time() | |
return out | |