julien.blanchon
add app
c8c12e9
"""Common helpers for both nightly and pre-merge model tests."""
# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
import os
from typing import Dict, List, Tuple, Union
import numpy as np
from omegaconf import DictConfig, ListConfig
from pytorch_lightning import LightningDataModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from anomalib.config import get_configurable_parameters, update_nncf_config
from anomalib.data import get_datamodule
from anomalib.models import get_model
from anomalib.models.components import AnomalyModule
from anomalib.utils.callbacks import VisualizerCallback, get_callbacks
def setup_model_train(
model_name: str,
dataset_path: str,
project_path: str,
nncf: bool,
category: str,
score_type: str = None,
weight_file: str = "weights/model.ckpt",
fast_run: bool = False,
device: Union[List[int], int] = [0],
) -> Tuple[Union[DictConfig, ListConfig], LightningDataModule, AnomalyModule, Trainer]:
"""Train the model based on the parameters passed.
Args:
model_name (str): Name of the model to train.
dataset_path (str): Location of the dataset.
project_path (str): Path to temporary project folder.
nncf (bool): Add nncf callback.
category (str): Category to train on.
score_type (str, optional): Only used for DFM. Defaults to None.
weight_file (str, optional): Path to weight file.
fast_run (bool, optional): If set to true, the model trains for only 1 epoch. We train for one epoch as
this ensures that both anomalous and non-anomalous images are present in the validation step.
device (List[int], int, optional): Select which device you want to train the model on. Defaults to first GPU.
Returns:
Tuple[DictConfig, LightningDataModule, AnomalyModule, Trainer]: config, datamodule, trained model, trainer
"""
config = get_configurable_parameters(model_name=model_name)
if score_type is not None:
config.model.score_type = score_type
config.project.seed = 42
config.dataset.category = category
config.dataset.path = dataset_path
config.project.log_images_to = []
config.trainer.gpus = device
# If weight file is empty, remove the key from config
if "weight_file" in config.model.keys() and weight_file == "":
config.model.pop("weight_file")
else:
config.model.weight_file = weight_file if not fast_run else "weights/last.ckpt"
if nncf:
config.optimization.nncf.apply = True
config = update_nncf_config(config)
config.init_weights = None
# reassign project path as config is updated in `update_config_for_nncf`
config.project.path = project_path
datamodule = get_datamodule(config)
model = get_model(config)
callbacks = get_callbacks(config)
# Force model checkpoint to create checkpoint after first epoch
if fast_run == True:
for index, callback in enumerate(callbacks):
if isinstance(callback, ModelCheckpoint):
callbacks.pop(index)
break
model_checkpoint = ModelCheckpoint(
dirpath=os.path.join(config.project.path, "weights"),
filename="last",
monitor=None,
mode="max",
save_last=True,
auto_insert_metric_name=False,
)
callbacks.append(model_checkpoint)
for index, callback in enumerate(callbacks):
if isinstance(callback, VisualizerCallback):
callbacks.pop(index)
break
# Train the model.
if fast_run:
config.trainer.max_epochs = 1
config.trainer.check_val_every_n_epoch = 1
trainer = Trainer(callbacks=callbacks, **config.trainer)
trainer.fit(model=model, datamodule=datamodule)
return config, datamodule, model, trainer
def model_load_test(config: Union[DictConfig, ListConfig], datamodule: LightningDataModule, results: Dict):
"""Create a new model based on the weights specified in config.
Args:
config ([Union[DictConfig, ListConfig]): Model config.
datamodule (LightningDataModule): Dataloader
results (Dict): Results from original model.
"""
loaded_model = get_model(config) # get new model
callbacks = get_callbacks(config)
for index, callback in enumerate(callbacks):
# Remove visualizer callback as saving results takes time
if isinstance(callback, VisualizerCallback):
callbacks.pop(index)
break
# create new trainer object with LoadModel callback (assumes it is present)
trainer = Trainer(callbacks=callbacks, **config.trainer)
# Assumes the new model has LoadModel callback and the old one had ModelCheckpoint callback
new_results = trainer.test(model=loaded_model, datamodule=datamodule)[0]
assert np.isclose(
results["image_AUROC"], new_results["image_AUROC"]
), "Loaded model does not yield close performance results"
if config.dataset.task == "segmentation":
assert np.isclose(
results["pixel_AUROC"], new_results["pixel_AUROC"]
), "Loaded model does not yield close performance results"