|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import csv |
|
import gc |
|
import logging |
|
import os |
|
import shutil |
|
import sys |
|
import time |
|
from collections import OrderedDict |
|
from datetime import datetime |
|
|
|
import monai.transforms as mt |
|
import numpy as np |
|
import torch |
|
import torch.distributed as dist |
|
import yaml |
|
from monai.apps import get_logger |
|
from monai.auto3dseg.utils import datafold_read |
|
from monai.bundle import BundleWorkflow, ConfigParser |
|
from monai.config import print_config |
|
from monai.data import DataLoader, Dataset, decollate_batch |
|
from monai.metrics import CumulativeAverage |
|
from monai.utils import ( |
|
BundleProperty, |
|
ImageMetaKey, |
|
convert_to_dst_type, |
|
ensure_tuple, |
|
look_up_option, |
|
optional_import, |
|
set_determinism, |
|
) |
|
from torch.cuda.amp import GradScaler, autocast |
|
from torch.utils.data import WeightedRandomSampler |
|
from torch.utils.data.distributed import DistributedSampler |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
mlflow, mlflow_is_imported = optional_import("mlflow") |
|
|
|
|
|
if __package__ in (None, ""): |
|
from cell_distributed_weighted_sampler import DistributedWeightedSampler |
|
from components import LabelsToFlows, LoadTiffd, LogitsToLabels |
|
from utils import LOGGING_CONFIG, parsing_bundle_config |
|
else: |
|
from .cell_distributed_weighted_sampler import DistributedWeightedSampler |
|
from .components import LabelsToFlows, LoadTiffd, LogitsToLabels |
|
from .utils import LOGGING_CONFIG, parsing_bundle_config |
|
|
|
|
|
logger = get_logger("VistaCell") |
|
|
|
|
|
class VistaCell(BundleWorkflow): |
|
""" |
|
Primary vista model training workflow that extends |
|
monai.bundle.BundleWorkflow for cell segmentation. |
|
""" |
|
|
|
def __init__(self, config_file=None, meta_file=None, logging_file=None, workflow_type="train", **override): |
|
""" |
|
config_file can be one or a list of config files. |
|
the rest key-values in the `override` are to override config content. |
|
""" |
|
|
|
parser = parsing_bundle_config(config_file, logging_file=logging_file, meta_file=meta_file) |
|
parser.update(pairs=override) |
|
|
|
mode = parser.get("mode", None) |
|
if mode is not None: |
|
workflow_type = mode |
|
else: |
|
mode = workflow_type |
|
super().__init__(workflow_type=workflow_type) |
|
self._props = {} |
|
self._set_props = {} |
|
self.parser = parser |
|
|
|
self.rank = int(os.getenv("LOCAL_RANK", "0")) |
|
self.global_rank = int(os.getenv("RANK", "0")) |
|
self.is_distributed = dist.is_available() and dist.is_initialized() |
|
|
|
|
|
if dist.is_torchelastic_launched() or ( |
|
os.getenv("NGC_ARRAY_SIZE") is not None and int(os.getenv("NGC_ARRAY_SIZE")) > 1 |
|
): |
|
if dist.is_available(): |
|
dist.init_process_group(backend="nccl", init_method="env://") |
|
|
|
self.is_distributed = dist.is_available() and dist.is_initialized() |
|
|
|
torch.cuda.set_device(self.config("device")) |
|
dist.barrier() |
|
|
|
else: |
|
self.is_distributed = False |
|
|
|
if self.global_rank == 0 and self.config("ckpt_path") and not os.path.exists(self.config("ckpt_path")): |
|
os.makedirs(self.config("ckpt_path"), exist_ok=True) |
|
|
|
if self.rank == 0: |
|
|
|
_log_file = self.config("log_output_file", "vista_cell.log") |
|
_log_file_dir = os.path.dirname(_log_file) |
|
if _log_file_dir and not os.path.exists(_log_file_dir): |
|
os.makedirs(_log_file_dir, exist_ok=True) |
|
|
|
print_config() |
|
|
|
if self.is_distributed: |
|
dist.barrier() |
|
|
|
seed = self.config("seed", None) |
|
if seed is not None: |
|
set_determinism(seed) |
|
logger.info(f"set determinism seed: {self.config('seed', None)}") |
|
elif torch.cuda.is_available(): |
|
torch.backends.cudnn.benchmark = True |
|
logger.info("No seed provided, using cudnn.benchmark for performance.") |
|
|
|
if os.path.exists(self.config("ckpt_path")): |
|
self.parser.export_config_file( |
|
self.parser.config, |
|
os.path.join(self.config("ckpt_path"), "working.yaml"), |
|
fmt="yaml", |
|
default_flow_style=None, |
|
) |
|
|
|
self.add_property("network", required=True) |
|
self.add_property("train_loader", required=True) |
|
self.add_property("val_dataset", required=False) |
|
self.add_property("val_loader", required=False) |
|
self.add_property("val_preprocessing", required=False) |
|
self.add_property("train_sampler", required=True) |
|
self.add_property("val_sampler", required=True) |
|
self.add_property("mode", required=False) |
|
|
|
|
|
self.evaluator = None |
|
|
|
def _set_property(self, name, property, value): |
|
|
|
self._set_props[name] = value |
|
|
|
def _get_property(self, name, property): |
|
""" |
|
The customized bundle workflow must implement required properties in: |
|
https://github.com/Project-MONAI/MONAI/blob/dev/monai/bundle/properties.py. |
|
""" |
|
if name in self._set_props: |
|
self._props[name] = self._set_props[name] |
|
return self._props[name] |
|
if name in self._props: |
|
return self._props[name] |
|
try: |
|
value = getattr(self, f"get_{name}")() |
|
except AttributeError as err: |
|
if property[BundleProperty.REQUIRED]: |
|
raise ValueError( |
|
f"Property '{name}' is required by the bundle format, " |
|
f"but the method 'get_{name}' is not implemented." |
|
) from err |
|
raise AttributeError from err |
|
self._props[name] = value |
|
return value |
|
|
|
def config(self, name, default="null", **kwargs): |
|
"""read the parsed content (evaluate the expression) from the config file.""" |
|
if default != "null": |
|
return self.parser.get_parsed_content(name, default=default, **kwargs) |
|
return self.parser.get_parsed_content(name, **kwargs) |
|
|
|
def initialize(self): |
|
_log_file = self.config("log_output_file", "vista_cell.log") |
|
if _log_file is None: |
|
LOGGING_CONFIG["loggers"]["VistaCell"]["handlers"].remove("file") |
|
LOGGING_CONFIG["handlers"].pop("file", None) |
|
else: |
|
LOGGING_CONFIG["handlers"]["file"]["filename"] = _log_file |
|
logging.config.dictConfig(LOGGING_CONFIG) |
|
|
|
def get_mode(self): |
|
mode_str = self.config("mode", self.workflow_type) |
|
return look_up_option(mode_str, ("train", "training", "infer", "inference", "eval", "evaluation")) |
|
|
|
def run(self): |
|
if str(self.mode).startswith("train"): |
|
return self.train() |
|
if str(self.mode).startswith("infer"): |
|
return self.infer() |
|
return self.validate() |
|
|
|
def finalize(self): |
|
if self.is_distributed: |
|
dist.destroy_process_group() |
|
set_determinism(None) |
|
|
|
def get_network_def(self): |
|
return self.config("network_def") |
|
|
|
def get_network(self): |
|
pretrained_ckpt_name = self.config("pretrained_ckpt_name", None) |
|
pretrained_ckpt_path = self.config("pretrained_ckpt_path", None) |
|
if pretrained_ckpt_name is not None and pretrained_ckpt_path is None: |
|
|
|
pretrained_ckpt_path = os.path.join(self.config("ckpt_path"), pretrained_ckpt_name) |
|
|
|
if pretrained_ckpt_path is not None and not os.path.exists(pretrained_ckpt_path): |
|
logger.info(f"Pretrained checkpoint {pretrained_ckpt_path} not found.") |
|
raise ValueError(f"Pretrained checkpoint {pretrained_ckpt_path} not found.") |
|
|
|
if pretrained_ckpt_path is not None and os.path.exists(pretrained_ckpt_path): |
|
|
|
if "checkpoint" in self.parser.config["network_def"]: |
|
self.parser.config["network_def"]["checkpoint"] = None |
|
model = self.config("network") |
|
self.checkpoint_load(ckpt=pretrained_ckpt_path, model=model) |
|
else: |
|
model = self.config("network") |
|
|
|
if self.config("channels_last", False): |
|
model = model.to(memory_format=torch.channels_last) |
|
|
|
if self.is_distributed: |
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
|
|
|
if self.config("compile", False): |
|
model = torch.compile(model) |
|
|
|
if self.is_distributed: |
|
model = torch.nn.parallel.DistributedDataParallel( |
|
module=model, |
|
device_ids=[self.rank], |
|
output_device=self.rank, |
|
find_unused_parameters=self.config("find_unused_parameters", False), |
|
) |
|
|
|
pytorch_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
logger.info(f"total parameters count {pytorch_params} distributed {self.is_distributed}") |
|
return model |
|
|
|
def get_train_dataset_data(self): |
|
train_files, valid_files = [], [] |
|
dataset_data = self.config("train#dataset#data") |
|
val_key = None |
|
if isinstance(dataset_data, dict): |
|
val_key = dataset_data.get("key", None) |
|
data_list_files = dataset_data["data_list_files"] |
|
|
|
if isinstance(data_list_files, str): |
|
data_list_files = ConfigParser.load_config_file( |
|
data_list_files |
|
) |
|
else: |
|
data_list_files = ensure_tuple(data_list_files) |
|
|
|
if self.global_rank == 0: |
|
print("Using data_list_files ", data_list_files) |
|
|
|
for idx, d in enumerate(data_list_files): |
|
logger.info(f"adding datalist ({idx}): {d['datalist']}") |
|
t, v = datafold_read(datalist=d["datalist"], basedir=d["basedir"], fold=self.config("fold")) |
|
|
|
if val_key is not None: |
|
v, _ = datafold_read(datalist=d["datalist"], basedir=d["basedir"], fold=-1, key=val_key) |
|
|
|
for item in t: |
|
item["datalist_id"] = idx |
|
item["datalist_count"] = len(t) |
|
for item in v: |
|
item["datalist_id"] = idx |
|
item["datalist_count"] = len(v) |
|
train_files.extend(t) |
|
valid_files.extend(v) |
|
|
|
if self.config("quick", False): |
|
logger.info("quick_data") |
|
train_files = train_files[:8] |
|
valid_files = valid_files[:7] |
|
if not valid_files: |
|
logger.warning("No validation data found.") |
|
return train_files, valid_files |
|
|
|
def read_val_datalists(self, section="validate", data_list_files=None, val_key=None, merge=True): |
|
"""read the corresponding folds of the datalist for validation or inference""" |
|
dataset_data = self.config(f"{section}#dataset#data") |
|
|
|
if isinstance(dataset_data, list): |
|
return dataset_data |
|
|
|
if data_list_files is None: |
|
data_list_files = dataset_data["data_list_files"] |
|
|
|
if isinstance(data_list_files, str): |
|
data_list_files = ConfigParser.load_config_file( |
|
data_list_files |
|
) |
|
else: |
|
data_list_files = ensure_tuple(data_list_files) |
|
|
|
if val_key is None: |
|
val_key = dataset_data.get("key", None) |
|
|
|
val_files, idx = [], 0 |
|
for d in data_list_files: |
|
if val_key is not None: |
|
v_files, _ = datafold_read(datalist=d["datalist"], basedir=d["basedir"], fold=-1, key=val_key) |
|
else: |
|
_, v_files = datafold_read(datalist=d["datalist"], basedir=d["basedir"], fold=self.config("fold")) |
|
logger.info(f"adding datalist ({idx} -- {val_key}): {d['datalist']} {len(v_files)}") |
|
if merge: |
|
val_files.extend(v_files) |
|
else: |
|
val_files.append(v_files) |
|
idx += 1 |
|
|
|
if self.config("quick", False): |
|
logger.info("quick_data") |
|
val_files = val_files[:8] if merge else [val_files[0][:8]] |
|
return val_files |
|
|
|
def get_train_preprocessing(self): |
|
roi_size = self.config("train#dataset#preprocessing#roi_size") |
|
|
|
train_xforms = [] |
|
train_xforms.append(LoadTiffd(keys=["image", "label"])) |
|
train_xforms.append(mt.EnsureTyped(keys=["image", "label"], data_type="tensor", dtype=torch.float)) |
|
if self.config("prescale", True): |
|
print("Prescaling images to 0..1") |
|
train_xforms.append(mt.ScaleIntensityd(keys="image", minv=0, maxv=1, channel_wise=True)) |
|
train_xforms.append(mt.ScaleIntensityd(keys="image", minv=0, maxv=1, channel_wise=True)) |
|
train_xforms.append( |
|
mt.ScaleIntensityRangePercentilesd( |
|
keys="image", lower=1, upper=99, b_min=0.0, b_max=1.0, channel_wise=True, clip=True |
|
) |
|
) |
|
train_xforms.append(mt.SpatialPadd(keys=["image", "label"], spatial_size=roi_size)) |
|
train_xforms.append( |
|
mt.RandSpatialCropd(keys=["image", "label"], roi_size=roi_size) |
|
) |
|
|
|
|
|
train_xforms.extend( |
|
[ |
|
mt.RandAffined( |
|
keys=["image", "label"], |
|
prob=0.5, |
|
rotate_range=np.pi, |
|
scale_range=[-0.5, 0.5], |
|
mode=["bilinear", "nearest"], |
|
spatial_size=roi_size, |
|
cache_grid=True, |
|
padding_mode="border", |
|
), |
|
mt.RandAxisFlipd(keys=["image", "label"], prob=0.5), |
|
mt.RandGaussianNoised(keys=["image"], prob=0.25, mean=0, std=0.1), |
|
mt.RandAdjustContrastd(keys=["image"], prob=0.25, gamma=(1, 2)), |
|
mt.RandGaussianSmoothd(keys=["image"], prob=0.25, sigma_x=(1, 2)), |
|
mt.RandHistogramShiftd(keys=["image"], prob=0.25, num_control_points=3), |
|
mt.RandGaussianSharpend(keys=["image"], prob=0.25), |
|
] |
|
) |
|
|
|
train_xforms.append( |
|
LabelsToFlows(keys="label", flow_key="flow") |
|
) |
|
|
|
return train_xforms |
|
|
|
def get_val_preprocessing(self): |
|
val_xforms = [] |
|
val_xforms.append(LoadTiffd(keys=["image", "label"], allow_missing_keys=True)) |
|
val_xforms.append( |
|
mt.EnsureTyped(keys=["image", "label"], data_type="tensor", dtype=torch.float, allow_missing_keys=True) |
|
) |
|
|
|
if self.config("prescale", True): |
|
print("Prescaling val images to 0..1") |
|
val_xforms.append(mt.ScaleIntensityd(keys="image", minv=0, maxv=1, channel_wise=True)) |
|
|
|
val_xforms.append( |
|
mt.ScaleIntensityRangePercentilesd( |
|
keys="image", lower=1, upper=99, b_min=0.0, b_max=1.0, channel_wise=True, clip=True |
|
) |
|
) |
|
val_xforms.append(LabelsToFlows(keys="label", flow_key="flow", allow_missing_keys=True)) |
|
|
|
return val_xforms |
|
|
|
def get_train_dataset(self): |
|
train_dataset_data = self.config("train#dataset#data") |
|
if isinstance(train_dataset_data, list): |
|
train_files = train_dataset_data |
|
else: |
|
train_files, _ = self.train_dataset_data |
|
logger.info(f"train files {len(train_files)}") |
|
return Dataset(data=train_files, transform=mt.Compose(self.train_preprocessing)) |
|
|
|
def get_val_dataset(self): |
|
"""this is to be used for validation during training""" |
|
val_dataset_data = self.config("validate#dataset#data") |
|
if isinstance(val_dataset_data, list): |
|
valid_files = val_dataset_data |
|
else: |
|
_, valid_files = self.train_dataset_data |
|
return Dataset(data=valid_files, transform=mt.Compose(self.val_preprocessing)) |
|
|
|
def set_val_datalist(self, datalist_py): |
|
self.parser["validate#dataset#data"] = datalist_py |
|
self._props.pop("val_loader", None) |
|
self._props.pop("val_dataset", None) |
|
self._props.pop("val_sampler", None) |
|
|
|
def get_train_sampler(self): |
|
if self.config("use_weighted_sampler", False): |
|
data = self.train_dataset.data |
|
logger.info(f"Using weighted sampler, first item {data[0]}") |
|
sample_weights = 1.0 / torch.as_tensor( |
|
[item.get("datalist_count", 1.0) for item in data], dtype=torch.float |
|
) |
|
|
|
|
|
num_samples_per_epoch = self.config("num_samples_per_epoch", None) |
|
if num_samples_per_epoch is None: |
|
num_samples_per_epoch = len(data) |
|
logger.warning( |
|
"We are using weighted random sampler, but num_samples_per_epoch is not provided, " |
|
f"so using {num_samples_per_epoch} full data length as a workaround!" |
|
) |
|
|
|
if self.is_distributed: |
|
return DistributedWeightedSampler( |
|
self.train_dataset, shuffle=True, weights=sample_weights, num_samples=num_samples_per_epoch |
|
) |
|
return WeightedRandomSampler(weights=sample_weights, num_samples=num_samples_per_epoch) |
|
|
|
if self.is_distributed: |
|
return DistributedSampler(self.train_dataset, shuffle=True) |
|
return None |
|
|
|
def get_val_sampler(self): |
|
if self.is_distributed: |
|
return DistributedSampler(self.val_dataset, shuffle=False) |
|
return None |
|
|
|
def get_train_loader(self): |
|
sampler = self.train_sampler |
|
return DataLoader( |
|
self.train_dataset, |
|
batch_size=self.config("train#batch_size"), |
|
shuffle=(sampler is None), |
|
sampler=sampler, |
|
pin_memory=True, |
|
num_workers=self.config("train#num_workers"), |
|
) |
|
|
|
def get_val_loader(self): |
|
sampler = self.val_sampler |
|
return DataLoader( |
|
self.val_dataset, |
|
batch_size=self.config("validate#batch_size"), |
|
shuffle=False, |
|
sampler=sampler, |
|
pin_memory=True, |
|
num_workers=self.config("validate#num_workers"), |
|
) |
|
|
|
def train(self): |
|
config = self.config |
|
distributed = self.is_distributed |
|
sliding_inferrer = config("inferer#sliding_inferer") |
|
use_amp = config("amp") |
|
|
|
amp_dtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[ |
|
config("amp_dtype") |
|
] |
|
if amp_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): |
|
amp_dtype = torch.float16 |
|
logger.warning( |
|
"bfloat16 dtype is not support on your device, changing to float16, use --amp_dtype=float16 to set manually" |
|
) |
|
|
|
use_gradscaler = use_amp and amp_dtype == torch.float16 |
|
logger.info(f"Using grad scaler {use_gradscaler} amp_dtype {amp_dtype} use_amp {use_amp}") |
|
grad_scaler = GradScaler(enabled=use_gradscaler) |
|
|
|
loss_function = config("loss_function") |
|
acc_function = config("key_metric") |
|
|
|
ckpt_path = config("ckpt_path") |
|
channels_last = config("channels_last") |
|
|
|
num_epochs_per_saving = config("train#trainer#num_epochs_per_saving") |
|
num_epochs_per_validation = config("train#trainer#num_epochs_per_validation") |
|
num_epochs = config("train#trainer#max_epochs") |
|
val_schedule_list = self.schedule_validation_epochs( |
|
num_epochs=num_epochs, num_epochs_per_validation=num_epochs_per_validation |
|
) |
|
logger.info(f"Scheduling validation loops at epochs: {val_schedule_list}") |
|
|
|
train_loader = self.train_loader |
|
val_loader = self.val_loader |
|
optimizer = config("optimizer") |
|
model = self.network |
|
|
|
tb_writer = None |
|
csv_path = progress_path = None |
|
|
|
if self.global_rank == 0 and ckpt_path is not None: |
|
|
|
progress_path = os.path.join(ckpt_path, "progress.yaml") |
|
|
|
tb_writer = SummaryWriter(log_dir=ckpt_path) |
|
logger.info(f"Writing Tensorboard logs to {tb_writer.log_dir}") |
|
|
|
if mlflow_is_imported: |
|
if config("mlflow_tracking_uri", None) is not None: |
|
mlflow.set_tracking_uri(config("mlflow_tracking_uri")) |
|
mlflow.set_experiment("vista2d") |
|
|
|
mlflow_run_name = config("mlflow_run_name", f'vista2d train fold{config("fold")}') |
|
mlflow.start_run( |
|
run_name=mlflow_run_name, log_system_metrics=config("mlflow_log_system_metrics", False) |
|
) |
|
mlflow.log_params(self.parser.config) |
|
mlflow.log_dict(self.parser.config, "hyper_parameters.yaml") |
|
|
|
csv_path = os.path.join(ckpt_path, "accuracy_history.csv") |
|
self.save_history_csv( |
|
csv_path=csv_path, |
|
header=["epoch", "metric", "loss", "iter", "time", "train_time", "validation_time", "epoch_time"], |
|
) |
|
|
|
do_torch_save = ( |
|
(self.global_rank == 0) and ckpt_path and config("ckpt_save") and not config("train#skip", False) |
|
) |
|
best_ckpt_path = os.path.join(ckpt_path, "model.pt") |
|
intermediate_ckpt_path = os.path.join(ckpt_path, "model_final.pt") |
|
|
|
best_metric = float(config("best_metric", -1)) |
|
start_epoch = config("start_epoch", 0) |
|
best_metric_epoch = -1 |
|
pre_loop_time = time.time() |
|
report_num_epochs = num_epochs |
|
train_time = validation_time = 0 |
|
val_acc_history = [] |
|
|
|
if start_epoch > 0: |
|
val_schedule_list = [v for v in val_schedule_list if v >= start_epoch] |
|
if len(val_schedule_list) == 0: |
|
val_schedule_list = [start_epoch] |
|
print(f"adjusted schedule_list {val_schedule_list}") |
|
|
|
logger.info( |
|
f"Using num_epochs => {num_epochs}\n " |
|
f"Using start_epoch => {start_epoch}\n " |
|
f"batch_size => {config('train#batch_size')} \n " |
|
f"num_warmup_epochs => {config('train#trainer#num_warmup_epochs')} \n " |
|
) |
|
|
|
lr_scheduler = config("lr_scheduler") |
|
if lr_scheduler is not None and start_epoch > 0: |
|
lr_scheduler.last_epoch = start_epoch |
|
|
|
range_num_epochs = range(start_epoch, num_epochs) |
|
|
|
if distributed: |
|
dist.barrier() |
|
|
|
if self.global_rank == 0 and tb_writer is not None and mlflow_is_imported and mlflow.is_tracking_uri_set(): |
|
mlflow.log_param("len_train_set", len(train_loader.dataset)) |
|
mlflow.log_param("len_val_set", len(val_loader.dataset)) |
|
|
|
for epoch in range_num_epochs: |
|
report_epoch = epoch |
|
|
|
if distributed: |
|
if isinstance(train_loader.sampler, DistributedSampler): |
|
train_loader.sampler.set_epoch(epoch) |
|
dist.barrier() |
|
|
|
epoch_time = start_time = time.time() |
|
|
|
train_loss, train_acc = 0, 0 |
|
|
|
if not config("train#skip", False): |
|
train_loss, train_acc = self.train_epoch( |
|
model=model, |
|
train_loader=train_loader, |
|
optimizer=optimizer, |
|
loss_function=loss_function, |
|
acc_function=acc_function, |
|
grad_scaler=grad_scaler, |
|
epoch=report_epoch, |
|
rank=self.rank, |
|
global_rank=self.global_rank, |
|
num_epochs=report_num_epochs, |
|
use_amp=use_amp, |
|
amp_dtype=amp_dtype, |
|
channels_last=channels_last, |
|
device=config("device"), |
|
) |
|
|
|
train_time = time.time() - start_time |
|
logger.info( |
|
f"Latest training {report_epoch}/{report_num_epochs - 1} " |
|
f"loss: {train_loss:.4f} time {train_time:.2f}s " |
|
f"lr: {optimizer.param_groups[0]['lr']:.4e}" |
|
) |
|
|
|
if self.global_rank == 0 and tb_writer is not None: |
|
tb_writer.add_scalar("train/loss", train_loss, report_epoch) |
|
|
|
if mlflow_is_imported and mlflow.is_tracking_uri_set(): |
|
mlflow.log_metric("train/loss", train_loss, step=report_epoch) |
|
mlflow.log_metric("train/epoch_time", train_time, step=report_epoch) |
|
|
|
|
|
val_acc_mean = -1 |
|
if ( |
|
len(val_schedule_list) > 0 |
|
and epoch + 1 >= val_schedule_list[0] |
|
and val_loader is not None |
|
and len(val_loader) > 0 |
|
): |
|
val_schedule_list.pop(0) |
|
|
|
start_time = time.time() |
|
torch.cuda.empty_cache() |
|
|
|
val_loss, val_acc = self.val_epoch( |
|
model=model, |
|
val_loader=val_loader, |
|
sliding_inferrer=sliding_inferrer, |
|
loss_function=loss_function, |
|
acc_function=acc_function, |
|
epoch=report_epoch, |
|
rank=self.rank, |
|
global_rank=self.global_rank, |
|
num_epochs=report_num_epochs, |
|
use_amp=use_amp, |
|
amp_dtype=amp_dtype, |
|
channels_last=channels_last, |
|
device=config("device"), |
|
) |
|
|
|
torch.cuda.empty_cache() |
|
validation_time = time.time() - start_time |
|
|
|
val_acc_mean = float(np.mean(val_acc)) |
|
val_acc_history.append((report_epoch, val_acc_mean)) |
|
|
|
if self.global_rank == 0: |
|
logger.info( |
|
f"Latest validation {report_epoch}/{report_num_epochs - 1} " |
|
f"loss: {val_loss:.4f} acc_avg: {val_acc_mean:.4f} acc: {val_acc} time: {validation_time:.2f}s" |
|
) |
|
|
|
if tb_writer is not None: |
|
tb_writer.add_scalar("val/acc", val_acc_mean, report_epoch) |
|
tb_writer.add_scalar("val/loss", val_loss, report_epoch) |
|
if mlflow_is_imported and mlflow.is_tracking_uri_set(): |
|
mlflow.log_metric("val/acc", val_acc_mean, step=report_epoch) |
|
mlflow.log_metric("val/epoch_time", validation_time, step=report_epoch) |
|
|
|
timing_dict = { |
|
"time": f"{(time.time() - pre_loop_time) / 3600:.2f} hr", |
|
"train_time": f"{train_time:.2f}s", |
|
"validation_time": f"{validation_time:.2f}s", |
|
"epoch_time": f"{time.time() - epoch_time:.2f}s", |
|
} |
|
|
|
if val_acc_mean > best_metric: |
|
logger.info(f"New best metric ({best_metric:.6f} --> {val_acc_mean:.6f}). ") |
|
best_metric, best_metric_epoch = val_acc_mean, report_epoch |
|
save_time = 0 |
|
if do_torch_save: |
|
save_time = self.checkpoint_save( |
|
ckpt=best_ckpt_path, model=model, epoch=best_metric_epoch, best_metric=best_metric |
|
) |
|
|
|
if progress_path is not None: |
|
self.save_progress_yaml( |
|
progress_path=progress_path, |
|
ckpt=best_ckpt_path if do_torch_save else None, |
|
best_avg_score_epoch=best_metric_epoch, |
|
best_avg_score=best_metric, |
|
save_time=save_time, |
|
**timing_dict, |
|
) |
|
if csv_path is not None: |
|
self.save_history_csv( |
|
csv_path=csv_path, |
|
epoch=report_epoch, |
|
metric=f"{val_acc_mean:.4f}", |
|
loss=f"{train_loss:.4f}", |
|
iter=report_epoch * len(train_loader.dataset), |
|
**timing_dict, |
|
) |
|
|
|
|
|
if epoch > max(20, num_epochs / 4) and 0 <= val_acc_mean < 0.01 and config("stop_on_lowacc", True): |
|
logger.info( |
|
f"Accuracy seems very low at epoch {report_epoch}, acc {val_acc_mean}. " |
|
"Most likely optimization diverged, try setting a smaller learning_rate" |
|
f" than {config('learning_rate')}" |
|
) |
|
raise ValueError( |
|
f"Accuracy seems very low at epoch {report_epoch}, acc {val_acc_mean}. " |
|
"Most likely optimization diverged, try setting a smaller learning_rate" |
|
f" than {config('learning_rate')}" |
|
) |
|
|
|
|
|
if do_torch_save and ((epoch + 1) % num_epochs_per_saving == 0 or (epoch + 1) >= num_epochs): |
|
if report_epoch != best_metric_epoch: |
|
self.checkpoint_save( |
|
ckpt=intermediate_ckpt_path, model=model, epoch=report_epoch, best_metric=val_acc_mean |
|
) |
|
else: |
|
try: |
|
shutil.copyfile(best_ckpt_path, intermediate_ckpt_path) |
|
except Exception as err: |
|
logger.warning(f"error copying {best_ckpt_path} {intermediate_ckpt_path} {err}") |
|
pass |
|
|
|
if lr_scheduler is not None: |
|
lr_scheduler.step() |
|
|
|
if self.global_rank == 0: |
|
|
|
time_remaining_estimate = train_time * (num_epochs - epoch) |
|
if val_loader is not None and len(val_loader) > 0: |
|
if validation_time == 0: |
|
validation_time = train_time |
|
time_remaining_estimate += validation_time * len(val_schedule_list) |
|
|
|
logger.info( |
|
f"Estimated remaining training time for the current model fold {config('fold')} is " |
|
f"{time_remaining_estimate/3600:.2f} hr, " |
|
f"running time {(time.time() - pre_loop_time)/3600:.2f} hr, " |
|
f"est total time {(time.time() - pre_loop_time + time_remaining_estimate)/3600:.2f} hr \n" |
|
) |
|
|
|
|
|
train_loader = val_loader = optimizer = None |
|
|
|
|
|
logger.info(f"Checking to run final testing {config('run_final_testing')}") |
|
if config("run_final_testing"): |
|
if distributed: |
|
dist.barrier() |
|
_ckpt_name = best_ckpt_path if os.path.exists(best_ckpt_path) else intermediate_ckpt_path |
|
if not os.path.exists(_ckpt_name): |
|
logger.info(f"Unable to validate final no checkpoints found {best_ckpt_path}, {intermediate_ckpt_path}") |
|
else: |
|
|
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
best_metric = self.run_final_testing( |
|
pretrained_ckpt_path=_ckpt_name, |
|
progress_path=progress_path, |
|
best_metric_epoch=best_metric_epoch, |
|
pre_loop_time=pre_loop_time, |
|
) |
|
|
|
if ( |
|
self.global_rank == 0 |
|
and tb_writer is not None |
|
and mlflow_is_imported |
|
and mlflow.is_tracking_uri_set() |
|
): |
|
mlflow.log_param("acc_testing", val_acc_mean) |
|
mlflow.log_metric("acc_testing", val_acc_mean) |
|
|
|
if tb_writer is not None: |
|
tb_writer.flush() |
|
tb_writer.close() |
|
|
|
if mlflow_is_imported and mlflow.is_tracking_uri_set(): |
|
mlflow.end_run() |
|
|
|
logger.info( |
|
f"=== DONE: best_metric: {best_metric:.4f} at epoch: {best_metric_epoch} of {report_num_epochs}." |
|
f"Training time {(time.time() - pre_loop_time)/3600:.2f} hr." |
|
) |
|
return best_metric |
|
|
|
def run_final_testing(self, pretrained_ckpt_path, progress_path, best_metric_epoch, pre_loop_time): |
|
logger.info("Running final best model testing set!") |
|
|
|
|
|
start_time = time.time() |
|
|
|
self._props.pop("network", None) |
|
self.parser["pretrained_ckpt_path"] = pretrained_ckpt_path |
|
self.parser["validate#evaluator#postprocessing"] = None |
|
|
|
val_acc_mean, val_loss, val_acc = self.validate(val_key="testing") |
|
validation_time = f"{time.time() - start_time:.2f}s" |
|
val_acc_mean = float(np.mean(val_acc)) |
|
logger.info(f"Testing: loss: {val_loss:.4f} acc_avg: {val_acc_mean:.4f} acc {val_acc} time {validation_time}") |
|
|
|
if self.global_rank == 0 and progress_path is not None: |
|
self.save_progress_yaml( |
|
progress_path=progress_path, |
|
ckpt=pretrained_ckpt_path, |
|
best_avg_score_epoch=best_metric_epoch, |
|
best_avg_score=val_acc_mean, |
|
validation_time=validation_time, |
|
run_final_testing=True, |
|
time=f"{(time.time() - pre_loop_time) / 3600:.2f} hr", |
|
) |
|
return val_acc_mean |
|
|
|
def validate(self, validation_files=None, val_key=None, datalist=None): |
|
if self.config("pretrained_ckpt_name", None) is None and self.config("pretrained_ckpt_path", None) is None: |
|
self.parser["pretrained_ckpt_name"] = "model.pt" |
|
logger.info("Using default model.pt checkpoint for validation.") |
|
|
|
grouping = self.config("validate#grouping", False) |
|
if validation_files is None: |
|
validation_files = self.read_val_datalists("validate", datalist, val_key=val_key, merge=not grouping) |
|
if len(validation_files) == 0: |
|
logger.warning(f"No validation files found {datalist} {val_key}!") |
|
return 0, 0, 0 |
|
if not grouping or not isinstance(validation_files[0], (list, tuple)): |
|
validation_files = [validation_files] |
|
logger.info(f"validation file groups {len(validation_files)} grouping {grouping}") |
|
val_acc_dict = {} |
|
|
|
amp_dtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[ |
|
self.config("amp_dtype") |
|
] |
|
if amp_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): |
|
amp_dtype = torch.float16 |
|
logger.warning( |
|
"bfloat16 dtype is not support on your device, changing to float16, use --amp_dtype=float16 to set manually" |
|
) |
|
|
|
for datalist_id, group_files in enumerate(validation_files): |
|
self.set_val_datalist(group_files) |
|
val_loader = self.val_loader |
|
|
|
start_time = time.time() |
|
val_loss, val_acc = self.val_epoch( |
|
model=self.network, |
|
val_loader=val_loader, |
|
sliding_inferrer=self.config("inferer#sliding_inferer"), |
|
loss_function=self.config("loss_function"), |
|
acc_function=self.config("key_metric"), |
|
rank=self.rank, |
|
global_rank=self.global_rank, |
|
use_amp=self.config("amp"), |
|
amp_dtype=amp_dtype, |
|
post_transforms=self.config("validate#evaluator#postprocessing"), |
|
channels_last=self.config("channels_last"), |
|
device=self.config("device"), |
|
) |
|
val_acc_mean = float(np.mean(val_acc)) |
|
logger.info( |
|
f"Validation {datalist_id} complete, loss_avg: {val_loss:.4f} " |
|
f"acc_avg: {val_acc_mean:.4f} acc {val_acc} time {time.time() - start_time:.2f}s" |
|
) |
|
val_acc_dict[datalist_id] = val_acc_mean |
|
for k, v in val_acc_dict.items(): |
|
logger.info(f"group: {k} => {v:.4f}") |
|
val_acc_mean = sum(val_acc_dict.values()) / len(val_acc_dict.values()) |
|
logger.info(f"Testing group score average: {val_acc_mean:.4f}") |
|
return val_acc_mean, val_loss, val_acc |
|
|
|
def infer(self, infer_files=None, infer_key=None, datalist=None): |
|
if self.config("pretrained_ckpt_name", None) is None and self.config("pretrained_ckpt_path", None) is None: |
|
self.parser["pretrained_ckpt_name"] = "model.pt" |
|
logger.info("Using default model.pt checkpoint for inference.") |
|
|
|
if infer_files is None: |
|
infer_files = self.read_val_datalists("infer", datalist, val_key=infer_key, merge=True) |
|
if len(infer_files) == 0: |
|
logger.warning(f"no file to infer {datalist} {infer_key}.") |
|
return |
|
logger.info(f"inference files {len(infer_files)}") |
|
self.set_val_datalist(infer_files) |
|
val_loader = self.val_loader |
|
|
|
amp_dtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[ |
|
self.config("amp_dtype") |
|
] |
|
if amp_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): |
|
amp_dtype = torch.bfloat16 |
|
logger.warning( |
|
"bfloat16 dtype is not support on your device, changing to float16, use --amp_dtype=float16 to set manually" |
|
) |
|
|
|
start_time = time.time() |
|
self.val_epoch( |
|
model=self.network, |
|
val_loader=val_loader, |
|
sliding_inferrer=self.config("inferer#sliding_inferer"), |
|
loss_function=None, |
|
acc_function=None, |
|
rank=self.rank, |
|
global_rank=self.global_rank, |
|
use_amp=self.config("amp"), |
|
amp_dtype=amp_dtype, |
|
post_transforms=self.config("infer#evaluator#postprocessing"), |
|
channels_last=self.config("channels_last"), |
|
device=self.config("device"), |
|
) |
|
logger.info(f"Inference complete time {time.time() - start_time:.2f}s") |
|
return |
|
|
|
@torch.no_grad() |
|
def val_epoch( |
|
self, |
|
model, |
|
val_loader, |
|
sliding_inferrer, |
|
loss_function=None, |
|
acc_function=None, |
|
epoch=0, |
|
rank=0, |
|
global_rank=0, |
|
num_epochs=0, |
|
use_amp=True, |
|
amp_dtype=torch.float16, |
|
post_transforms=None, |
|
channels_last=False, |
|
device=None, |
|
): |
|
model.eval() |
|
distributed = dist.is_available() and dist.is_initialized() |
|
memory_format = torch.channels_last if channels_last else torch.preserve_format |
|
|
|
run_loss = CumulativeAverage() |
|
run_acc = CumulativeAverage() |
|
run_loss.append(torch.tensor(0, device=device), count=0) |
|
|
|
avg_loss = avg_acc = 0 |
|
start_time = time.time() |
|
|
|
|
|
|
|
|
|
|
|
nonrepeated_data_length = len(val_loader.dataset) |
|
sampler = val_loader.sampler |
|
if distributed and isinstance(sampler, DistributedSampler) and not sampler.drop_last: |
|
nonrepeated_data_length = len(range(sampler.rank, len(sampler.dataset), sampler.num_replicas)) |
|
|
|
for idx, batch_data in enumerate(val_loader): |
|
data = batch_data["image"].as_subclass(torch.Tensor).to(memory_format=memory_format, device=device) |
|
filename = batch_data["image"].meta[ImageMetaKey.FILENAME_OR_OBJ] |
|
batch_size = data.shape[0] |
|
loss = acc = None |
|
|
|
with autocast(enabled=use_amp, dtype=amp_dtype): |
|
logits = sliding_inferrer(inputs=data, network=model) |
|
data = None |
|
|
|
|
|
if loss_function is not None: |
|
target = batch_data["flow"].as_subclass(torch.Tensor).to(device=logits.device) |
|
loss = loss_function(logits, target) |
|
run_loss.append(loss.to(device=device), count=batch_size) |
|
target = None |
|
|
|
pred_mask_all = [] |
|
|
|
for b_ind in range(logits.shape[0]): |
|
pred_mask, p = LogitsToLabels()(logits=logits[b_ind], filename=filename) |
|
pred_mask_all.append(pred_mask) |
|
|
|
if acc_function is not None: |
|
label = batch_data["label"].as_subclass(torch.Tensor) |
|
|
|
for b_ind in range(label.shape[0]): |
|
acc = acc_function(pred_mask_all[b_ind], label[b_ind, 0].long()) |
|
acc = acc.detach().clone() if isinstance(acc, torch.Tensor) else torch.tensor(acc) |
|
|
|
if idx < nonrepeated_data_length: |
|
run_acc.append(acc.to(device=device), count=1) |
|
else: |
|
run_acc.append(torch.zeros_like(acc, device=device), count=0) |
|
label = None |
|
|
|
avg_loss = loss.cpu() if loss is not None else 0 |
|
avg_acc = acc.cpu().numpy() if acc is not None else 0 |
|
|
|
logger.info( |
|
f"Val {epoch}/{num_epochs} {idx}/{len(val_loader)} " |
|
f"loss: {avg_loss:.4f} acc {avg_acc} time {time.time() - start_time:.2f}s" |
|
) |
|
|
|
if post_transforms: |
|
seg = torch.from_numpy(np.stack(pred_mask_all, axis=0).astype(np.int32)).unsqueeze(1) |
|
batch_data["seg"] = convert_to_dst_type( |
|
seg, batch_data["image"], dtype=torch.int32, device=torch.device("cpu") |
|
)[0] |
|
for bd in decollate_batch(batch_data): |
|
post_transforms(bd) |
|
|
|
start_time = time.time() |
|
|
|
label = target = data = batch_data = None |
|
|
|
if distributed: |
|
dist.barrier() |
|
|
|
avg_loss = run_loss.aggregate() |
|
avg_acc = run_acc.aggregate() |
|
|
|
if np.any(avg_acc < 0): |
|
dist.barrier() |
|
logger.warning(f"Avg accuracy is negative ({avg_acc}), something went wrong!!!!!") |
|
|
|
return avg_loss, avg_acc |
|
|
|
def train_epoch( |
|
self, |
|
model, |
|
train_loader, |
|
optimizer, |
|
loss_function, |
|
acc_function, |
|
grad_scaler, |
|
epoch, |
|
rank, |
|
global_rank=0, |
|
num_epochs=0, |
|
use_amp=True, |
|
amp_dtype=torch.float16, |
|
channels_last=False, |
|
device=None, |
|
): |
|
model.train() |
|
memory_format = torch.channels_last if channels_last else torch.preserve_format |
|
|
|
run_loss = CumulativeAverage() |
|
|
|
start_time = time.time() |
|
avg_loss = avg_acc = 0 |
|
for idx, batch_data in enumerate(train_loader): |
|
data = batch_data["image"].as_subclass(torch.Tensor).to(memory_format=memory_format, device=device) |
|
target = batch_data["flow"].as_subclass(torch.Tensor).to(memory_format=memory_format, device=device) |
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
with autocast(enabled=use_amp, dtype=amp_dtype): |
|
logits = model(data) |
|
|
|
|
|
loss = loss_function(logits.float(), target) |
|
|
|
grad_scaler.scale(loss).backward() |
|
grad_scaler.step(optimizer) |
|
grad_scaler.update() |
|
|
|
batch_size = data.shape[0] |
|
|
|
run_loss.append(loss, count=batch_size) |
|
avg_loss = run_loss.aggregate() |
|
|
|
logger.info( |
|
f"Epoch {epoch}/{num_epochs} {idx}/{len(train_loader)} " |
|
f"loss: {avg_loss:.4f} time {time.time() - start_time:.2f}s " |
|
) |
|
start_time = time.time() |
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
data = None |
|
target = None |
|
batch_data = None |
|
|
|
return avg_loss, avg_acc |
|
|
|
def save_history_csv(self, csv_path=None, header=None, **kwargs): |
|
if csv_path is not None: |
|
if header is not None: |
|
with open(csv_path, "a") as myfile: |
|
wrtr = csv.writer(myfile, delimiter="\t") |
|
wrtr.writerow(header) |
|
if len(kwargs): |
|
with open(csv_path, "a") as myfile: |
|
wrtr = csv.writer(myfile, delimiter="\t") |
|
wrtr.writerow(list(kwargs.values())) |
|
|
|
def save_progress_yaml(self, progress_path=None, ckpt=None, **report): |
|
if ckpt is not None: |
|
report["model"] = ckpt |
|
|
|
report["date"] = str(datetime.now())[:19] |
|
|
|
if progress_path is not None: |
|
yaml.add_representer( |
|
float, lambda dumper, value: dumper.represent_scalar("tag:yaml.org,2002:float", f"{value:.4f}") |
|
) |
|
with open(progress_path, "a") as progress_file: |
|
yaml.dump([report], stream=progress_file, allow_unicode=True, default_flow_style=None, sort_keys=False) |
|
|
|
logger.info("Progress:" + ",".join(f" {k}: {v}" for k, v in report.items())) |
|
|
|
def checkpoint_save(self, ckpt: str, model: torch.nn.Module, **kwargs): |
|
|
|
save_time = time.time() |
|
if isinstance(model, torch.nn.parallel.DistributedDataParallel): |
|
state_dict = model.module.state_dict() |
|
else: |
|
state_dict = model.state_dict() |
|
|
|
if self.config("compile", False): |
|
|
|
state_dict = OrderedDict( |
|
(k[len("_orig_mod.") :] if k.startswith("_orig_mod.") else k, v) for k, v in state_dict.items() |
|
) |
|
|
|
torch.save({"state_dict": state_dict, "config": self.parser.config, **kwargs}, ckpt) |
|
|
|
save_time = time.time() - save_time |
|
logger.info(f"Saving checkpoint process: {ckpt}, {kwargs}, save_time {save_time:.2f}s") |
|
|
|
return save_time |
|
|
|
def checkpoint_load(self, ckpt: str, model: torch.nn.Module, **kwargs): |
|
|
|
if not os.path.isfile(ckpt): |
|
logger.warning("Invalid checkpoint file: " + str(ckpt)) |
|
return |
|
checkpoint = torch.load(ckpt, map_location="cpu") |
|
|
|
model.load_state_dict(checkpoint["state_dict"], strict=True) |
|
epoch = checkpoint.get("epoch", 0) |
|
best_metric = checkpoint.get("best_metric", 0) |
|
|
|
if self.config("continue", False): |
|
if "epoch" in checkpoint: |
|
self.parser["start_epoch"] = checkpoint["epoch"] |
|
if "best_metric" in checkpoint: |
|
self.parser["best_metric"] = checkpoint["best_metric"] |
|
|
|
logger.info( |
|
f"=> loaded checkpoint {ckpt} (epoch {epoch}) " |
|
f"(best_metric {best_metric}) setting start_epoch {self.config('start_epoch')}" |
|
) |
|
self.parser["start_epoch"] = int(self.config("start_epoch")) + 1 |
|
return |
|
|
|
def schedule_validation_epochs(self, num_epochs, num_epochs_per_validation=None, fraction=0.16) -> list: |
|
""" |
|
Schedule of epochs to validate (progressively more frequently) |
|
num_epochs - total number of epochs |
|
num_epochs_per_validation - if provided use a linear schedule with this step |
|
init_step |
|
""" |
|
|
|
if num_epochs_per_validation is None: |
|
x = (np.sin(np.linspace(0, np.pi / 2, max(10, int(fraction * num_epochs)))) * num_epochs).astype(int) |
|
x = np.cumsum(np.sort(np.diff(np.unique(x)))[::-1]) |
|
x[-1] = num_epochs |
|
x = x.tolist() |
|
else: |
|
if num_epochs_per_validation >= num_epochs: |
|
x = [num_epochs_per_validation] |
|
else: |
|
x = list(range(num_epochs_per_validation, num_epochs, num_epochs_per_validation)) |
|
|
|
if len(x) == 0: |
|
x = [0] |
|
|
|
return x |
|
|
|
|
|
def main(**kwargs) -> None: |
|
workflow = VistaCell(**kwargs) |
|
workflow.initialize() |
|
workflow.run() |
|
workflow.finalize() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
from pathlib import Path |
|
|
|
sys.path.append(str(Path(__file__).parent.parent)) |
|
|
|
|
|
|
|
fire, fire_is_imported = optional_import("fire") |
|
if fire_is_imported: |
|
fire.Fire(main) |
|
else: |
|
print("Missing package: fire") |
|
|