Spaces:
Build error
Build error
File size: 4,445 Bytes
c8c12e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
"""Callbacks for Anomalib models."""
# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
import os
from importlib import import_module
from typing import List, Union
import yaml
from omegaconf import DictConfig, ListConfig, OmegaConf
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from .cdf_normalization import CdfNormalizationCallback
from .min_max_normalization import MinMaxNormalizationCallback
from .model_loader import LoadModelCallback
from .timer import TimerCallback
from .visualizer_callback import VisualizerCallback
__all__ = [
"LoadModelCallback",
"TimerCallback",
"VisualizerCallback",
]
def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]:
"""Return base callbacks for all the lightning models.
Args:
config (DictConfig): Model config
Return:
(List[Callback]): List of callbacks.
"""
callbacks: List[Callback] = []
monitor_metric = None if "early_stopping" not in config.model.keys() else config.model.early_stopping.metric
monitor_mode = "max" if "early_stopping" not in config.model.keys() else config.model.early_stopping.mode
checkpoint = ModelCheckpoint(
dirpath=os.path.join(config.project.path, "weights"),
filename="model",
monitor=monitor_metric,
mode=monitor_mode,
auto_insert_metric_name=False,
)
callbacks.extend([checkpoint, TimerCallback()])
if "weight_file" in config.model.keys():
load_model = LoadModelCallback(os.path.join(config.project.path, config.model.weight_file))
callbacks.append(load_model)
if "normalization_method" in config.model.keys() and not config.model.normalization_method == "none":
if config.model.normalization_method == "cdf":
if config.model.name in ["padim", "stfpm"]:
if "nncf" in config.optimization and config.optimization.nncf.apply:
raise NotImplementedError("CDF Score Normalization is currently not compatible with NNCF.")
callbacks.append(CdfNormalizationCallback())
else:
raise NotImplementedError("Score Normalization is currently supported for PADIM and STFPM only.")
elif config.model.normalization_method == "min_max":
callbacks.append(MinMaxNormalizationCallback())
else:
raise ValueError(f"Normalization method not recognized: {config.model.normalization_method}")
if not config.project.log_images_to == []:
callbacks.append(
VisualizerCallback(
task=config.dataset.task, inputs_are_normalized=not config.model.normalization_method == "none"
)
)
if "optimization" in config.keys():
if "nncf" in config.optimization and config.optimization.nncf.apply:
# NNCF wraps torch's jit which conflicts with kornia's jit calls.
# Hence, nncf is imported only when required
nncf_module = import_module("anomalib.utils.callbacks.nncf.callback")
nncf_callback = getattr(nncf_module, "NNCFCallback")
nncf_config = yaml.safe_load(OmegaConf.to_yaml(config.optimization.nncf))
callbacks.append(
nncf_callback(
config=nncf_config,
export_dir=os.path.join(config.project.path, "compressed"),
)
)
if "openvino" in config.optimization and config.optimization.openvino.apply:
from .openvino import ( # pylint: disable=import-outside-toplevel
OpenVINOCallback,
)
callbacks.append(
OpenVINOCallback(
input_size=config.model.input_size,
dirpath=os.path.join(config.project.path, "openvino"),
filename="openvino_model",
)
)
return callbacks
|