|
from __future__ import annotations |
|
|
|
import logging |
|
import os |
|
from collections.abc import Hashable, Mapping |
|
from typing import Any, Callable, Sequence |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn |
|
from ignite.engine import Engine |
|
from ignite.metrics import Metric |
|
from monai.config import KeysCollection |
|
from monai.engines import SupervisedTrainer |
|
from monai.engines.utils import get_devices_spec |
|
from monai.inferers import Inferer |
|
from monai.transforms.transform import MapTransform, Transform |
|
from torch.nn.parallel import DataParallel, DistributedDataParallel |
|
from torch.optim.optimizer import Optimizer |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
def get_device_list(n_gpu): |
|
if type(n_gpu) is not list: |
|
n_gpu = [n_gpu] |
|
device_list = get_devices_spec(n_gpu) |
|
if torch.cuda.device_count() >= max(n_gpu): |
|
device_list = [d for d in device_list if d in n_gpu] |
|
else: |
|
logging.info( |
|
"""Highest GPU ID provided in 'n_gpu' is larger than number of GPUs available, assigning GPUs starting from 0 |
|
to match n_gpu length of {}""".format( |
|
len(n_gpu) |
|
) |
|
) |
|
device_list = device_list[: len(n_gpu)] |
|
return device_list |
|
|
|
|
|
def supervised_trainer_multi_gpu( |
|
max_epochs: int, |
|
train_data_loader, |
|
network: torch.nn.Module, |
|
optimizer: Optimizer, |
|
loss_function: Callable, |
|
device: Sequence[str | torch.device] | None = None, |
|
epoch_length: int | None = None, |
|
non_blocking: bool = False, |
|
iteration_update: Callable[[Engine, Any], Any] | None = None, |
|
inferer: Inferer | None = None, |
|
postprocessing: Transform | None = None, |
|
key_train_metric: dict[str, Metric] | None = None, |
|
additional_metrics: dict[str, Metric] | None = None, |
|
train_handlers: Sequence | None = None, |
|
amp: bool = False, |
|
distributed: bool = False, |
|
): |
|
devices_ = device |
|
if not device: |
|
devices_ = get_devices_spec(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
net = network |
|
if distributed: |
|
if len(devices_) > 1: |
|
raise ValueError(f"for distributed training, `devices` must contain only 1 GPU or CPU, but got {devices_}.") |
|
net = DistributedDataParallel(net, device_ids=devices_) |
|
elif len(devices_) > 1: |
|
net = DataParallel(net, device_ids=devices_) |
|
|
|
return SupervisedTrainer( |
|
device=devices_[0], |
|
network=net, |
|
optimizer=optimizer, |
|
loss_function=loss_function, |
|
max_epochs=max_epochs, |
|
train_data_loader=train_data_loader, |
|
epoch_length=epoch_length, |
|
non_blocking=non_blocking, |
|
iteration_update=iteration_update, |
|
inferer=inferer, |
|
postprocessing=postprocessing, |
|
key_train_metric=key_train_metric, |
|
additional_metrics=additional_metrics, |
|
train_handlers=train_handlers, |
|
amp=amp, |
|
) |
|
|
|
|
|
class SupervisedTrainerMGPU(SupervisedTrainer): |
|
def __init__( |
|
self, |
|
max_epochs: int, |
|
train_data_loader, |
|
network: torch.nn.Module, |
|
optimizer: Optimizer, |
|
loss_function: Callable, |
|
device: Sequence[str | torch.device] | None = None, |
|
epoch_length: int | None = None, |
|
non_blocking: bool = False, |
|
iteration_update: Callable[[Engine, Any], Any] | None = None, |
|
inferer: Inferer | None = None, |
|
postprocessing: Transform | None = None, |
|
key_train_metric: dict[str, Metric] | None = None, |
|
additional_metrics: dict[str, Metric] | None = None, |
|
train_handlers: Sequence | None = None, |
|
amp: bool = False, |
|
distributed: bool = False, |
|
): |
|
self.devices_ = device |
|
if not device: |
|
self.devices_ = get_devices_spec(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.net = network |
|
if distributed: |
|
if len(self.devices_) > 1: |
|
raise ValueError( |
|
f"for distributed training, `devices` must contain only 1 GPU or CPU, but got {self.devices_}." |
|
) |
|
self.net = DistributedDataParallel(self.net, device_ids=self.devices_) |
|
elif len(self.devices_) > 1: |
|
self.net = DataParallel(self.net, device_ids=self.devices_) |
|
|
|
super().__init__( |
|
device=self.devices_[0], |
|
network=self.net, |
|
optimizer=optimizer, |
|
loss_function=loss_function, |
|
max_epochs=max_epochs, |
|
train_data_loader=train_data_loader, |
|
epoch_length=epoch_length, |
|
non_blocking=non_blocking, |
|
iteration_update=iteration_update, |
|
inferer=inferer, |
|
postprocessing=postprocessing, |
|
key_train_metric=key_train_metric, |
|
additional_metrics=additional_metrics, |
|
train_handlers=train_handlers, |
|
amp=amp, |
|
) |
|
|
|
|
|
class AddLabelNamesd(MapTransform): |
|
def __init__( |
|
self, keys: KeysCollection, label_names: dict[str, int] | None = None, allow_missing_keys: bool = False |
|
): |
|
""" |
|
Normalize label values according to label names dictionary |
|
|
|
Args: |
|
keys: The ``keys`` parameter will be used to get and set the actual data item to transform |
|
label_names: all label names |
|
""" |
|
super().__init__(keys, allow_missing_keys) |
|
|
|
self.label_names = label_names or {} |
|
|
|
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]: |
|
d: dict = dict(data) |
|
d["label_names"] = self.label_names |
|
return d |
|
|
|
|
|
class CopyFilenamesd(MapTransform): |
|
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False): |
|
""" |
|
Copy Filenames for future use |
|
|
|
Args: |
|
keys: The ``keys`` parameter will be used to get and set the actual data item to transform |
|
""" |
|
super().__init__(keys, allow_missing_keys) |
|
|
|
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]: |
|
d: dict = dict(data) |
|
d["filename"] = os.path.basename(d["label"]) |
|
return d |
|
|
|
|
|
class SplitPredsLabeld(MapTransform): |
|
""" |
|
Split preds and labels for individual evaluation |
|
|
|
""" |
|
|
|
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]: |
|
d: dict = dict(data) |
|
for key in self.key_iterator(d): |
|
if key == "pred": |
|
for idx, (key_label, _) in enumerate(d["label_names"].items()): |
|
if key_label != "background": |
|
d[f"pred_{key_label}"] = d[key][idx, ...][None] |
|
d[f"label_{key_label}"] = d["label"][idx, ...][None] |
|
elif key != "pred": |
|
logger.info("This is only for pred key") |
|
return d |
|
|