|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import json |
|
import math |
|
import os |
|
import zipfile |
|
from argparse import Namespace |
|
from datetime import timedelta |
|
from typing import Any, Sequence |
|
|
|
import numpy as np |
|
import skimage |
|
import torch |
|
import torch.distributed as dist |
|
from monai.bundle import ConfigParser |
|
from monai.config import DtypeLike, NdarrayOrTensor |
|
from monai.data import CacheDataset, DataLoader, partition_dataset |
|
from monai.transforms import Compose, EnsureTyped, Lambdad, LoadImaged, Orientationd |
|
from monai.transforms.utils_morphological_ops import dilate, erode |
|
from monai.utils import TransformBackends, convert_data_type, convert_to_dst_type, get_equivalent_dtype |
|
from scipy import stats |
|
from torch import Tensor |
|
|
|
|
|
def unzip_dataset(dataset_dir): |
|
if dist.is_available() and dist.is_initialized(): |
|
rank = dist.get_rank() |
|
else: |
|
rank = 0 |
|
|
|
if rank == 0: |
|
if not os.path.exists(dataset_dir): |
|
zip_file_path = dataset_dir + ".zip" |
|
if not os.path.isfile(zip_file_path): |
|
raise ValueError(f"Please download {zip_file_path}.") |
|
with zipfile.ZipFile(zip_file_path, "r") as zip_ref: |
|
zip_ref.extractall(path=os.path.dirname(dataset_dir)) |
|
print(f"Unzipped {zip_file_path} to {dataset_dir}.") |
|
|
|
if dist.is_available() and dist.is_initialized(): |
|
dist.barrier() |
|
|
|
return |
|
|
|
|
|
def add_data_dir2path(list_files: list, data_dir: str, fold: int = None) -> tuple[list, list]: |
|
""" |
|
Read a list of data dictionary. |
|
|
|
Args: |
|
list_files (list): input data to load and transform to generate dataset for model. |
|
data_dir (str): directory of files. |
|
fold (int, optional): fold index for cross validation. Defaults to None. |
|
|
|
Returns: |
|
tuple[list, list]: A tuple of two arrays (training, validation). |
|
""" |
|
new_list_files = copy.deepcopy(list_files) |
|
if fold is not None: |
|
new_list_files_train = [] |
|
new_list_files_val = [] |
|
for d in new_list_files: |
|
d["image"] = os.path.join(data_dir, d["image"]) |
|
|
|
if "label" in d: |
|
d["label"] = os.path.join(data_dir, d["label"]) |
|
|
|
if fold is not None: |
|
if d["fold"] == fold: |
|
new_list_files_val.append(copy.deepcopy(d)) |
|
else: |
|
new_list_files_train.append(copy.deepcopy(d)) |
|
|
|
if fold is not None: |
|
return new_list_files_train, new_list_files_val |
|
else: |
|
return new_list_files, [] |
|
|
|
|
|
def maisi_datafold_read(json_list, data_base_dir, fold=None): |
|
with open(json_list, "r") as f: |
|
filenames_train = json.load(f)["training"] |
|
|
|
train_files, val_files = add_data_dir2path(filenames_train, data_base_dir, fold=fold) |
|
print(f"dataset: {data_base_dir}, num_training_files: {len(train_files)}, num_val_files: {len(val_files)}") |
|
return train_files, val_files |
|
|
|
|
|
def remap_labels(mask, label_dict_remap_json): |
|
""" |
|
Remap labels in the mask according to the provided label dictionary. |
|
|
|
This function reads a JSON file containing label mapping information and applies |
|
the mapping to the input mask. |
|
|
|
Args: |
|
mask (Tensor): The input mask tensor to be remapped. |
|
label_dict_remap_json (str): Path to the JSON file containing the label mapping dictionary. |
|
|
|
Returns: |
|
Tensor: The remapped mask tensor. |
|
""" |
|
with open(label_dict_remap_json, "r") as f: |
|
mapping_dict = json.load(f) |
|
mapper = MapLabelValue( |
|
orig_labels=[pair[0] for pair in mapping_dict.values()], |
|
target_labels=[pair[1] for pair in mapping_dict.values()], |
|
dtype=torch.uint8, |
|
) |
|
return mapper(mask[0, ...])[None, ...].to(mask.device) |
|
|
|
|
|
def get_index_arr(img): |
|
""" |
|
Generate an index array for the given image. |
|
|
|
This function creates a 3D array of indices corresponding to the dimensions of the input image. |
|
|
|
Args: |
|
img (ndarray): The input image array. |
|
|
|
Returns: |
|
ndarray: A 3D array containing the indices for each dimension of the input image. |
|
""" |
|
return np.moveaxis( |
|
np.moveaxis( |
|
np.stack(np.meshgrid(np.arange(img.shape[0]), np.arange(img.shape[1]), np.arange(img.shape[2]))), 0, 3 |
|
), |
|
0, |
|
1, |
|
) |
|
|
|
|
|
def supress_non_largest_components(img, target_label, default_val=0): |
|
""" |
|
Suppress all components except the largest one(s) for specified target labels. |
|
|
|
This function identifies the largest component(s) for each target label and |
|
suppresses all other smaller components. |
|
|
|
Args: |
|
img (ndarray): The input image array. |
|
target_label (list): List of label values to process. |
|
default_val (int, optional): Value to assign to suppressed voxels. Defaults to 0. |
|
|
|
Returns: |
|
tuple: A tuple containing: |
|
- ndarray: Modified image with non-largest components suppressed. |
|
- int: Number of voxels that were changed. |
|
""" |
|
index_arr = get_index_arr(img) |
|
img_mod = copy.deepcopy(img) |
|
new_background = np.zeros(img.shape, dtype=np.bool_) |
|
for label in target_label: |
|
label_cc = skimage.measure.label(img == label, connectivity=3) |
|
uv, uc = np.unique(label_cc, return_counts=True) |
|
dominant_vals = uv[np.argsort(uc)[::-1][:2]] |
|
if len(dominant_vals) >= 2: |
|
new_background = np.logical_or( |
|
new_background, |
|
np.logical_not(np.logical_or(label_cc == dominant_vals[0], label_cc == dominant_vals[1])), |
|
) |
|
|
|
for voxel in index_arr[new_background]: |
|
img_mod[tuple(voxel)] = default_val |
|
diff = np.sum((img - img_mod) > 0) |
|
|
|
return img_mod, diff |
|
|
|
|
|
def erode_one_img(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> Tensor: |
|
""" |
|
Erode 2D/3D binary mask with data type as torch tensor. |
|
|
|
Args: |
|
mask_t: input 2D/3D binary mask, [M,N] or [M,N,P] torch tensor. |
|
filter_size: erosion filter size, has to be odd numbers, default to be 3. |
|
pad_value: the filled value for padding. We need to pad the input before filtering |
|
to keep the output with the same size as input. Usually use default value |
|
and not changed. |
|
|
|
Return: |
|
Tensor: eroded mask, same shape as input. |
|
""" |
|
return erode(mask_t.float().unsqueeze(0).unsqueeze(0), filter_size, pad_value=pad_value).squeeze(0).squeeze(0) |
|
|
|
|
|
def dilate_one_img(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> Tensor: |
|
""" |
|
Dilate 2D/3D binary mask with data type as torch tensor. |
|
|
|
Args: |
|
mask_t: input 2D/3D binary mask, [M,N] or [M,N,P] torch tensor. |
|
filter_size: dilation filter size, has to be odd numbers, default to be 3. |
|
pad_value: the filled value for padding. We need to pad the input before filtering |
|
to keep the output with the same size as input. Usually use default value |
|
and not changed. |
|
|
|
Return: |
|
Tensor: dilated mask, same shape as input. |
|
""" |
|
return dilate(mask_t.float().unsqueeze(0).unsqueeze(0), filter_size, pad_value=pad_value).squeeze(0).squeeze(0) |
|
|
|
|
|
def binarize_labels(x: Tensor, bits: int = 8) -> Tensor: |
|
""" |
|
Convert input tensor to binary representation. |
|
|
|
This function takes an input tensor and converts it to a binary representation |
|
using the specified number of bits. |
|
|
|
Args: |
|
x (Tensor): Input tensor with shape (B, 1, H, W, D). |
|
bits (int, optional): Number of bits to use for binary representation. Defaults to 8. |
|
|
|
Returns: |
|
Tensor: Binary representation of the input tensor with shape (B, bits, H, W, D). |
|
""" |
|
mask = 2 ** torch.arange(bits).to(x.device, x.dtype) |
|
return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte().squeeze(1).permute(0, 4, 1, 2, 3) |
|
|
|
|
|
def setup_ddp(rank: int, world_size: int) -> torch.device: |
|
""" |
|
Initialize the distributed process group. |
|
|
|
Args: |
|
rank (int): rank of the current process. |
|
world_size (int): number of processes participating in the job. |
|
|
|
Returns: |
|
torch.device: device of the current process. |
|
""" |
|
dist.init_process_group( |
|
backend="nccl", init_method="env://", timeout=timedelta(seconds=36000), rank=rank, world_size=world_size |
|
) |
|
dist.barrier() |
|
device = torch.device(f"cuda:{rank}") |
|
return device |
|
|
|
|
|
def define_instance(args: Namespace, instance_def_key: str) -> Any: |
|
""" |
|
Define and instantiate an object based on the provided arguments and instance definition key. |
|
|
|
This function uses a ConfigParser to parse the arguments and instantiate an object |
|
defined by the instance_def_key. |
|
|
|
Args: |
|
args: An object containing the arguments to be parsed. |
|
instance_def_key (str): The key used to retrieve the instance definition from the parsed content. |
|
|
|
Returns: |
|
The instantiated object as defined by the instance_def_key in the parsed configuration. |
|
""" |
|
parser = ConfigParser(vars(args)) |
|
parser.parse(True) |
|
return parser.get_parsed_content(instance_def_key, instantiate=True) |
|
|
|
|
|
def prepare_maisi_controlnet_json_dataloader( |
|
json_data_list: list | str, |
|
data_base_dir: list | str, |
|
batch_size: int = 1, |
|
fold: int = 0, |
|
cache_rate: float = 0.0, |
|
rank: int = 0, |
|
world_size: int = 1, |
|
) -> tuple[DataLoader, DataLoader]: |
|
""" |
|
Prepare dataloaders for training and validation. |
|
|
|
Args: |
|
json_data_list (list | str): the name of JSON files listing the data. |
|
data_base_dir (list | str): directory of files. |
|
batch_size (int, optional): how many samples per batch to load . Defaults to 1. |
|
fold (int, optional): fold index for cross validation. Defaults to 0. |
|
cache_rate (float, optional): percentage of cached data in total. Defaults to 0.0. |
|
rank (int, optional): rank of the current process. Defaults to 0. |
|
world_size (int, optional): number of processes participating in the job. Defaults to 1. |
|
|
|
Returns: |
|
tuple[DataLoader, DataLoader]: A tuple of two dataloaders (training, validation). |
|
""" |
|
use_ddp = world_size > 1 |
|
if isinstance(json_data_list, list): |
|
assert isinstance(data_base_dir, list) |
|
list_train = [] |
|
list_valid = [] |
|
for data_list, data_root in zip(json_data_list, data_base_dir): |
|
with open(data_list, "r") as f: |
|
json_data = json.load(f)["training"] |
|
train, val = add_data_dir2path(json_data, data_root, fold) |
|
list_train += train |
|
list_valid += val |
|
else: |
|
with open(json_data_list, "r") as f: |
|
json_data = json.load(f)["training"] |
|
list_train, list_valid = add_data_dir2path(json_data, data_base_dir, fold) |
|
|
|
common_transform = [ |
|
LoadImaged(keys=["image", "label"], image_only=True, ensure_channel_first=True), |
|
Orientationd(keys=["label"], axcodes="RAS"), |
|
EnsureTyped(keys=["label"], dtype=torch.uint8, track_meta=True), |
|
Lambdad(keys="top_region_index", func=lambda x: torch.FloatTensor(x)), |
|
Lambdad(keys="bottom_region_index", func=lambda x: torch.FloatTensor(x)), |
|
Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(x)), |
|
Lambdad(keys=["top_region_index", "bottom_region_index", "spacing"], func=lambda x: x * 1e2), |
|
] |
|
train_transforms, val_transforms = Compose(common_transform), Compose(common_transform) |
|
|
|
train_loader = None |
|
|
|
if use_ddp: |
|
list_train = partition_dataset(data=list_train, shuffle=True, num_partitions=world_size, even_divisible=True)[ |
|
rank |
|
] |
|
train_ds = CacheDataset(data=list_train, transform=train_transforms, cache_rate=cache_rate, num_workers=8) |
|
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) |
|
if use_ddp: |
|
list_valid = partition_dataset(data=list_valid, shuffle=True, num_partitions=world_size, even_divisible=False)[ |
|
rank |
|
] |
|
val_ds = CacheDataset(data=list_valid, transform=val_transforms, cache_rate=cache_rate, num_workers=8) |
|
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=False) |
|
return train_loader, val_loader |
|
|
|
|
|
def organ_fill_by_closing(data, target_label, device, close_times=2, filter_size=3, pad_value=0.0): |
|
""" |
|
Fill holes in an organ mask using morphological closing operations. |
|
|
|
This function performs a series of dilation and erosion operations to fill holes |
|
in the organ mask identified by the target label. |
|
|
|
Args: |
|
data (ndarray): The input data containing organ labels. |
|
target_label (int): The label of the organ to be processed. |
|
device (str): The device to perform the operations on (e.g., 'cuda:0'). |
|
close_times (int, optional): Number of times to perform the closing operation. Defaults to 2. |
|
filter_size (int, optional): Size of the filter for dilation and erosion. Defaults to 3. |
|
pad_value (float, optional): Value used for padding in dilation and erosion. Defaults to 0.0. |
|
|
|
Returns: |
|
ndarray: Boolean mask of the filled organ. |
|
""" |
|
mask = (data == target_label).astype(np.uint8) |
|
mask = torch.from_numpy(mask).to(device) |
|
for _ in range(close_times): |
|
mask = dilate_one_img(mask, filter_size=filter_size, pad_value=pad_value) |
|
mask = erode_one_img(mask, filter_size=filter_size, pad_value=pad_value) |
|
return mask.cpu().numpy().astype(np.bool_) |
|
|
|
|
|
def organ_fill_by_removed_mask(data, target_label, remove_mask, device): |
|
""" |
|
Fill an organ mask in regions where it was previously removed. |
|
|
|
Args: |
|
data (ndarray): The input data containing organ labels. |
|
target_label (int): The label of the organ to be processed. |
|
remove_mask (ndarray): Boolean mask indicating regions where the organ was removed. |
|
device (str): The device to perform the operations on (e.g., 'cuda:0'). |
|
|
|
Returns: |
|
ndarray: Boolean mask of the filled organ in previously removed regions. |
|
""" |
|
mask = (data == target_label).astype(np.uint8) |
|
mask = dilate_one_img(torch.from_numpy(mask).to(device), filter_size=3, pad_value=0.0) |
|
mask = dilate_one_img(mask, filter_size=3, pad_value=0.0) |
|
roi_oragn_mask = dilate_one_img(mask, filter_size=3, pad_value=0.0).cpu().numpy() |
|
return (roi_oragn_mask * remove_mask).astype(np.bool_) |
|
|
|
|
|
def get_body_region_index_from_mask(input_mask): |
|
""" |
|
Determine the top and bottom body region indices from an input mask. |
|
|
|
Args: |
|
input_mask (Tensor): Input mask tensor containing body region labels. |
|
|
|
Returns: |
|
tuple: Two lists representing the top and bottom region indices. |
|
""" |
|
region_indices = {} |
|
|
|
region_indices["region_0"] = [22, 120] |
|
|
|
region_indices["region_1"] = [28, 29, 30, 31, 32] |
|
|
|
region_indices["region_2"] = [1, 2, 3, 4, 5, 14] |
|
|
|
region_indices["region_3"] = [93, 94] |
|
|
|
nda = input_mask.cpu().numpy().squeeze() |
|
unique_elements = np.lib.arraysetops.unique(nda) |
|
unique_elements = list(unique_elements) |
|
|
|
overlap_array = np.zeros(len(region_indices), dtype=np.uint8) |
|
for _j in range(len(region_indices)): |
|
overlap = any(element in region_indices[f"region_{_j}"] for element in unique_elements) |
|
overlap_array[_j] = np.uint8(overlap) |
|
overlap_array_indices = np.nonzero(overlap_array)[0] |
|
top_region_index = np.eye(len(region_indices), dtype=np.uint8)[np.amin(overlap_array_indices), ...] |
|
top_region_index = list(top_region_index) |
|
top_region_index = [int(_k) for _k in top_region_index] |
|
bottom_region_index = np.eye(len(region_indices), dtype=np.uint8)[np.amax(overlap_array_indices), ...] |
|
bottom_region_index = list(bottom_region_index) |
|
bottom_region_index = [int(_k) for _k in bottom_region_index] |
|
|
|
return top_region_index, bottom_region_index |
|
|
|
|
|
def general_mask_generation_post_process(volume_t, target_tumor_label=None, device="cuda:0"): |
|
""" |
|
Perform post-processing on a generated mask volume. |
|
|
|
This function applies various refinement steps to improve the quality of the generated mask, |
|
including body mask refinement, tumor prediction refinement, and organ-specific processing. |
|
|
|
Args: |
|
volume_t (ndarray): Input volume containing organ and tumor labels. |
|
target_tumor_label (int, optional): Label of the target tumor. Defaults to None. |
|
device (str, optional): Device to perform operations on. Defaults to "cuda:0". |
|
|
|
Returns: |
|
ndarray: Post-processed volume with refined organ and tumor labels. |
|
""" |
|
|
|
hepatic_vessel = volume_t == 25 |
|
airway = volume_t == 132 |
|
|
|
|
|
body_region_mask = ( |
|
erode_one_img(torch.from_numpy((volume_t > 0)).to(device), filter_size=3, pad_value=0.0).cpu().numpy() |
|
) |
|
body_region_mask, _ = supress_non_largest_components(body_region_mask, [1]) |
|
body_region_mask = ( |
|
dilate_one_img(torch.from_numpy(body_region_mask).to(device), filter_size=3, pad_value=0.0) |
|
.cpu() |
|
.numpy() |
|
.astype(np.uint8) |
|
) |
|
volume_t = volume_t * body_region_mask |
|
|
|
|
|
tumor_organ_dict = {23: 28, 24: 4, 26: 1, 27: 62, 128: 200} |
|
for t in [23, 24, 26, 27, 128]: |
|
if t != target_tumor_label: |
|
volume_t[volume_t == t] = tumor_organ_dict[t] |
|
else: |
|
volume_t[organ_fill_by_closing(volume_t, target_label=t, device=device)] = t |
|
volume_t[organ_fill_by_closing(volume_t, target_label=t, device=device)] = t |
|
|
|
if target_tumor_label != 26 and target_tumor_label != 128: |
|
volume_t, _ = supress_non_largest_components(volume_t, [target_tumor_label], default_val=200) |
|
target_tumor = volume_t == target_tumor_label |
|
|
|
|
|
|
|
|
|
oran_list = [1, 4, 10, 12, 3, 28, 29, 30, 31, 32, 5, 14, 13, 6, 7, 8, 9, 10] |
|
if target_tumor_label != 128: |
|
oran_list += list(range(33, 60)) |
|
data, _ = supress_non_largest_components(volume_t, oran_list, default_val=200) |
|
organ_remove_mask = (volume_t - data).astype(np.bool_) |
|
|
|
intestinal_mask_ = ( |
|
(data == 12).astype(np.uint8) |
|
+ (data == 13).astype(np.uint8) |
|
+ (data == 19).astype(np.uint8) |
|
+ (data == 62).astype(np.uint8) |
|
) |
|
intestinal_mask, _ = supress_non_largest_components(intestinal_mask_, [1], default_val=0) |
|
|
|
small_bowel_remove_mask = (data == 19).astype(np.uint8) - (data == 19).astype(np.uint8) * intestinal_mask |
|
|
|
colon_remove_mask = (data == 62).astype(np.uint8) - (data == 62).astype(np.uint8) * intestinal_mask |
|
intestinal_remove_mask = (small_bowel_remove_mask + colon_remove_mask).astype(np.bool_) |
|
data[intestinal_remove_mask] = 200 |
|
|
|
|
|
for organ_label in oran_list: |
|
data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label |
|
|
|
if target_tumor_label == 23 and np.sum(target_tumor) > 0: |
|
|
|
dia_lung_tumor_mask = ( |
|
dilate_one_img(torch.from_numpy((data == 23)).to(device), filter_size=3, pad_value=0.0).cpu().numpy() |
|
) |
|
tmp = ( |
|
(data * (dia_lung_tumor_mask.astype(np.uint8) - (data == 23).astype(np.uint8))).astype(np.float32).flatten() |
|
) |
|
tmp[tmp == 0] = float("nan") |
|
mode = int(stats.mode(tmp.flatten(), nan_policy="omit")[0]) |
|
if mode in [28, 29, 30, 31, 32]: |
|
dia_lung_tumor_mask = ( |
|
dilate_one_img(torch.from_numpy(dia_lung_tumor_mask).to(device), filter_size=3, pad_value=0.0) |
|
.cpu() |
|
.numpy() |
|
) |
|
lung_remove_mask = dia_lung_tumor_mask.astype(np.uint8) - (data == 23).astype(np.uint8).astype(np.uint8) |
|
data[organ_fill_by_removed_mask(data, target_label=mode, remove_mask=lung_remove_mask, device=device)] = ( |
|
mode |
|
) |
|
dia_lung_tumor_mask = ( |
|
dilate_one_img(torch.from_numpy(dia_lung_tumor_mask).to(device), filter_size=3, pad_value=0.0).cpu().numpy() |
|
) |
|
data[ |
|
organ_fill_by_removed_mask( |
|
data, target_label=23, remove_mask=dia_lung_tumor_mask * organ_remove_mask, device=device |
|
) |
|
] = 23 |
|
for organ_label in [28, 29, 30, 31, 32]: |
|
data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label |
|
data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label |
|
data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label |
|
|
|
if target_tumor_label == 26 and np.sum(target_tumor) > 0: |
|
|
|
|
|
data[organ_fill_by_removed_mask(data, target_label=1, remove_mask=intestinal_remove_mask, device=device)] = 1 |
|
data[organ_fill_by_removed_mask(data, target_label=1, remove_mask=intestinal_remove_mask, device=device)] = 1 |
|
|
|
data[organ_fill_by_removed_mask(data, target_label=3, remove_mask=organ_remove_mask, device=device)] = 3 |
|
data[organ_fill_by_removed_mask(data, target_label=3, remove_mask=organ_remove_mask, device=device)] = 3 |
|
dia_tumor_mask = ( |
|
dilate_one_img(torch.from_numpy((data == target_tumor_label)).to(device), filter_size=3, pad_value=0.0) |
|
.cpu() |
|
.numpy() |
|
) |
|
dia_tumor_mask = ( |
|
dilate_one_img(torch.from_numpy(dia_tumor_mask).to(device), filter_size=3, pad_value=0.0).cpu().numpy() |
|
) |
|
data[ |
|
organ_fill_by_removed_mask( |
|
data, target_label=target_tumor_label, remove_mask=dia_tumor_mask * organ_remove_mask, device=device |
|
) |
|
] = target_tumor_label |
|
|
|
hepatic_tumor_vessel_liver_mask_ = ( |
|
(data == 26).astype(np.uint8) + (data == 25).astype(np.uint8) + (data == 1).astype(np.uint8) |
|
) |
|
hepatic_tumor_vessel_liver_mask_ = (hepatic_tumor_vessel_liver_mask_ > 1).astype(np.uint8) |
|
hepatic_tumor_vessel_liver_mask, _ = supress_non_largest_components( |
|
hepatic_tumor_vessel_liver_mask_, [1], default_val=0 |
|
) |
|
removed_region = (hepatic_tumor_vessel_liver_mask_ - hepatic_tumor_vessel_liver_mask).astype(np.bool_) |
|
data[removed_region] = 200 |
|
target_tumor = (target_tumor * hepatic_tumor_vessel_liver_mask).astype(np.bool_) |
|
|
|
data[organ_fill_by_closing(data, target_label=1, device=device)] = 1 |
|
data[organ_fill_by_closing(data, target_label=1, device=device)] = 1 |
|
data[organ_fill_by_closing(data, target_label=1, device=device)] = 1 |
|
|
|
if target_tumor_label == 27 and np.sum(target_tumor) > 0: |
|
|
|
dia_tumor_mask = ( |
|
dilate_one_img(torch.from_numpy((data == target_tumor_label)).to(device), filter_size=3, pad_value=0.0) |
|
.cpu() |
|
.numpy() |
|
) |
|
dia_tumor_mask = ( |
|
dilate_one_img(torch.from_numpy(dia_tumor_mask).to(device), filter_size=3, pad_value=0.0).cpu().numpy() |
|
) |
|
data[ |
|
organ_fill_by_removed_mask( |
|
data, target_label=target_tumor_label, remove_mask=dia_tumor_mask * organ_remove_mask, device=device |
|
) |
|
] = target_tumor_label |
|
|
|
if target_tumor_label == 129 and np.sum(target_tumor) > 0: |
|
|
|
for organ_label in [5, 14]: |
|
data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label |
|
data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label |
|
data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label |
|
|
|
|
|
print( |
|
"Current model does not support hepatic vessel by size control, " |
|
"so we treat generated hepatic vessel as part of liver for better visiaulization." |
|
) |
|
data[hepatic_vessel] = 1 |
|
data[airway] = 132 |
|
if target_tumor_label is not None: |
|
data[target_tumor] = target_tumor_label |
|
|
|
return data |
|
|
|
|
|
class MapLabelValue: |
|
""" |
|
Utility to map label values to another set of values. |
|
For example, map [3, 2, 1] to [0, 1, 2], [1, 2, 3] -> [0.5, 1.5, 2.5], ["label3", "label2", "label1"] -> [0, 1, 2], |
|
[3.5, 2.5, 1.5] -> ["label0", "label1", "label2"], etc. |
|
The label data must be numpy array or array-like data and the output data will be numpy array. |
|
|
|
""" |
|
|
|
backend = [TransformBackends.NUMPY, TransformBackends.TORCH] |
|
|
|
def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeLike = np.float32) -> None: |
|
""" |
|
Args: |
|
orig_labels: original labels that map to others. |
|
target_labels: expected label values, 1: 1 map to the `orig_labels`. |
|
dtype: convert the output data to dtype, default to float32. |
|
if dtype is from PyTorch, the transform will use the pytorch backend, else with numpy backend. |
|
|
|
""" |
|
if len(orig_labels) != len(target_labels): |
|
raise ValueError("orig_labels and target_labels must have the same length.") |
|
|
|
self.orig_labels = orig_labels |
|
self.target_labels = target_labels |
|
self.pair = tuple((o, t) for o, t in zip(self.orig_labels, self.target_labels) if o != t) |
|
type_dtype = type(dtype) |
|
if getattr(type_dtype, "__module__", "") == "torch": |
|
self.use_numpy = False |
|
self.dtype = get_equivalent_dtype(dtype, data_type=torch.Tensor) |
|
else: |
|
self.use_numpy = True |
|
self.dtype = get_equivalent_dtype(dtype, data_type=np.ndarray) |
|
|
|
def __call__(self, img: NdarrayOrTensor): |
|
""" |
|
Apply the label mapping to the input image. |
|
|
|
Args: |
|
img (NdarrayOrTensor): Input image to be remapped. |
|
|
|
Returns: |
|
NdarrayOrTensor: Remapped image. |
|
""" |
|
if self.use_numpy: |
|
img_np, *_ = convert_data_type(img, np.ndarray) |
|
_out_shape = img_np.shape |
|
img_flat = img_np.flatten() |
|
try: |
|
out_flat = img_flat.astype(self.dtype) |
|
except ValueError: |
|
|
|
out_flat = np.zeros(shape=img_flat.shape, dtype=self.dtype) |
|
for o, t in self.pair: |
|
out_flat[img_flat == o] = t |
|
out_t = out_flat.reshape(_out_shape) |
|
else: |
|
img_t, *_ = convert_data_type(img, torch.Tensor) |
|
out_t = img_t.detach().clone().to(self.dtype) |
|
for o, t in self.pair: |
|
out_t[img_t == o] = t |
|
out, *_ = convert_to_dst_type(src=out_t, dst=img, dtype=self.dtype) |
|
return out |
|
|
|
|
|
def dynamic_infer(inferer, model, images): |
|
""" |
|
Perform dynamic inference using a model and an inferer, typically a monai SlidingWindowInferer. |
|
|
|
This function determines whether to use the model directly or to use the provided inferer |
|
(such as a sliding window inferer) based on the size of the input images. |
|
|
|
Args: |
|
inferer: An inference object, typically a monai SlidingWindowInferer, which handles patch-based inference. |
|
model (torch.nn.Module): The model used for inference. |
|
images (torch.Tensor): The input images for inference, shape [N,C,H,W,D] or [N,C,H,W]. |
|
|
|
Returns: |
|
torch.Tensor: The output from the model or the inferer, depending on the input size. |
|
""" |
|
if torch.numel(images[0:1, 0:1, ...]) < math.prod(inferer.roi_size): |
|
return model(images) |
|
else: |
|
|
|
spatial_dims = images.shape[2:] |
|
orig_roi = inferer.roi_size |
|
|
|
|
|
if len(orig_roi) != len(spatial_dims): |
|
raise ValueError(f"ROI length ({len(orig_roi)}) does not match spatial dimensions ({len(spatial_dims)}).") |
|
|
|
|
|
adjusted_roi = [min(roi_dim, img_dim) for roi_dim, img_dim in zip(orig_roi, spatial_dims)] |
|
inferer.roi_size = adjusted_roi |
|
output = inferer(network=model, inputs=images) |
|
inferer.roi_size = orig_roi |
|
return output |
|
|