File size: 4,627 Bytes
9fd1204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import contextlib
import copy
import pathlib
from enum import Enum
from typing import Any, Dict, List, Optional, Union

from .logging import get_logger
from .utils import Timer, TimerDevice


logger = get_logger()


class BaseTracker:
    r"""Base class for loggers. Does nothing by default, so it is useful when you want to disable logging."""

    def __init__(self):
        self._timed_metrics = {}

    @contextlib.contextmanager
    def timed(self, name: str, device: TimerDevice = TimerDevice.CPU, device_sync: bool = False):
        r"""Context manager to track time for a specific operation."""
        timer = Timer(name, device, device_sync)
        timer.start()
        yield timer
        timer.end()
        elapsed_time = timer.elapsed_time
        if name in self._timed_metrics:
            # If the timer name already exists, add the elapsed time to the existing value since a log has not been invoked yet
            self._timed_metrics[name] += elapsed_time
        else:
            self._timed_metrics[name] = elapsed_time

    def log(self, metrics: Dict[str, Any], step: int) -> None:
        pass

    def finish(self) -> None:
        pass


class DummyTracker(BaseTracker):
    def __init__(self):
        super().__init__()

    def log(self, *args, **kwargs):
        pass

    def finish(self) -> None:
        pass


class WandbTracker(BaseTracker):
    r"""Logger implementation for Weights & Biases."""

    def __init__(self, experiment_name: str, log_dir: str, config: Optional[Dict[str, Any]] = None) -> None:
        super().__init__()

        import wandb

        self.wandb = wandb

        # WandB does not create a directory if it does not exist and instead starts using the system temp directory.
        pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True)

        self.run = wandb.init(project=experiment_name, dir=log_dir, config=config)
        logger.info("WandB logging enabled")

    def log(self, metrics: Dict[str, Any], step: int) -> None:
        metrics = {**self._timed_metrics, **metrics}
        self.run.log(metrics, step=step)
        self._timed_metrics = {}

    def finish(self) -> None:
        self.run.finish()


class SequentialTracker(BaseTracker):
    r"""Sequential tracker that logs to multiple trackers in sequence."""

    def __init__(self, trackers: List[BaseTracker]) -> None:
        super().__init__()
        self.trackers = trackers

    @contextlib.contextmanager
    def timed(self, name: str, device: TimerDevice = TimerDevice.CPU, device_sync: bool = False):
        r"""Context manager to track time for a specific operation."""
        timer = Timer(name, device, device_sync)
        timer.start()
        yield timer
        timer.end()
        elapsed_time = timer.elapsed_time
        if name in self._timed_metrics:
            # If the timer name already exists, add the elapsed time to the existing value since a log has not been invoked yet
            self._timed_metrics[name] += elapsed_time
        else:
            self._timed_metrics[name] = elapsed_time
        for tracker in self.trackers:
            tracker._timed_metrics = copy.deepcopy(self._timed_metrics)

    def log(self, metrics: Dict[str, Any], step: int) -> None:
        for tracker in self.trackers:
            tracker.log(metrics, step)
        self._timed_metrics = {}

    def finish(self) -> None:
        for tracker in self.trackers:
            tracker.finish()


class Trackers(str, Enum):
    r"""Enum for supported trackers."""

    NONE = "none"
    WANDB = "wandb"


_SUPPORTED_TRACKERS = [tracker.value for tracker in Trackers.__members__.values()]


def initialize_trackers(
    trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str
) -> Union[BaseTracker, SequentialTracker]:
    r"""Initialize loggers based on the provided configuration."""

    logger.info(f"Initializing trackers: {trackers}. Logging to {log_dir=}")

    if len(trackers) == 0:
        return BaseTracker()

    if any(tracker_name not in _SUPPORTED_TRACKERS for tracker_name in set(trackers)):
        raise ValueError(f"Unsupported tracker(s) provided. Supported trackers: {_SUPPORTED_TRACKERS}")

    tracker_instances = []
    for tracker_name in set(trackers):
        if tracker_name == Trackers.NONE:
            tracker = BaseTracker()
        elif tracker_name == Trackers.WANDB:
            tracker = WandbTracker(experiment_name, log_dir, config)
        tracker_instances.append(tracker)

    tracker = SequentialTracker(tracker_instances)
    return tracker


TrackerType = Union[BaseTracker, SequentialTracker, WandbTracker]