"""Callbacks for Anomalib models.""" # Copyright (C) 2020 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions # and limitations under the License. import os from importlib import import_module from typing import List, Union import yaml from omegaconf import DictConfig, ListConfig, OmegaConf from pytorch_lightning.callbacks import Callback, ModelCheckpoint from .cdf_normalization import CdfNormalizationCallback from .min_max_normalization import MinMaxNormalizationCallback from .model_loader import LoadModelCallback from .timer import TimerCallback from .visualizer_callback import VisualizerCallback __all__ = [ "LoadModelCallback", "TimerCallback", "VisualizerCallback", ] def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]: """Return base callbacks for all the lightning models. Args: config (DictConfig): Model config Return: (List[Callback]): List of callbacks. """ callbacks: List[Callback] = [] monitor_metric = None if "early_stopping" not in config.model.keys() else config.model.early_stopping.metric monitor_mode = "max" if "early_stopping" not in config.model.keys() else config.model.early_stopping.mode checkpoint = ModelCheckpoint( dirpath=os.path.join(config.project.path, "weights"), filename="model", monitor=monitor_metric, mode=monitor_mode, auto_insert_metric_name=False, ) callbacks.extend([checkpoint, TimerCallback()]) if "weight_file" in config.model.keys(): load_model = LoadModelCallback(os.path.join(config.project.path, config.model.weight_file)) callbacks.append(load_model) if "normalization_method" in config.model.keys() and not config.model.normalization_method == "none": if config.model.normalization_method == "cdf": if config.model.name in ["padim", "stfpm"]: if "nncf" in config.optimization and config.optimization.nncf.apply: raise NotImplementedError("CDF Score Normalization is currently not compatible with NNCF.") callbacks.append(CdfNormalizationCallback()) else: raise NotImplementedError("Score Normalization is currently supported for PADIM and STFPM only.") elif config.model.normalization_method == "min_max": callbacks.append(MinMaxNormalizationCallback()) else: raise ValueError(f"Normalization method not recognized: {config.model.normalization_method}") if not config.project.log_images_to == []: callbacks.append( VisualizerCallback( task=config.dataset.task, inputs_are_normalized=not config.model.normalization_method == "none" ) ) if "optimization" in config.keys(): if "nncf" in config.optimization and config.optimization.nncf.apply: # NNCF wraps torch's jit which conflicts with kornia's jit calls. # Hence, nncf is imported only when required nncf_module = import_module("anomalib.utils.callbacks.nncf.callback") nncf_callback = getattr(nncf_module, "NNCFCallback") nncf_config = yaml.safe_load(OmegaConf.to_yaml(config.optimization.nncf)) callbacks.append( nncf_callback( config=nncf_config, export_dir=os.path.join(config.project.path, "compressed"), ) ) if "openvino" in config.optimization and config.optimization.openvino.apply: from .openvino import ( # pylint: disable=import-outside-toplevel OpenVINOCallback, ) callbacks.append( OpenVINOCallback( input_size=config.model.input_size, dirpath=os.path.join(config.project.path, "openvino"), filename="openvino_model", ) ) return callbacks