Spaces:
Build error
Build error
"""Callbacks for Anomalib models.""" | |
# Copyright (C) 2020 Intel Corporation | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions | |
# and limitations under the License. | |
import os | |
from importlib import import_module | |
from typing import List, Union | |
import yaml | |
from omegaconf import DictConfig, ListConfig, OmegaConf | |
from pytorch_lightning.callbacks import Callback, ModelCheckpoint | |
from .cdf_normalization import CdfNormalizationCallback | |
from .min_max_normalization import MinMaxNormalizationCallback | |
from .model_loader import LoadModelCallback | |
from .timer import TimerCallback | |
from .visualizer_callback import VisualizerCallback | |
__all__ = [ | |
"LoadModelCallback", | |
"TimerCallback", | |
"VisualizerCallback", | |
] | |
def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]: | |
"""Return base callbacks for all the lightning models. | |
Args: | |
config (DictConfig): Model config | |
Return: | |
(List[Callback]): List of callbacks. | |
""" | |
callbacks: List[Callback] = [] | |
monitor_metric = None if "early_stopping" not in config.model.keys() else config.model.early_stopping.metric | |
monitor_mode = "max" if "early_stopping" not in config.model.keys() else config.model.early_stopping.mode | |
checkpoint = ModelCheckpoint( | |
dirpath=os.path.join(config.project.path, "weights"), | |
filename="model", | |
monitor=monitor_metric, | |
mode=monitor_mode, | |
auto_insert_metric_name=False, | |
) | |
callbacks.extend([checkpoint, TimerCallback()]) | |
if "weight_file" in config.model.keys(): | |
load_model = LoadModelCallback(os.path.join(config.project.path, config.model.weight_file)) | |
callbacks.append(load_model) | |
if "normalization_method" in config.model.keys() and not config.model.normalization_method == "none": | |
if config.model.normalization_method == "cdf": | |
if config.model.name in ["padim", "stfpm"]: | |
if "nncf" in config.optimization and config.optimization.nncf.apply: | |
raise NotImplementedError("CDF Score Normalization is currently not compatible with NNCF.") | |
callbacks.append(CdfNormalizationCallback()) | |
else: | |
raise NotImplementedError("Score Normalization is currently supported for PADIM and STFPM only.") | |
elif config.model.normalization_method == "min_max": | |
callbacks.append(MinMaxNormalizationCallback()) | |
else: | |
raise ValueError(f"Normalization method not recognized: {config.model.normalization_method}") | |
if not config.project.log_images_to == []: | |
callbacks.append( | |
VisualizerCallback( | |
task=config.dataset.task, inputs_are_normalized=not config.model.normalization_method == "none" | |
) | |
) | |
if "optimization" in config.keys(): | |
if "nncf" in config.optimization and config.optimization.nncf.apply: | |
# NNCF wraps torch's jit which conflicts with kornia's jit calls. | |
# Hence, nncf is imported only when required | |
nncf_module = import_module("anomalib.utils.callbacks.nncf.callback") | |
nncf_callback = getattr(nncf_module, "NNCFCallback") | |
nncf_config = yaml.safe_load(OmegaConf.to_yaml(config.optimization.nncf)) | |
callbacks.append( | |
nncf_callback( | |
config=nncf_config, | |
export_dir=os.path.join(config.project.path, "compressed"), | |
) | |
) | |
if "openvino" in config.optimization and config.optimization.openvino.apply: | |
from .openvino import ( # pylint: disable=import-outside-toplevel | |
OpenVINOCallback, | |
) | |
callbacks.append( | |
OpenVINOCallback( | |
input_size=config.model.input_size, | |
dirpath=os.path.join(config.project.path, "openvino"), | |
filename="openvino_model", | |
) | |
) | |
return callbacks | |