"""Console logger utilities. Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging """ import logging import fsspec import lightning import torch from timm.scheduler import CosineLRScheduler def fsspec_exists(filename): """Check if a file exists using fsspec.""" fs, _ = fsspec.core.url_to_fs(filename) return fs.exists(filename) def fsspec_listdir(dirname): """Listdir in manner compatible with fsspec.""" fs, _ = fsspec.core.url_to_fs(dirname) return fs.ls(dirname) def fsspec_mkdirs(dirname, exist_ok=True): """Mkdirs in manner compatible with fsspec.""" fs, _ = fsspec.core.url_to_fs(dirname) fs.makedirs(dirname, exist_ok=exist_ok) def print_nans(tensor, name): if torch.isnan(tensor).any(): print(name, tensor) class CosineDecayWarmupLRScheduler( CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler): """Wrap timm.scheduler.CosineLRScheduler Enables calling scheduler.step() without passing in epoch. Supports resuming as well. Adapted from: https://github.com/HazyResearch/hyena-dna/blob/main/src/utils/optim/schedulers.py """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._last_epoch = -1 self.step(epoch=0) def step(self, epoch=None): if epoch is None: self._last_epoch += 1 else: self._last_epoch = epoch # We call either step or step_update, depending on # whether we're using the scheduler every epoch or every # step. # Otherwise, lightning will always call step (i.e., # meant for each epoch), and if we set scheduler # interval to "step", then the learning rate update will # be wrong. if self.t_in_epochs: super().step(epoch=self._last_epoch) else: super().step_update(num_updates=self._last_epoch) def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: """Initializes multi-GPU-friendly python logger.""" logger = logging.getLogger(name) logger.setLevel(level) # this ensures all logging levels get marked with the rank zero decorator # otherwise logs would get multiplied for each GPU process in multi-GPU setup for level in ('debug', 'info', 'warning', 'error', 'exception', 'fatal', 'critical'): setattr(logger, level, lightning.pytorch.utilities.rank_zero_only( getattr(logger, level))) return logger