File size: 5,752 Bytes
c8c12e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""Common helpers for both nightly and pre-merge model tests."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

import os
from typing import Dict, List, Tuple, Union

import numpy as np
from omegaconf import DictConfig, ListConfig
from pytorch_lightning import LightningDataModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from anomalib.config import get_configurable_parameters, update_nncf_config
from anomalib.data import get_datamodule
from anomalib.models import get_model
from anomalib.models.components import AnomalyModule
from anomalib.utils.callbacks import VisualizerCallback, get_callbacks


def setup_model_train(
    model_name: str,
    dataset_path: str,
    project_path: str,
    nncf: bool,
    category: str,
    score_type: str = None,
    weight_file: str = "weights/model.ckpt",
    fast_run: bool = False,
    device: Union[List[int], int] = [0],
) -> Tuple[Union[DictConfig, ListConfig], LightningDataModule, AnomalyModule, Trainer]:
    """Train the model based on the parameters passed.

    Args:
        model_name (str): Name of the model to train.
        dataset_path (str): Location of the dataset.
        project_path (str): Path to temporary project folder.
        nncf (bool): Add nncf callback.
        category (str): Category to train on.
        score_type (str, optional): Only used for DFM. Defaults to None.
        weight_file (str, optional): Path to weight file.
        fast_run (bool, optional): If set to true, the model trains for only 1 epoch. We train for one epoch as
            this ensures that both anomalous and non-anomalous images are present in the validation step.
        device (List[int], int, optional): Select which device you want to train the model on. Defaults to first GPU.

    Returns:
        Tuple[DictConfig, LightningDataModule, AnomalyModule, Trainer]: config, datamodule, trained model, trainer
    """
    config = get_configurable_parameters(model_name=model_name)
    if score_type is not None:
        config.model.score_type = score_type
    config.project.seed = 42
    config.dataset.category = category
    config.dataset.path = dataset_path
    config.project.log_images_to = []
    config.trainer.gpus = device

    # If weight file is empty, remove the key from config
    if "weight_file" in config.model.keys() and weight_file == "":
        config.model.pop("weight_file")
    else:
        config.model.weight_file = weight_file if not fast_run else "weights/last.ckpt"

    if nncf:
        config.optimization.nncf.apply = True
        config = update_nncf_config(config)
        config.init_weights = None

    # reassign project path as config is updated in `update_config_for_nncf`
    config.project.path = project_path

    datamodule = get_datamodule(config)
    model = get_model(config)

    callbacks = get_callbacks(config)

    # Force model checkpoint to create checkpoint after first epoch
    if fast_run == True:
        for index, callback in enumerate(callbacks):
            if isinstance(callback, ModelCheckpoint):
                callbacks.pop(index)
                break
        model_checkpoint = ModelCheckpoint(
            dirpath=os.path.join(config.project.path, "weights"),
            filename="last",
            monitor=None,
            mode="max",
            save_last=True,
            auto_insert_metric_name=False,
        )
        callbacks.append(model_checkpoint)

    for index, callback in enumerate(callbacks):
        if isinstance(callback, VisualizerCallback):
            callbacks.pop(index)
            break

    # Train the model.
    if fast_run:
        config.trainer.max_epochs = 1
        config.trainer.check_val_every_n_epoch = 1

    trainer = Trainer(callbacks=callbacks, **config.trainer)
    trainer.fit(model=model, datamodule=datamodule)
    return config, datamodule, model, trainer


def model_load_test(config: Union[DictConfig, ListConfig], datamodule: LightningDataModule, results: Dict):
    """Create a new model based on the weights specified in config.

    Args:
        config ([Union[DictConfig, ListConfig]): Model config.
        datamodule (LightningDataModule): Dataloader
        results (Dict): Results from original model.

    """
    loaded_model = get_model(config)  # get new model

    callbacks = get_callbacks(config)

    for index, callback in enumerate(callbacks):
        # Remove visualizer callback as saving results takes time
        if isinstance(callback, VisualizerCallback):
            callbacks.pop(index)
            break

    # create new trainer object with LoadModel callback (assumes it is present)
    trainer = Trainer(callbacks=callbacks, **config.trainer)
    # Assumes the new model has LoadModel callback and the old one had ModelCheckpoint callback
    new_results = trainer.test(model=loaded_model, datamodule=datamodule)[0]
    assert np.isclose(
        results["image_AUROC"], new_results["image_AUROC"]
    ), "Loaded model does not yield close performance results"
    if config.dataset.task == "segmentation":
        assert np.isclose(
            results["pixel_AUROC"], new_results["pixel_AUROC"]
        ), "Loaded model does not yield close performance results"