Spaces:
Running
Running
File size: 4,627 Bytes
9fd1204 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import contextlib
import copy
import pathlib
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from .logging import get_logger
from .utils import Timer, TimerDevice
logger = get_logger()
class BaseTracker:
r"""Base class for loggers. Does nothing by default, so it is useful when you want to disable logging."""
def __init__(self):
self._timed_metrics = {}
@contextlib.contextmanager
def timed(self, name: str, device: TimerDevice = TimerDevice.CPU, device_sync: bool = False):
r"""Context manager to track time for a specific operation."""
timer = Timer(name, device, device_sync)
timer.start()
yield timer
timer.end()
elapsed_time = timer.elapsed_time
if name in self._timed_metrics:
# If the timer name already exists, add the elapsed time to the existing value since a log has not been invoked yet
self._timed_metrics[name] += elapsed_time
else:
self._timed_metrics[name] = elapsed_time
def log(self, metrics: Dict[str, Any], step: int) -> None:
pass
def finish(self) -> None:
pass
class DummyTracker(BaseTracker):
def __init__(self):
super().__init__()
def log(self, *args, **kwargs):
pass
def finish(self) -> None:
pass
class WandbTracker(BaseTracker):
r"""Logger implementation for Weights & Biases."""
def __init__(self, experiment_name: str, log_dir: str, config: Optional[Dict[str, Any]] = None) -> None:
super().__init__()
import wandb
self.wandb = wandb
# WandB does not create a directory if it does not exist and instead starts using the system temp directory.
pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True)
self.run = wandb.init(project=experiment_name, dir=log_dir, config=config)
logger.info("WandB logging enabled")
def log(self, metrics: Dict[str, Any], step: int) -> None:
metrics = {**self._timed_metrics, **metrics}
self.run.log(metrics, step=step)
self._timed_metrics = {}
def finish(self) -> None:
self.run.finish()
class SequentialTracker(BaseTracker):
r"""Sequential tracker that logs to multiple trackers in sequence."""
def __init__(self, trackers: List[BaseTracker]) -> None:
super().__init__()
self.trackers = trackers
@contextlib.contextmanager
def timed(self, name: str, device: TimerDevice = TimerDevice.CPU, device_sync: bool = False):
r"""Context manager to track time for a specific operation."""
timer = Timer(name, device, device_sync)
timer.start()
yield timer
timer.end()
elapsed_time = timer.elapsed_time
if name in self._timed_metrics:
# If the timer name already exists, add the elapsed time to the existing value since a log has not been invoked yet
self._timed_metrics[name] += elapsed_time
else:
self._timed_metrics[name] = elapsed_time
for tracker in self.trackers:
tracker._timed_metrics = copy.deepcopy(self._timed_metrics)
def log(self, metrics: Dict[str, Any], step: int) -> None:
for tracker in self.trackers:
tracker.log(metrics, step)
self._timed_metrics = {}
def finish(self) -> None:
for tracker in self.trackers:
tracker.finish()
class Trackers(str, Enum):
r"""Enum for supported trackers."""
NONE = "none"
WANDB = "wandb"
_SUPPORTED_TRACKERS = [tracker.value for tracker in Trackers.__members__.values()]
def initialize_trackers(
trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str
) -> Union[BaseTracker, SequentialTracker]:
r"""Initialize loggers based on the provided configuration."""
logger.info(f"Initializing trackers: {trackers}. Logging to {log_dir=}")
if len(trackers) == 0:
return BaseTracker()
if any(tracker_name not in _SUPPORTED_TRACKERS for tracker_name in set(trackers)):
raise ValueError(f"Unsupported tracker(s) provided. Supported trackers: {_SUPPORTED_TRACKERS}")
tracker_instances = []
for tracker_name in set(trackers):
if tracker_name == Trackers.NONE:
tracker = BaseTracker()
elif tracker_name == Trackers.WANDB:
tracker = WandbTracker(experiment_name, log_dir, config)
tracker_instances.append(tracker)
tracker = SequentialTracker(tracker_instances)
return tracker
TrackerType = Union[BaseTracker, SequentialTracker, WandbTracker]
|