Spaces:
Running
on
Zero
Running
on
Zero
import time | |
from typing import Mapping, Union | |
from ignite.contrib.handlers import TensorboardLogger | |
from ignite.handlers import global_step_from_engine | |
from ignite.contrib.handlers.base_logger import BaseHandler | |
from ignite.engine import Engine, EventEnum, Events | |
import torch | |
def add_time_handlers(engine: Engine): | |
iteration_time_handler = TimeHandler("iter", freq=True, period=True) | |
batch_time_handler = TimeHandler("get_batch", freq=False, period=True) | |
engine.add_event_handler( | |
Events.ITERATION_STARTED, iteration_time_handler.start_timing | |
) | |
engine.add_event_handler( | |
Events.ITERATION_COMPLETED, iteration_time_handler.end_timing | |
) | |
engine.add_event_handler(Events.GET_BATCH_STARTED, batch_time_handler.start_timing) | |
engine.add_event_handler(Events.GET_BATCH_COMPLETED, batch_time_handler.end_timing) | |
class MetricLoggingHandler(BaseHandler): | |
def __init__( | |
self, | |
tag, | |
optimizer=None, | |
log_loss=True, | |
log_metrics=True, | |
log_timings=True, | |
global_step_transform=None, | |
): | |
self.tag = tag | |
self.optimizer = optimizer | |
self.log_loss = log_loss | |
self.log_metrics = log_metrics | |
self.log_timings = log_timings | |
self.gst = global_step_transform | |
super(MetricLoggingHandler, self).__init__() | |
def __call__( | |
self, | |
engine: Engine, | |
logger: TensorboardLogger, | |
event_name: Union[str, EventEnum], | |
): | |
if not isinstance(logger, TensorboardLogger): | |
raise RuntimeError( | |
"Handler 'MetricLoggingHandler' works only with TensorboardLogger" | |
) | |
if self.gst is None: | |
gst = global_step_from_engine(engine) | |
else: | |
gst = self.gst | |
global_step = gst(engine, event_name) # type: ignore[misc] | |
if not isinstance(global_step, int): | |
raise TypeError( | |
f"global_step must be int, got {type(global_step)}." | |
" Please check the output of global_step_transform." | |
) | |
writer = logger.writer | |
# Optimizer parameters | |
if self.optimizer is not None: | |
params = { | |
k: float(param_group["lr"]) | |
for k, param_group in enumerate(self.optimizer.param_groups) | |
} | |
for k, param in params.items(): | |
writer.add_scalar(f"lr-{self.tag}/{k}", param, global_step) | |
if self.log_loss: | |
# Plot losses | |
loss_dict = engine.state.output["loss_dict"] | |
for k, v in loss_dict.items(): | |
# TODO: is this needed? | |
# if not isinstance(v, (float, int)): | |
# print(f"{k}: {type(v)}") | |
writer.add_scalar(f"loss-{self.tag}/{k}", v, global_step) | |
if self.log_metrics: | |
# Plot metrics | |
metrics_dict = engine.state.metrics | |
metrics_dict_custom = engine.state.output["metrics_dict"] | |
for k, v in metrics_dict.items(): | |
# Avoid dictionaries because of weird ignite handling of Mapping metrics | |
if isinstance(v, Mapping) or k.endswith("assignment"): # TODO: Remove hard-coded assignment | |
continue | |
if isinstance(v, torch.Tensor) and v.ndim > 0: | |
writer.add_histogram(f"metrics-{self.tag}/{k}", v, global_step) | |
else: | |
writer.add_scalar(f"metrics-{self.tag}/{k}", v, global_step) | |
for k, v in metrics_dict_custom.items(): | |
if isinstance(v, Mapping): | |
continue | |
if isinstance(v, torch.Tensor) and v.ndim > 0: | |
writer.add_histogram(f"metrics-{self.tag}/{k}", v, global_step) | |
else: | |
writer.add_scalar(f"metrics-{self.tag}/{k}", v, global_step) | |
if self.log_timings: | |
# Plot timings | |
timings_dict = engine.state.times | |
timings_dict_custom = engine.state.output["timings_dict"] | |
for k, v in timings_dict.items(): | |
if k == "COMPLETED": | |
continue | |
writer.add_scalar(f"timing-{self.tag}/{k}", v, global_step) | |
for k, v in timings_dict_custom.items(): | |
writer.add_scalar(f"timing-{self.tag}/{k}", v, global_step) | |
engine.state.output = None # For memory efficiency, val results do not need to stay in the state | |
class TimeHandler: | |
def __init__(self, name: str, freq: bool = False, period: bool = False) -> None: | |
self.name = name | |
self.freq = freq | |
self.period = period | |
if not self.period and not self.freq: | |
print(f"Warning: No timings logged for {name}") | |
self._start_time = None | |
def start_timing(self, engine): | |
self._start_time = time.time() | |
def end_timing(self, engine): | |
if self._start_time is None: | |
period = 0 | |
freq = 0 | |
else: | |
period = max(time.time() - self._start_time, 1e-6) | |
freq = 1 / period | |
if not hasattr(engine.state, "times"): | |
engine.state.times = {} | |
else: | |
if self.period: | |
engine.state.times[f"secs_per_{self.name}"] = period | |
if self.freq: | |
engine.state.times[f"num_{self.name}_per_sec"] = freq | |
class VisualizationHandler(BaseHandler): | |
def __init__(self, tag, visualizer, global_step_transform=None): | |
self.tag = tag | |
self.visualizer = visualizer | |
self.gst = global_step_transform | |
super(VisualizationHandler, self).__init__() | |
def __call__( | |
self, | |
engine: Engine, | |
logger: TensorboardLogger, | |
event_name: Union[str, EventEnum], | |
) -> None: | |
if not isinstance(logger, TensorboardLogger): | |
raise RuntimeError( | |
"Handler 'VisualizationHandler' works only with TensorboardLogger" | |
) | |
if self.gst is None: | |
gst = global_step_from_engine(engine) | |
else: | |
gst = self.gst | |
global_step = gst(engine, event_name) # type: ignore[misc] | |
if not isinstance(global_step, int): | |
raise TypeError( | |
f"global_step must be int, got {type(global_step)}." | |
" Please check the output of global_step_transform." | |
) | |
self.visualizer(engine, logger, global_step, self.tag) | |