File size: 7,595 Bytes
a0ae4d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
from __future__ import annotations

import logging
import os
from collections.abc import Hashable, Mapping
from typing import Any, Callable, Sequence

import numpy as np
import torch
import torch.nn
from ignite.engine import Engine
from ignite.metrics import Metric
from monai.config import KeysCollection
from monai.engines import SupervisedTrainer
from monai.engines.utils import get_devices_spec
from monai.inferers import Inferer
from monai.transforms.transform import MapTransform, Transform
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.optim.optimizer import Optimizer

# measure, _ = optional_import("skimage.measure", "0.14.2", min_version)

logger = logging.getLogger(__name__)

# distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt")


def get_device_list(n_gpu):
    if type(n_gpu) is not list:
        n_gpu = [n_gpu]
    device_list = get_devices_spec(n_gpu)
    if torch.cuda.device_count() >= max(n_gpu):
        device_list = [d for d in device_list if d in n_gpu]
    else:
        logging.info(
            """Highest GPU ID provided in 'n_gpu' is larger than number of GPUs available, assigning GPUs starting from 0
                 to match n_gpu length of {}""".format(
                len(n_gpu)
            )
        )
        device_list = device_list[: len(n_gpu)]
    return device_list


def supervised_trainer_multi_gpu(
    max_epochs: int,
    train_data_loader,
    network: torch.nn.Module,
    optimizer: Optimizer,
    loss_function: Callable,
    device: Sequence[str | torch.device] | None = None,
    epoch_length: int | None = None,
    non_blocking: bool = False,
    iteration_update: Callable[[Engine, Any], Any] | None = None,
    inferer: Inferer | None = None,
    postprocessing: Transform | None = None,
    key_train_metric: dict[str, Metric] | None = None,
    additional_metrics: dict[str, Metric] | None = None,
    train_handlers: Sequence | None = None,
    amp: bool = False,
    distributed: bool = False,
):
    devices_ = device
    if not device:
        devices_ = get_devices_spec(device)  # Using all devices i.e GPUs

    #     if device:
    #         if next(network.parameters()).device.index != device[0]:
    #             network.to(devices_[0])
    #     else:
    #         if next(network.parameters()).device.index != devices_[0].index:
    #             network.to(devices_[0])
    #
    net = network
    if distributed:
        if len(devices_) > 1:
            raise ValueError(f"for distributed training, `devices` must contain only 1 GPU or CPU, but got {devices_}.")
        net = DistributedDataParallel(net, device_ids=devices_)
    elif len(devices_) > 1:
        net = DataParallel(net, device_ids=devices_)  # ,output_device=devices_[0])

    return SupervisedTrainer(
        device=devices_[0],
        network=net,
        optimizer=optimizer,
        loss_function=loss_function,
        max_epochs=max_epochs,
        train_data_loader=train_data_loader,
        epoch_length=epoch_length,
        non_blocking=non_blocking,
        iteration_update=iteration_update,
        inferer=inferer,
        postprocessing=postprocessing,
        key_train_metric=key_train_metric,
        additional_metrics=additional_metrics,
        train_handlers=train_handlers,
        amp=amp,
    )


class SupervisedTrainerMGPU(SupervisedTrainer):
    def __init__(
        self,
        max_epochs: int,
        train_data_loader,
        network: torch.nn.Module,
        optimizer: Optimizer,
        loss_function: Callable,
        device: Sequence[str | torch.device] | None = None,
        epoch_length: int | None = None,
        non_blocking: bool = False,
        iteration_update: Callable[[Engine, Any], Any] | None = None,
        inferer: Inferer | None = None,
        postprocessing: Transform | None = None,
        key_train_metric: dict[str, Metric] | None = None,
        additional_metrics: dict[str, Metric] | None = None,
        train_handlers: Sequence | None = None,
        amp: bool = False,
        distributed: bool = False,
    ):
        self.devices_ = device
        if not device:
            self.devices_ = get_devices_spec(device)  # Using all devices i.e GPUs

        #     if device:
        #         if next(network.parameters()).device.index != device[0]:
        #             network.to(devices_[0])
        #     else:
        #         if next(network.parameters()).device.index != devices_[0].index:
        #             network.to(devices_[0])
        #
        self.net = network
        if distributed:
            if len(self.devices_) > 1:
                raise ValueError(
                    f"for distributed training, `devices` must contain only 1 GPU or CPU, but got {self.devices_}."
                )
            self.net = DistributedDataParallel(self.net, device_ids=self.devices_)
        elif len(self.devices_) > 1:
            self.net = DataParallel(self.net, device_ids=self.devices_)  # ,output_device=devices_[0])

        super().__init__(
            device=self.devices_[0],
            network=self.net,
            optimizer=optimizer,
            loss_function=loss_function,
            max_epochs=max_epochs,
            train_data_loader=train_data_loader,
            epoch_length=epoch_length,
            non_blocking=non_blocking,
            iteration_update=iteration_update,
            inferer=inferer,
            postprocessing=postprocessing,
            key_train_metric=key_train_metric,
            additional_metrics=additional_metrics,
            train_handlers=train_handlers,
            amp=amp,
        )


class AddLabelNamesd(MapTransform):
    def __init__(
        self, keys: KeysCollection, label_names: dict[str, int] | None = None, allow_missing_keys: bool = False
    ):
        """
        Normalize label values according to label names dictionary

        Args:
            keys: The ``keys`` parameter will be used to get and set the actual data item to transform
            label_names: all label names
        """
        super().__init__(keys, allow_missing_keys)

        self.label_names = label_names or {}

    def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]:
        d: dict = dict(data)
        d["label_names"] = self.label_names
        return d


class CopyFilenamesd(MapTransform):
    def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):
        """
        Copy Filenames for future use

        Args:
            keys: The ``keys`` parameter will be used to get and set the actual data item to transform
        """
        super().__init__(keys, allow_missing_keys)

    def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]:
        d: dict = dict(data)
        d["filename"] = os.path.basename(d["label"])
        return d


class SplitPredsLabeld(MapTransform):
    """
    Split preds and labels for individual evaluation

    """

    def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]:
        d: dict = dict(data)
        for key in self.key_iterator(d):
            if key == "pred":
                for idx, (key_label, _) in enumerate(d["label_names"].items()):
                    if key_label != "background":
                        d[f"pred_{key_label}"] = d[key][idx, ...][None]
                        d[f"label_{key_label}"] = d["label"][idx, ...][None]
            elif key != "pred":
                logger.info("This is only for pred key")
        return d