Spaces:
Build error
Build error
"""Load PyTorch Lightning Loggers.""" | |
# Copyright (C) 2020 Intel Corporation | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions | |
# and limitations under the License. | |
import logging | |
import os | |
from typing import Iterable, List, Union | |
from omegaconf.dictconfig import DictConfig | |
from omegaconf.listconfig import ListConfig | |
from pytorch_lightning.loggers import CSVLogger, LightningLoggerBase | |
from .tensorboard import AnomalibTensorBoardLogger | |
from .wandb import AnomalibWandbLogger | |
__all__ = [ | |
"AnomalibTensorBoardLogger", | |
"AnomalibWandbLogger", | |
"configure_logger", | |
"get_experiment_logger", | |
] | |
AVAILABLE_LOGGERS = ["tensorboard", "wandb", "csv"] | |
class UnknownLogger(Exception): | |
"""This is raised when the logger option in `config.yaml` file is set incorrectly.""" | |
def configure_logger(level: Union[int, str] = logging.INFO): | |
"""Get console logger by name. | |
Args: | |
level (Union[int, str], optional): Logger Level. Defaults to logging.INFO. | |
Returns: | |
Logger: The expected logger. | |
""" | |
if isinstance(level, str): | |
level = logging.getLevelName(level) | |
format_string = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
logging.basicConfig(format=format_string, level=level) | |
# Set Pytorch Lightning logs to have a the consistent formatting with anomalib. | |
for handler in logging.getLogger("pytorch_lightning").handlers: | |
handler.setFormatter(logging.Formatter(format_string)) | |
handler.setLevel(level) | |
def get_experiment_logger( | |
config: Union[DictConfig, ListConfig] | |
) -> Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool]: | |
"""Return a logger based on the choice of logger in the config file. | |
Args: | |
config (DictConfig): config.yaml file for the corresponding anomalib model. | |
Raises: | |
ValueError: for any logger types apart from false and tensorboard | |
Returns: | |
Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool]: Logger | |
""" | |
if config.project.logger in [None, False]: | |
return False | |
logger_list: List[LightningLoggerBase] = [] | |
if isinstance(config.project.logger, str): | |
config.project.logger = [config.project.logger] | |
for logger in config.project.logger: | |
if logger == "tensorboard": | |
logger_list.append( | |
AnomalibTensorBoardLogger( | |
name="Tensorboard Logs", | |
save_dir=os.path.join(config.project.path, "logs"), | |
) | |
) | |
elif logger == "wandb": | |
wandb_logdir = os.path.join(config.project.path, "logs") | |
os.makedirs(wandb_logdir, exist_ok=True) | |
logger_list.append( | |
AnomalibWandbLogger( | |
project=config.dataset.name, | |
name=f"{config.dataset.category} {config.model.name}", | |
save_dir=wandb_logdir, | |
) | |
) | |
elif logger == "csv": | |
logger_list.append(CSVLogger(save_dir=os.path.join(config.project.path, "logs"))) | |
else: | |
raise UnknownLogger( | |
f"Unknown logger type: {config.project.logger}. " | |
f"Available loggers are: {AVAILABLE_LOGGERS}.\n" | |
f"To enable the logger, set `project.logger` to `true` or use one of available loggers in config.yaml\n" | |
f"To disable the logger, set `project.logger` to `false`." | |
) | |
return logger_list | |