Spaces:
Build error
Build error
File size: 5,752 Bytes
c8c12e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
"""Common helpers for both nightly and pre-merge model tests."""
# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
import os
from typing import Dict, List, Tuple, Union
import numpy as np
from omegaconf import DictConfig, ListConfig
from pytorch_lightning import LightningDataModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from anomalib.config import get_configurable_parameters, update_nncf_config
from anomalib.data import get_datamodule
from anomalib.models import get_model
from anomalib.models.components import AnomalyModule
from anomalib.utils.callbacks import VisualizerCallback, get_callbacks
def setup_model_train(
model_name: str,
dataset_path: str,
project_path: str,
nncf: bool,
category: str,
score_type: str = None,
weight_file: str = "weights/model.ckpt",
fast_run: bool = False,
device: Union[List[int], int] = [0],
) -> Tuple[Union[DictConfig, ListConfig], LightningDataModule, AnomalyModule, Trainer]:
"""Train the model based on the parameters passed.
Args:
model_name (str): Name of the model to train.
dataset_path (str): Location of the dataset.
project_path (str): Path to temporary project folder.
nncf (bool): Add nncf callback.
category (str): Category to train on.
score_type (str, optional): Only used for DFM. Defaults to None.
weight_file (str, optional): Path to weight file.
fast_run (bool, optional): If set to true, the model trains for only 1 epoch. We train for one epoch as
this ensures that both anomalous and non-anomalous images are present in the validation step.
device (List[int], int, optional): Select which device you want to train the model on. Defaults to first GPU.
Returns:
Tuple[DictConfig, LightningDataModule, AnomalyModule, Trainer]: config, datamodule, trained model, trainer
"""
config = get_configurable_parameters(model_name=model_name)
if score_type is not None:
config.model.score_type = score_type
config.project.seed = 42
config.dataset.category = category
config.dataset.path = dataset_path
config.project.log_images_to = []
config.trainer.gpus = device
# If weight file is empty, remove the key from config
if "weight_file" in config.model.keys() and weight_file == "":
config.model.pop("weight_file")
else:
config.model.weight_file = weight_file if not fast_run else "weights/last.ckpt"
if nncf:
config.optimization.nncf.apply = True
config = update_nncf_config(config)
config.init_weights = None
# reassign project path as config is updated in `update_config_for_nncf`
config.project.path = project_path
datamodule = get_datamodule(config)
model = get_model(config)
callbacks = get_callbacks(config)
# Force model checkpoint to create checkpoint after first epoch
if fast_run == True:
for index, callback in enumerate(callbacks):
if isinstance(callback, ModelCheckpoint):
callbacks.pop(index)
break
model_checkpoint = ModelCheckpoint(
dirpath=os.path.join(config.project.path, "weights"),
filename="last",
monitor=None,
mode="max",
save_last=True,
auto_insert_metric_name=False,
)
callbacks.append(model_checkpoint)
for index, callback in enumerate(callbacks):
if isinstance(callback, VisualizerCallback):
callbacks.pop(index)
break
# Train the model.
if fast_run:
config.trainer.max_epochs = 1
config.trainer.check_val_every_n_epoch = 1
trainer = Trainer(callbacks=callbacks, **config.trainer)
trainer.fit(model=model, datamodule=datamodule)
return config, datamodule, model, trainer
def model_load_test(config: Union[DictConfig, ListConfig], datamodule: LightningDataModule, results: Dict):
"""Create a new model based on the weights specified in config.
Args:
config ([Union[DictConfig, ListConfig]): Model config.
datamodule (LightningDataModule): Dataloader
results (Dict): Results from original model.
"""
loaded_model = get_model(config) # get new model
callbacks = get_callbacks(config)
for index, callback in enumerate(callbacks):
# Remove visualizer callback as saving results takes time
if isinstance(callback, VisualizerCallback):
callbacks.pop(index)
break
# create new trainer object with LoadModel callback (assumes it is present)
trainer = Trainer(callbacks=callbacks, **config.trainer)
# Assumes the new model has LoadModel callback and the old one had ModelCheckpoint callback
new_results = trainer.test(model=loaded_model, datamodule=datamodule)[0]
assert np.isclose(
results["image_AUROC"], new_results["image_AUROC"]
), "Loaded model does not yield close performance results"
if config.dataset.task == "segmentation":
assert np.isclose(
results["pixel_AUROC"], new_results["pixel_AUROC"]
), "Loaded model does not yield close performance results"
|