File size: 7,595 Bytes
a0ae4d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
from __future__ import annotations
import logging
import os
from collections.abc import Hashable, Mapping
from typing import Any, Callable, Sequence
import numpy as np
import torch
import torch.nn
from ignite.engine import Engine
from ignite.metrics import Metric
from monai.config import KeysCollection
from monai.engines import SupervisedTrainer
from monai.engines.utils import get_devices_spec
from monai.inferers import Inferer
from monai.transforms.transform import MapTransform, Transform
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.optim.optimizer import Optimizer
# measure, _ = optional_import("skimage.measure", "0.14.2", min_version)
logger = logging.getLogger(__name__)
# distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt")
def get_device_list(n_gpu):
if type(n_gpu) is not list:
n_gpu = [n_gpu]
device_list = get_devices_spec(n_gpu)
if torch.cuda.device_count() >= max(n_gpu):
device_list = [d for d in device_list if d in n_gpu]
else:
logging.info(
"""Highest GPU ID provided in 'n_gpu' is larger than number of GPUs available, assigning GPUs starting from 0
to match n_gpu length of {}""".format(
len(n_gpu)
)
)
device_list = device_list[: len(n_gpu)]
return device_list
def supervised_trainer_multi_gpu(
max_epochs: int,
train_data_loader,
network: torch.nn.Module,
optimizer: Optimizer,
loss_function: Callable,
device: Sequence[str | torch.device] | None = None,
epoch_length: int | None = None,
non_blocking: bool = False,
iteration_update: Callable[[Engine, Any], Any] | None = None,
inferer: Inferer | None = None,
postprocessing: Transform | None = None,
key_train_metric: dict[str, Metric] | None = None,
additional_metrics: dict[str, Metric] | None = None,
train_handlers: Sequence | None = None,
amp: bool = False,
distributed: bool = False,
):
devices_ = device
if not device:
devices_ = get_devices_spec(device) # Using all devices i.e GPUs
# if device:
# if next(network.parameters()).device.index != device[0]:
# network.to(devices_[0])
# else:
# if next(network.parameters()).device.index != devices_[0].index:
# network.to(devices_[0])
#
net = network
if distributed:
if len(devices_) > 1:
raise ValueError(f"for distributed training, `devices` must contain only 1 GPU or CPU, but got {devices_}.")
net = DistributedDataParallel(net, device_ids=devices_)
elif len(devices_) > 1:
net = DataParallel(net, device_ids=devices_) # ,output_device=devices_[0])
return SupervisedTrainer(
device=devices_[0],
network=net,
optimizer=optimizer,
loss_function=loss_function,
max_epochs=max_epochs,
train_data_loader=train_data_loader,
epoch_length=epoch_length,
non_blocking=non_blocking,
iteration_update=iteration_update,
inferer=inferer,
postprocessing=postprocessing,
key_train_metric=key_train_metric,
additional_metrics=additional_metrics,
train_handlers=train_handlers,
amp=amp,
)
class SupervisedTrainerMGPU(SupervisedTrainer):
def __init__(
self,
max_epochs: int,
train_data_loader,
network: torch.nn.Module,
optimizer: Optimizer,
loss_function: Callable,
device: Sequence[str | torch.device] | None = None,
epoch_length: int | None = None,
non_blocking: bool = False,
iteration_update: Callable[[Engine, Any], Any] | None = None,
inferer: Inferer | None = None,
postprocessing: Transform | None = None,
key_train_metric: dict[str, Metric] | None = None,
additional_metrics: dict[str, Metric] | None = None,
train_handlers: Sequence | None = None,
amp: bool = False,
distributed: bool = False,
):
self.devices_ = device
if not device:
self.devices_ = get_devices_spec(device) # Using all devices i.e GPUs
# if device:
# if next(network.parameters()).device.index != device[0]:
# network.to(devices_[0])
# else:
# if next(network.parameters()).device.index != devices_[0].index:
# network.to(devices_[0])
#
self.net = network
if distributed:
if len(self.devices_) > 1:
raise ValueError(
f"for distributed training, `devices` must contain only 1 GPU or CPU, but got {self.devices_}."
)
self.net = DistributedDataParallel(self.net, device_ids=self.devices_)
elif len(self.devices_) > 1:
self.net = DataParallel(self.net, device_ids=self.devices_) # ,output_device=devices_[0])
super().__init__(
device=self.devices_[0],
network=self.net,
optimizer=optimizer,
loss_function=loss_function,
max_epochs=max_epochs,
train_data_loader=train_data_loader,
epoch_length=epoch_length,
non_blocking=non_blocking,
iteration_update=iteration_update,
inferer=inferer,
postprocessing=postprocessing,
key_train_metric=key_train_metric,
additional_metrics=additional_metrics,
train_handlers=train_handlers,
amp=amp,
)
class AddLabelNamesd(MapTransform):
def __init__(
self, keys: KeysCollection, label_names: dict[str, int] | None = None, allow_missing_keys: bool = False
):
"""
Normalize label values according to label names dictionary
Args:
keys: The ``keys`` parameter will be used to get and set the actual data item to transform
label_names: all label names
"""
super().__init__(keys, allow_missing_keys)
self.label_names = label_names or {}
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]:
d: dict = dict(data)
d["label_names"] = self.label_names
return d
class CopyFilenamesd(MapTransform):
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):
"""
Copy Filenames for future use
Args:
keys: The ``keys`` parameter will be used to get and set the actual data item to transform
"""
super().__init__(keys, allow_missing_keys)
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]:
d: dict = dict(data)
d["filename"] = os.path.basename(d["label"])
return d
class SplitPredsLabeld(MapTransform):
"""
Split preds and labels for individual evaluation
"""
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]:
d: dict = dict(data)
for key in self.key_iterator(d):
if key == "pred":
for idx, (key_label, _) in enumerate(d["label_names"].items()):
if key_label != "background":
d[f"pred_{key_label}"] = d[key][idx, ...][None]
d[f"label_{key_label}"] = d["label"][idx, ...][None]
elif key != "pred":
logger.info("This is only for pred key")
return d
|