|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import logging |
|
import os |
|
import random |
|
import time |
|
from datetime import datetime |
|
|
|
import monai |
|
import torch |
|
from monai.data import MetaTensor |
|
from monai.inferers.inferer import DiffusionInferer, SlidingWindowInferer |
|
from monai.transforms import Compose, SaveImage |
|
from monai.utils import set_determinism |
|
from tqdm import tqdm |
|
|
|
from .augmentation import augmentation |
|
from .find_masks import find_masks |
|
from .quality_check import is_outlier |
|
from .utils import binarize_labels, dynamic_infer, general_mask_generation_post_process, remap_labels |
|
|
|
modality_mapping = { |
|
"unknown": 0, |
|
"ct": 1, |
|
"ct_wo_contrast": 2, |
|
"ct_contrast": 3, |
|
"mri": 8, |
|
"mri_t1": 9, |
|
"mri_t2": 10, |
|
"mri_flair": 11, |
|
"mri_pd": 12, |
|
"mri_dwi": 13, |
|
"mri_adc": 14, |
|
"mri_ssfp": 15, |
|
"mri_mra": 16, |
|
} |
|
|
|
|
|
class ReconModel(torch.nn.Module): |
|
""" |
|
A PyTorch module for reconstructing images from latent representations. |
|
|
|
Attributes: |
|
autoencoder: The autoencoder model used for decoding. |
|
scale_factor: Scaling factor applied to the input before decoding. |
|
""" |
|
|
|
def __init__(self, autoencoder, scale_factor): |
|
super().__init__() |
|
self.autoencoder = autoencoder |
|
self.scale_factor = scale_factor |
|
|
|
def forward(self, z): |
|
""" |
|
Decode the input latent representation to an image. |
|
|
|
Args: |
|
z (torch.Tensor): The input latent representation. |
|
|
|
Returns: |
|
torch.Tensor: The reconstructed image. |
|
""" |
|
recon_pt_nda = self.autoencoder.decode_stage_2_outputs(z / self.scale_factor) |
|
return recon_pt_nda |
|
|
|
|
|
def initialize_noise_latents(latent_shape, device): |
|
""" |
|
Initialize random noise latents for image generation with float16. |
|
|
|
Args: |
|
latent_shape (tuple): The shape of the latent space. |
|
device (torch.device): The device to create the tensor on. |
|
|
|
Returns: |
|
torch.Tensor: Initialized noise latents. |
|
""" |
|
return torch.randn([1] + list(latent_shape)).half().to(device) |
|
|
|
|
|
def ldm_conditional_sample_one_mask( |
|
autoencoder, |
|
diffusion_unet, |
|
noise_scheduler, |
|
scale_factor, |
|
anatomy_size, |
|
device, |
|
latent_shape, |
|
label_dict_remap_json, |
|
num_inference_steps=1000, |
|
autoencoder_sliding_window_infer_size=(96, 96, 96), |
|
autoencoder_sliding_window_infer_overlap=0.6667, |
|
): |
|
""" |
|
Generate a single synthetic mask using a latent diffusion model. |
|
|
|
Args: |
|
autoencoder (nn.Module): The autoencoder model. |
|
diffusion_unet (nn.Module): The diffusion U-Net model. |
|
noise_scheduler: The noise scheduler for the diffusion process. |
|
scale_factor (float): Scaling factor for the latent space. |
|
anatomy_size (torch.Tensor): Tensor specifying the desired anatomy sizes. |
|
device (torch.device): The device to run the computation on. |
|
latent_shape (tuple): The shape of the latent space. |
|
label_dict_remap_json (str): Path to the JSON file for label remapping. |
|
num_inference_steps (int): Number of inference steps for the diffusion process. |
|
autoencoder_sliding_window_infer_size (list, optional): Size of the sliding window for inference. Defaults to [96, 96, 96]. |
|
autoencoder_sliding_window_infer_overlap (float, optional): Overlap ratio for sliding window inference. Defaults to 0.6667. |
|
|
|
Returns: |
|
torch.Tensor: The generated synthetic mask. |
|
""" |
|
recon_model = ReconModel(autoencoder=autoencoder, scale_factor=scale_factor).to(device) |
|
|
|
with torch.no_grad(), torch.amp.autocast("cuda"): |
|
|
|
latents = initialize_noise_latents(latent_shape, device) |
|
anatomy_size = torch.FloatTensor(anatomy_size).unsqueeze(0).unsqueeze(0).half().to(device) |
|
|
|
noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps) |
|
inferer_ddpm = DiffusionInferer(noise_scheduler) |
|
latents = inferer_ddpm.sample( |
|
input_noise=latents, |
|
diffusion_model=diffusion_unet, |
|
scheduler=noise_scheduler, |
|
verbose=True, |
|
conditioning=anatomy_size.to(device), |
|
) |
|
|
|
inferer = SlidingWindowInferer( |
|
roi_size=autoencoder_sliding_window_infer_size, |
|
sw_batch_size=1, |
|
progress=True, |
|
mode="gaussian", |
|
overlap=autoencoder_sliding_window_infer_overlap, |
|
device=torch.device("cpu"), |
|
sw_device=device, |
|
) |
|
synthetic_mask = dynamic_infer(inferer, recon_model, latents) |
|
synthetic_mask = torch.softmax(synthetic_mask, dim=1) |
|
synthetic_mask = torch.argmax(synthetic_mask, dim=1, keepdim=True) |
|
|
|
synthetic_mask = remap_labels(synthetic_mask, label_dict_remap_json) |
|
|
|
|
|
data = synthetic_mask.squeeze().cpu().detach().numpy() |
|
|
|
labels = [23, 24, 26, 27, 128] |
|
target_tumor_label = None |
|
for index, size in enumerate(anatomy_size[0, 0, 5:10]): |
|
if size.item() != -1.0: |
|
target_tumor_label = labels[index] |
|
|
|
logging.info(f"target_tumor_label for postprocess:{target_tumor_label}") |
|
data = general_mask_generation_post_process(data, target_tumor_label=target_tumor_label, device=device) |
|
synthetic_mask = torch.from_numpy(data).unsqueeze(0).unsqueeze(0).to(device) |
|
|
|
return synthetic_mask |
|
|
|
|
|
def ldm_conditional_sample_one_image( |
|
autoencoder, |
|
diffusion_unet, |
|
controlnet, |
|
noise_scheduler, |
|
scale_factor, |
|
device, |
|
combine_label_or, |
|
modality_tensor, |
|
spacing_tensor, |
|
latent_shape, |
|
output_size, |
|
noise_factor, |
|
num_inference_steps=1000, |
|
autoencoder_sliding_window_infer_size=(96, 96, 96), |
|
autoencoder_sliding_window_infer_overlap=0.6667, |
|
): |
|
""" |
|
Generate a single synthetic image using a latent diffusion model with controlnet. |
|
|
|
Args: |
|
autoencoder (nn.Module): The autoencoder model. |
|
diffusion_unet (nn.Module): The diffusion U-Net model. |
|
controlnet (nn.Module): The controlnet model. |
|
noise_scheduler: The noise scheduler for the diffusion process. |
|
scale_factor (float): Scaling factor for the latent space. |
|
device (torch.device): The device to run the computation on. |
|
combine_label_or (torch.Tensor): The combined label tensor. |
|
spacing_tensor (torch.Tensor): Tensor specifying the spacing. |
|
latent_shape (tuple): The shape of the latent space. |
|
output_size (tuple): The desired output size of the image. |
|
noise_factor (float): Factor to scale the initial noise. |
|
num_inference_steps (int): Number of inference steps for the diffusion process. |
|
autoencoder_sliding_window_infer_size (list, optional): Size of the sliding window for inference. Defaults to [96, 96, 96]. |
|
autoencoder_sliding_window_infer_overlap (float, optional): Overlap ratio for sliding window inference. Defaults to 0.6667. |
|
|
|
Returns: |
|
tuple: A tuple containing the synthetic image and its corresponding label. |
|
""" |
|
|
|
a_min = -1000 |
|
a_max = 1000 |
|
|
|
b_min = 0.0 |
|
b_max = 1 |
|
|
|
recon_model = ReconModel(autoencoder=autoencoder, scale_factor=scale_factor).to(device) |
|
|
|
with torch.no_grad(), torch.amp.autocast("cuda", enabled=True): |
|
logging.info("---- Start generating latent features... ----") |
|
start_time = time.time() |
|
|
|
combine_label = combine_label_or.to(device) |
|
if ( |
|
output_size[0] != combine_label.shape[2] |
|
or output_size[1] != combine_label.shape[3] |
|
or output_size[2] != combine_label.shape[4] |
|
): |
|
logging.info( |
|
"output_size is not a desired value. Need to interpolate the mask to match " |
|
"with output_size. The result image will be very low quality." |
|
) |
|
combine_label = torch.nn.functional.interpolate(combine_label, size=output_size, mode="nearest") |
|
|
|
controlnet_cond_vis = binarize_labels(combine_label.as_tensor().long()).half() |
|
|
|
|
|
latents = initialize_noise_latents(latent_shape, device) * noise_factor |
|
|
|
|
|
noise_scheduler.set_timesteps( |
|
num_inference_steps=num_inference_steps, input_img_size=torch.prod(torch.tensor(latent_shape[-3:])) |
|
) |
|
|
|
guidance_scale = 0 |
|
all_next_timesteps = torch.cat( |
|
(noise_scheduler.timesteps[1:], torch.tensor([0], dtype=noise_scheduler.timesteps.dtype)) |
|
) |
|
for t, next_t in tqdm( |
|
zip(noise_scheduler.timesteps, all_next_timesteps), |
|
total=min(len(noise_scheduler.timesteps), len(all_next_timesteps)), |
|
): |
|
timesteps = torch.Tensor((t,)).to(device) |
|
if guidance_scale == 0: |
|
down_block_res_samples, mid_block_res_sample = controlnet( |
|
x=latents, timesteps=timesteps, controlnet_cond=controlnet_cond_vis, class_labels=modality_tensor |
|
) |
|
predicted_velocity = diffusion_unet( |
|
x=latents, |
|
timesteps=timesteps, |
|
spacing_tensor=spacing_tensor, |
|
class_labels=modality_tensor, |
|
down_block_additional_residuals=down_block_res_samples, |
|
mid_block_additional_residual=mid_block_res_sample, |
|
) |
|
else: |
|
down_block_res_samples, mid_block_res_sample = controlnet( |
|
x=torch.cat([latents] * 2), |
|
timesteps=torch.cat([timesteps] * 2), |
|
controlnet_cond=torch.cat([controlnet_cond_vis] * 2), |
|
class_labels=torch.cat([modality_tensor, torch.zeros_like(modality_tensor)]), |
|
) |
|
model_t, model_uncond = diffusion_unet( |
|
x=torch.cat([latents] * 2), |
|
timesteps=timesteps, |
|
spacing_tensor=torch.cat([timesteps] * 2), |
|
class_labels=torch.cat([modality_tensor, torch.zeros_like(modality_tensor)]), |
|
down_block_additional_residuals=down_block_res_samples, |
|
mid_block_additional_residual=mid_block_res_sample, |
|
).chunk(2) |
|
predicted_velocity = model_uncond + guidance_scale * (model_t - model_uncond) |
|
latents, _ = noise_scheduler.step(predicted_velocity, t, latents, next_timestep=next_t) |
|
end_time = time.time() |
|
logging.info(f"---- Latent features generation time: {end_time - start_time} seconds ----") |
|
del predicted_velocity |
|
torch.cuda.empty_cache() |
|
|
|
|
|
logging.info("---- Start decoding latent features into images... ----") |
|
inferer = SlidingWindowInferer( |
|
roi_size=autoencoder_sliding_window_infer_size, |
|
sw_batch_size=1, |
|
progress=True, |
|
mode="gaussian", |
|
overlap=autoencoder_sliding_window_infer_overlap, |
|
device=torch.device("cpu"), |
|
sw_device=device, |
|
) |
|
start_time = time.time() |
|
synthetic_images = dynamic_infer(inferer, recon_model, latents) |
|
synthetic_images = torch.clip(synthetic_images, b_min, b_max).cpu() |
|
end_time = time.time() |
|
logging.info(f"---- Image decoding time: {end_time - start_time} seconds ----") |
|
|
|
|
|
|
|
synthetic_images = (synthetic_images - b_min) / (b_max - b_min) |
|
|
|
synthetic_images = synthetic_images * (a_max - a_min) + a_min |
|
|
|
synthetic_images = crop_img_body_mask(synthetic_images, combine_label) |
|
torch.cuda.empty_cache() |
|
|
|
return synthetic_images, combine_label |
|
|
|
|
|
def filter_mask_with_organs(combine_label, anatomy_list): |
|
""" |
|
Filter a mask to only include specified organs. |
|
|
|
Args: |
|
combine_label (torch.Tensor): The input mask. |
|
anatomy_list (list): List of organ labels to keep. |
|
|
|
Returns: |
|
torch.Tensor: The filtered mask. |
|
""" |
|
|
|
|
|
combine_label = combine_label.long() |
|
|
|
for i in range(len(anatomy_list)): |
|
organ = anatomy_list[i] |
|
|
|
combine_label[combine_label == organ] = -(i + 1) |
|
|
|
combine_label[combine_label > 0] = 0 |
|
|
|
combine_label = -combine_label |
|
return combine_label |
|
|
|
|
|
def crop_img_body_mask(synthetic_images, combine_label): |
|
""" |
|
Crop the synthetic image using a body mask. |
|
|
|
Args: |
|
synthetic_images (torch.Tensor): The synthetic images. |
|
combine_label (torch.Tensor): The body mask. |
|
|
|
Returns: |
|
torch.Tensor: The cropped synthetic images. |
|
""" |
|
synthetic_images[combine_label == 0] = -1000 |
|
return synthetic_images |
|
|
|
|
|
def check_input(body_region, anatomy_list, label_dict_json, output_size, spacing, controllable_anatomy_size): |
|
""" |
|
Validate input parameters for image generation. |
|
|
|
Args: |
|
body_region (list): List of body regions. |
|
anatomy_list (list): List of anatomical structures. |
|
label_dict_json (str): Path to the label dictionary JSON file. |
|
output_size (tuple): Desired output size of the image. |
|
spacing (tuple): Desired voxel spacing. |
|
controllable_anatomy_size (list): List of tuples specifying controllable anatomy sizes. |
|
|
|
Raises: |
|
ValueError: If any input parameter is invalid. |
|
""" |
|
|
|
if output_size[0] != output_size[1]: |
|
raise ValueError(f"The first two components of output_size need to be equal, yet got {output_size}.") |
|
if (output_size[0] not in [256, 384, 512]) or (output_size[2] not in [128, 256, 384, 512, 640, 768]): |
|
raise ValueError( |
|
( |
|
"The output_size[0] have to be chosen from [256, 384, 512], and output_size[2] " |
|
f"have to be chosen from [128, 256, 384, 512, 640, 768], yet got {output_size}." |
|
) |
|
) |
|
|
|
if spacing[0] != spacing[1]: |
|
raise ValueError(f"The first two components of spacing need to be equal, yet got {spacing}.") |
|
if spacing[0] < 0.5 or spacing[0] > 3.0 or spacing[2] < 0.5 or spacing[2] > 5.0: |
|
raise ValueError( |
|
f"spacing[0] have to be between 0.5 and 3.0 mm, spacing[2] have to be between 0.5 and 5.0 mm, yet got {spacing}." |
|
) |
|
|
|
if ( |
|
output_size[0] * spacing[0] < 256 |
|
or output_size[2] * spacing[2] < 128 |
|
or output_size[0] * spacing[0] > 640 |
|
or output_size[2] * spacing[2] > 2000 |
|
): |
|
fov = [output_size[axis] * spacing[axis] for axis in range(3)] |
|
raise ValueError( |
|
( |
|
f"`'spacing'({spacing}mm) and 'output_size'({output_size}) together decide the output field of view (FOV). " |
|
f"The FOV will be {fov}mm. We recommend the FOV in x and y axis to be at least 256mm for head, and at least " |
|
"384mm for other body regions like abdomen, and less than 640mm. " |
|
"For z-axis, we require it to be at least 128mm and less than 2000mm." |
|
) |
|
) |
|
|
|
|
|
if len(controllable_anatomy_size) > 10: |
|
raise ValueError( |
|
( |
|
"The output_size[0] have to be chosen from [256, 384, 512], and output_size[2] " |
|
f"have to be chosen from [128, 256, 384, 512, 640, 768], yet got {output_size}." |
|
) |
|
) |
|
available_controllable_organ = ["liver", "gallbladder", "stomach", "pancreas", "colon"] |
|
available_controllable_tumor = [ |
|
"hepatic tumor", |
|
"bone lesion", |
|
"lung tumor", |
|
"colon cancer primaries", |
|
"pancreatic tumor", |
|
] |
|
available_controllable_anatomy = available_controllable_organ + available_controllable_tumor |
|
controllable_tumor = [] |
|
controllable_organ = [] |
|
for controllable_anatomy_size_pair in controllable_anatomy_size: |
|
if controllable_anatomy_size_pair[0] not in available_controllable_anatomy: |
|
raise ValueError( |
|
( |
|
f"The controllable_anatomy have to be chosen from {available_controllable_anatomy}, " |
|
f"yet got {controllable_anatomy_size_pair[0]}." |
|
) |
|
) |
|
if controllable_anatomy_size_pair[0] in available_controllable_tumor: |
|
controllable_tumor += [controllable_anatomy_size_pair[0]] |
|
if controllable_anatomy_size_pair[0] in available_controllable_organ: |
|
controllable_organ += [controllable_anatomy_size_pair[0]] |
|
if controllable_anatomy_size_pair[1] == -1: |
|
continue |
|
if controllable_anatomy_size_pair[1] < 0 or controllable_anatomy_size_pair[1] > 1.0: |
|
raise ValueError( |
|
( |
|
"The controllable size scale have to be between 0 and 1,0, or equal to -1, " |
|
f"yet got {controllable_anatomy_size_pair[1]}." |
|
) |
|
) |
|
if len(controllable_tumor + controllable_organ) != len(list(set(controllable_tumor + controllable_organ))): |
|
raise ValueError(f"Please do not repeat controllable_anatomy. Got {controllable_tumor + controllable_organ}.") |
|
if len(controllable_tumor) > 1: |
|
raise ValueError(f"Only one controllable tumor is supported. Yet got {controllable_tumor}.") |
|
|
|
if len(controllable_anatomy_size) > 0: |
|
logging.info( |
|
( |
|
"`controllable_anatomy_size` is not empty.\nWe will ignore `body_region` and `anatomy_list` " |
|
f"and synthesize based on `controllable_anatomy_size`: ({controllable_anatomy_size})." |
|
) |
|
) |
|
else: |
|
logging.info( |
|
(f"`controllable_anatomy_size` is empty.\nWe will synthesize based on `anatomy_list`: ({anatomy_list}).") |
|
) |
|
|
|
available_body_region = ["head", "chest", "thorax", "abdomen", "pelvis", "lower"] |
|
for region in body_region: |
|
if region not in available_body_region: |
|
raise ValueError( |
|
f"The components in body_region have to be chosen from {available_body_region}, yet got {region}." |
|
) |
|
|
|
|
|
with open(label_dict_json) as f: |
|
label_dict = json.load(f) |
|
for anatomy in anatomy_list: |
|
if anatomy not in label_dict.keys(): |
|
raise ValueError( |
|
f"The components in anatomy_list have to be chosen from {label_dict.keys()}, yet got {anatomy}." |
|
) |
|
logging.info(f"The generate results will have voxel size to be {spacing} mm, volume size to be {output_size}.") |
|
|
|
return |
|
|
|
|
|
class LDMSampler: |
|
""" |
|
A sampler class for generating synthetic medical images and masks using latent diffusion models. |
|
|
|
Attributes: |
|
Various attributes related to model configuration, input parameters, and generation settings. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
body_region, |
|
anatomy_list, |
|
modality, |
|
all_mask_files_json, |
|
all_anatomy_size_condtions_json, |
|
all_mask_files_base_dir, |
|
label_dict_json, |
|
label_dict_remap_json, |
|
autoencoder, |
|
diffusion_unet, |
|
controlnet, |
|
noise_scheduler, |
|
scale_factor, |
|
mask_generation_autoencoder, |
|
mask_generation_diffusion_unet, |
|
mask_generation_scale_factor, |
|
mask_generation_noise_scheduler, |
|
device, |
|
latent_shape, |
|
mask_generation_latent_shape, |
|
output_size, |
|
output_dir, |
|
controllable_anatomy_size, |
|
image_output_ext=".nii.gz", |
|
label_output_ext=".nii.gz", |
|
real_img_median_statistics="./configs/image_median_statistics.json", |
|
spacing=(1, 1, 1), |
|
num_inference_steps=None, |
|
mask_generation_num_inference_steps=None, |
|
random_seed=None, |
|
autoencoder_sliding_window_infer_size=(96, 96, 96), |
|
autoencoder_sliding_window_infer_overlap=0.6667, |
|
) -> None: |
|
""" |
|
Initialize the LDMSampler with various parameters and models. |
|
|
|
Args: |
|
Various parameters related to model configuration, input settings, and output specifications. |
|
""" |
|
self.random_seed = random_seed |
|
if random_seed is not None: |
|
set_determinism(seed=random_seed) |
|
|
|
with open(label_dict_json, "r") as f: |
|
label_dict = json.load(f) |
|
self.all_anatomy_size_condtions_json = all_anatomy_size_condtions_json |
|
|
|
|
|
self.body_region = body_region |
|
self.anatomy_list = [label_dict[organ] for organ in anatomy_list] |
|
self.modality_int = modality_mapping[modality] |
|
self.all_mask_files_json = all_mask_files_json |
|
self.data_root = all_mask_files_base_dir |
|
self.label_dict_remap_json = label_dict_remap_json |
|
self.autoencoder = autoencoder |
|
self.diffusion_unet = diffusion_unet |
|
self.controlnet = controlnet |
|
self.noise_scheduler = noise_scheduler |
|
self.scale_factor = scale_factor |
|
self.mask_generation_autoencoder = mask_generation_autoencoder |
|
self.mask_generation_diffusion_unet = mask_generation_diffusion_unet |
|
self.mask_generation_scale_factor = mask_generation_scale_factor |
|
self.mask_generation_noise_scheduler = mask_generation_noise_scheduler |
|
self.device = device |
|
self.latent_shape = latent_shape |
|
self.mask_generation_latent_shape = mask_generation_latent_shape |
|
self.output_size = output_size |
|
self.output_dir = output_dir |
|
self.noise_factor = 1.0 |
|
self.controllable_anatomy_size = controllable_anatomy_size |
|
if len(self.controllable_anatomy_size): |
|
logging.info("controllable_anatomy_size is given, mask generation is triggered!") |
|
|
|
self.anatomy_list = [label_dict[organ_and_size[0]] for organ_and_size in self.controllable_anatomy_size] |
|
self.image_output_ext = image_output_ext |
|
self.label_output_ext = label_output_ext |
|
|
|
self.num_inference_steps = num_inference_steps if num_inference_steps is not None else 1000 |
|
self.mask_generation_num_inference_steps = ( |
|
mask_generation_num_inference_steps if mask_generation_num_inference_steps is not None else 1000 |
|
) |
|
|
|
if any(size % 16 != 0 for size in autoencoder_sliding_window_infer_size): |
|
raise ValueError( |
|
f"autoencoder_sliding_window_infer_size must be divisible by 16.\n Got {autoencoder_sliding_window_infer_size}" |
|
) |
|
if not (0 <= autoencoder_sliding_window_infer_overlap <= 1): |
|
raise ValueError( |
|
( |
|
"Value of autoencoder_sliding_window_infer_overlap must be between 0 " |
|
f"and 1.\n Got {autoencoder_sliding_window_infer_overlap}" |
|
) |
|
) |
|
self.autoencoder_sliding_window_infer_size = autoencoder_sliding_window_infer_size |
|
self.autoencoder_sliding_window_infer_overlap = autoencoder_sliding_window_infer_overlap |
|
|
|
|
|
self.max_try_time = 3 |
|
with open(real_img_median_statistics, "r") as json_file: |
|
self.median_statistics = json.load(json_file) |
|
self.label_int_dict = { |
|
"liver": [1], |
|
"spleen": [3], |
|
"pancreas": [4], |
|
"kidney": [5, 14], |
|
"lung": [28, 29, 30, 31, 31], |
|
"brain": [22], |
|
"hepatic tumor": [26], |
|
"bone lesion": [128], |
|
"lung tumor": [23], |
|
"colon cancer primaries": [27], |
|
"pancreatic tumor": [24], |
|
"bone": list(range(33, 57)) + list(range(63, 98)) + [120, 122, 127], |
|
} |
|
|
|
|
|
self.autoencoder.eval() |
|
self.diffusion_unet.eval() |
|
self.controlnet.eval() |
|
self.mask_generation_autoencoder.eval() |
|
self.mask_generation_diffusion_unet.eval() |
|
|
|
self.spacing = spacing |
|
|
|
self.val_transforms = Compose( |
|
[ |
|
monai.transforms.LoadImaged(keys=["pseudo_label"]), |
|
monai.transforms.EnsureChannelFirstd(keys=["pseudo_label"]), |
|
monai.transforms.Orientationd(keys=["pseudo_label"], axcodes="RAS"), |
|
monai.transforms.EnsureTyped(keys=["pseudo_label"], dtype=torch.uint8), |
|
monai.transforms.Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(x)), |
|
monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2), |
|
] |
|
) |
|
logging.info("LDM sampler initialized.") |
|
|
|
def sample_multiple_images(self, num_img): |
|
""" |
|
Generate multiple synthetic images and masks. |
|
|
|
Args: |
|
num_img (int): Number of images to generate. |
|
""" |
|
output_filenames = [] |
|
if len(self.controllable_anatomy_size) > 0: |
|
|
|
|
|
selected_mask_files = list(range(num_img)) |
|
|
|
anatomy_size_condtion = self.prepare_anatomy_size_condtion(self.controllable_anatomy_size) |
|
else: |
|
need_resample = False |
|
|
|
candidate_mask_files = find_masks( |
|
self.anatomy_list, self.spacing, self.output_size, True, self.all_mask_files_json, self.data_root |
|
) |
|
if len(candidate_mask_files) < num_img: |
|
|
|
|
|
logging.info("Resample mask file to get desired output size and spacing") |
|
candidate_mask_files = self.find_closest_masks(num_img) |
|
need_resample = True |
|
|
|
selected_mask_files = self.select_mask(candidate_mask_files, num_img) |
|
if len(selected_mask_files) < num_img: |
|
raise ValueError( |
|
( |
|
f"len(selected_mask_files) ({len(selected_mask_files)}) < num_img ({num_img}). " |
|
"This should not happen. Please revisit function select_mask(self, candidate_mask_files, num_img)." |
|
) |
|
) |
|
num_generated_img = 0 |
|
for index_s in range(len(selected_mask_files)): |
|
item = selected_mask_files[index_s] |
|
if num_generated_img >= num_img: |
|
break |
|
logging.info("---- Start preparing masks... ----") |
|
start_time = time.time() |
|
logging.info(f"Image will be generated based on {item}.") |
|
if len(self.controllable_anatomy_size) > 0: |
|
|
|
(combine_label_or, spacing_tensor) = self.prepare_one_mask_and_meta_info(anatomy_size_condtion) |
|
else: |
|
|
|
mask_file = item["mask_file"] |
|
if_aug = item["if_aug"] |
|
(combine_label_or, spacing_tensor) = self.read_mask_information(mask_file) |
|
if need_resample: |
|
combine_label_or = self.ensure_output_size_and_spacing(combine_label_or) |
|
|
|
if if_aug: |
|
combine_label_or = augmentation(combine_label_or, self.output_size, random_seed=self.random_seed) |
|
end_time = time.time() |
|
logging.info(f"---- Mask preparation time: {end_time - start_time} seconds ----") |
|
torch.cuda.empty_cache() |
|
|
|
modality_tensor = torch.ones_like(spacing_tensor[:, 0]).long() * self.modality_int |
|
|
|
synthetic_images, synthetic_labels = self.sample_one_pair(combine_label_or, modality_tensor, spacing_tensor) |
|
|
|
pass_quality_check = self.quality_check( |
|
synthetic_images.cpu().detach().numpy(), combine_label_or.cpu().detach().numpy() |
|
) |
|
if pass_quality_check or (num_img - num_generated_img) >= (len(selected_mask_files) - index_s): |
|
if not pass_quality_check: |
|
logging.info( |
|
"Generated image/label pair did not pass quality check, but will still save them. " |
|
"Please consider changing spacing and output_size to facilitate a more realistic setting." |
|
) |
|
num_generated_img = num_generated_img + 1 |
|
|
|
output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f") |
|
synthetic_labels.meta["filename_or_obj"] = "sample.nii.gz" |
|
synthetic_images = MetaTensor(synthetic_images, meta=synthetic_labels.meta) |
|
img_saver = SaveImage( |
|
output_dir=self.output_dir, |
|
output_postfix=output_postfix + "_image", |
|
output_ext=self.image_output_ext, |
|
separate_folder=False, |
|
) |
|
img_saver(synthetic_images[0]) |
|
synthetic_images_filename = os.path.join( |
|
self.output_dir, "sample_" + output_postfix + "_image" + self.image_output_ext |
|
) |
|
|
|
synthetic_labels = filter_mask_with_organs(synthetic_labels, self.anatomy_list) |
|
label_saver = SaveImage( |
|
output_dir=self.output_dir, |
|
output_postfix=output_postfix + "_label", |
|
output_ext=self.label_output_ext, |
|
separate_folder=False, |
|
) |
|
label_saver(synthetic_labels[0]) |
|
synthetic_labels_filename = os.path.join( |
|
self.output_dir, "sample_" + output_postfix + "_label" + self.label_output_ext |
|
) |
|
output_filenames.append([synthetic_images_filename, synthetic_labels_filename]) |
|
else: |
|
logging.info("Generated image/label pair did not pass quality check, will re-generate another pair.") |
|
return output_filenames |
|
|
|
def select_mask(self, candidate_mask_files, num_img): |
|
""" |
|
Select mask files for image generation. |
|
|
|
Args: |
|
candidate_mask_files (list): List of candidate mask files. |
|
num_img (int): Number of images to generate. |
|
|
|
Returns: |
|
list: Selected mask files with augmentation flags. |
|
""" |
|
selected_mask_files = [] |
|
random.shuffle(candidate_mask_files) |
|
|
|
for n in range(num_img * self.max_try_time): |
|
mask_file = candidate_mask_files[n % len(candidate_mask_files)] |
|
selected_mask_files.append({"mask_file": mask_file, "if_aug": True}) |
|
return selected_mask_files |
|
|
|
def sample_one_pair(self, combine_label_or_aug, modality_tensor, spacing_tensor): |
|
""" |
|
Generate a single pair of synthetic image and mask. |
|
|
|
Args: |
|
combine_label_or_aug (torch.Tensor): Combined label tensor or augmented label. |
|
modality_tensor (torch.Tensor): Tensor specifying the image modality. |
|
spacing_tensor (torch.Tensor): Tensor specifying the spacing. |
|
|
|
Returns: |
|
tuple: A tuple containing the synthetic image and its corresponding label. |
|
""" |
|
|
|
synthetic_images, synthetic_labels = ldm_conditional_sample_one_image( |
|
autoencoder=self.autoencoder, |
|
diffusion_unet=self.diffusion_unet, |
|
controlnet=self.controlnet, |
|
noise_scheduler=self.noise_scheduler, |
|
scale_factor=self.scale_factor, |
|
device=self.device, |
|
combine_label_or=combine_label_or_aug, |
|
modality_tensor=modality_tensor, |
|
spacing_tensor=spacing_tensor, |
|
latent_shape=self.latent_shape, |
|
output_size=self.output_size, |
|
noise_factor=self.noise_factor, |
|
num_inference_steps=self.num_inference_steps, |
|
autoencoder_sliding_window_infer_size=self.autoencoder_sliding_window_infer_size, |
|
autoencoder_sliding_window_infer_overlap=self.autoencoder_sliding_window_infer_overlap, |
|
) |
|
return synthetic_images, synthetic_labels |
|
|
|
def prepare_anatomy_size_condtion(self, controllable_anatomy_size): |
|
""" |
|
Prepare anatomy size conditions for mask generation. |
|
|
|
Args: |
|
controllable_anatomy_size (list): List of tuples specifying controllable anatomy sizes. |
|
|
|
Returns: |
|
list: Prepared anatomy size conditions. |
|
""" |
|
anatomy_size_idx = { |
|
"gallbladder": 0, |
|
"liver": 1, |
|
"stomach": 2, |
|
"pancreas": 3, |
|
"colon": 4, |
|
"lung tumor": 5, |
|
"pancreatic tumor": 6, |
|
"hepatic tumor": 7, |
|
"colon cancer primaries": 8, |
|
"bone lesion": 9, |
|
} |
|
provide_anatomy_size = [None for _ in range(10)] |
|
logging.info(f"controllable_anatomy_size: {controllable_anatomy_size}") |
|
for element in controllable_anatomy_size: |
|
anatomy_name, anatomy_size = element |
|
provide_anatomy_size[anatomy_size_idx[anatomy_name]] = anatomy_size |
|
|
|
with open(self.all_anatomy_size_condtions_json, "r") as f: |
|
all_anatomy_size_condtions = json.load(f) |
|
|
|
|
|
candidate_list = [] |
|
for anatomy_size in all_anatomy_size_condtions: |
|
size = anatomy_size["organ_size"] |
|
diff = 0 |
|
for db_size, provide_size in zip(size, provide_anatomy_size): |
|
if provide_size is None: |
|
continue |
|
diff += abs(provide_size - db_size) |
|
candidate_list.append((size, diff)) |
|
candidate_condition = sorted(candidate_list, key=lambda x: x[1])[0][0] |
|
|
|
|
|
for element in controllable_anatomy_size: |
|
anatomy_name, anatomy_size = element |
|
candidate_condition[anatomy_size_idx[anatomy_name]] = anatomy_size |
|
|
|
return candidate_condition |
|
|
|
def prepare_one_mask_and_meta_info(self, anatomy_size_condtion): |
|
""" |
|
Prepare a single mask and its associated meta information. |
|
|
|
Args: |
|
anatomy_size_condtion (list): Anatomy size conditions. |
|
|
|
Returns: |
|
tuple: A tuple containing the prepared mask and associated tensors. |
|
""" |
|
combine_label_or = self.sample_one_mask(anatomy_size=anatomy_size_condtion) |
|
|
|
affine = torch.zeros((4, 4)) |
|
affine[0, 0] = 1.5 |
|
affine[1, 1] = 1.5 |
|
affine[2, 2] = 1.5 |
|
affine[3, 3] = 1.0 |
|
combine_label_or = MetaTensor(combine_label_or, affine=affine) |
|
combine_label_or = self.ensure_output_size_and_spacing(combine_label_or) |
|
|
|
spacing_tensor = torch.FloatTensor(self.spacing).unsqueeze(0).half().to(self.device) * 1e2 |
|
|
|
return combine_label_or, spacing_tensor |
|
|
|
def sample_one_mask(self, anatomy_size): |
|
""" |
|
Generate a single synthetic mask. |
|
|
|
Args: |
|
anatomy_size (list): Anatomy size specifications. |
|
|
|
Returns: |
|
torch.Tensor: The generated synthetic mask. |
|
""" |
|
|
|
synthetic_mask = ldm_conditional_sample_one_mask( |
|
self.mask_generation_autoencoder, |
|
self.mask_generation_diffusion_unet, |
|
self.mask_generation_noise_scheduler, |
|
self.mask_generation_scale_factor, |
|
anatomy_size, |
|
self.device, |
|
self.mask_generation_latent_shape, |
|
label_dict_remap_json=self.label_dict_remap_json, |
|
num_inference_steps=self.mask_generation_num_inference_steps, |
|
autoencoder_sliding_window_infer_size=self.autoencoder_sliding_window_infer_size, |
|
autoencoder_sliding_window_infer_overlap=self.autoencoder_sliding_window_infer_overlap, |
|
) |
|
return synthetic_mask |
|
|
|
def ensure_output_size_and_spacing(self, labels, check_contains_target_labels=True): |
|
""" |
|
Ensure the output mask has the correct size and spacing. |
|
|
|
Args: |
|
labels (torch.Tensor): Input label tensor. |
|
check_contains_target_labels (bool): Whether to check if the resampled mask contains target labels. |
|
|
|
Returns: |
|
torch.Tensor: Resampled label tensor. |
|
|
|
Raises: |
|
ValueError: If the resampled mask doesn't contain required class labels. |
|
""" |
|
current_spacing = [labels.affine[0, 0], labels.affine[1, 1], labels.affine[2, 2]] |
|
current_shape = list(labels.squeeze().shape) |
|
|
|
need_resample = False |
|
|
|
for i, j in zip(current_spacing, self.spacing): |
|
if i != j: |
|
need_resample = True |
|
|
|
for i, j in zip(current_shape, self.output_size): |
|
if i != j: |
|
need_resample = True |
|
|
|
if need_resample: |
|
logging.info("Resampling mask to target shape and spacing") |
|
logging.info(f"Resize Spacing: {current_spacing} -> {self.spacing}") |
|
logging.info(f"Output size: {current_shape} -> {self.output_size}") |
|
spacing = monai.transforms.Spacing(pixdim=tuple(self.spacing), mode="nearest") |
|
pad_crop = monai.transforms.ResizeWithPadOrCrop(spatial_size=tuple(self.output_size)) |
|
labels = pad_crop(spacing(labels.squeeze(0))).unsqueeze(0).to(labels.dtype) |
|
|
|
contained_labels = torch.unique(labels) |
|
if check_contains_target_labels: |
|
|
|
for anatomy_label in self.anatomy_list: |
|
if anatomy_label not in contained_labels: |
|
raise ValueError( |
|
( |
|
f"Resampled mask does not contain required class labels {anatomy_label}. " |
|
"Please consider increasing the output spacing or specifying a larger output size." |
|
) |
|
) |
|
return labels |
|
|
|
def read_mask_information(self, mask_file): |
|
""" |
|
Read mask information from a file. |
|
|
|
Args: |
|
mask_file (str): Path to the mask file. |
|
|
|
Returns: |
|
tuple: A tuple containing the mask tensor and associated information. |
|
""" |
|
val_data = self.val_transforms(mask_file) |
|
|
|
for key in ["pseudo_label", "spacing"]: |
|
val_data[key] = val_data[key].unsqueeze(0).to(self.device) |
|
|
|
return (val_data["pseudo_label"], val_data["spacing"]) |
|
|
|
def find_closest_masks(self, num_img): |
|
""" |
|
Find the closest matching masks from the database. |
|
|
|
Args: |
|
num_img (int): Number of images to generate. |
|
|
|
Returns: |
|
list: List of closest matching mask candidates. |
|
|
|
Raises: |
|
ValueError: If suitable candidates cannot be found. |
|
""" |
|
|
|
candidates = find_masks( |
|
self.anatomy_list, self.spacing, self.output_size, False, self.all_mask_files_json, self.data_root |
|
) |
|
|
|
if len(candidates) < num_img: |
|
raise ValueError(f"candidate masks are less than {num_img}).") |
|
|
|
|
|
new_candidates = [] |
|
for c in candidates: |
|
diff = 0 |
|
include_c = True |
|
for axis in range(3): |
|
if abs(c["dim"][axis]) < self.output_size[axis] - 64: |
|
|
|
include_c = False |
|
break |
|
|
|
diff += abs( |
|
(abs(c["dim"][axis] * c["spacing"][axis]) - self.output_size[axis] * self.spacing[axis]) / 10 |
|
) |
|
|
|
diff += abs((abs(c["dim"][axis]) - self.output_size[axis]) / 100) |
|
|
|
diff += abs(abs(c["spacing"][axis]) - self.spacing[axis]) |
|
if include_c: |
|
new_candidates.append((c, diff)) |
|
|
|
|
|
new_candidates = sorted(new_candidates, key=lambda x: x[1])[: max(2 * num_img, 5)] |
|
final_candidates = [] |
|
|
|
|
|
image_loader = monai.transforms.LoadImage(image_only=True, ensure_channel_first=True) |
|
for c, _ in new_candidates: |
|
label = image_loader(c["pseudo_label"]) |
|
try: |
|
label = self.ensure_output_size_and_spacing(label.unsqueeze(0)) |
|
except ValueError as e: |
|
if "Resampled mask does not contain required class labels" in str(e): |
|
continue |
|
else: |
|
raise e |
|
|
|
c["spacing"] = self.spacing |
|
c["dim"] = self.output_size |
|
|
|
final_candidates.append(c) |
|
if len(final_candidates) == 0: |
|
raise ValueError("Cannot find body region with given anatomy list.") |
|
return final_candidates |
|
|
|
def quality_check(self, image_data, label_data): |
|
""" |
|
Perform a quality check on the generated image. |
|
Args: |
|
image_data (np.ndarray): The generated image. |
|
label_data (np.ndarray): The corresponding whole body mask. |
|
Returns: |
|
bool: True if the image passes the quality check, False otherwise. |
|
""" |
|
outlier_results = is_outlier(self.median_statistics, image_data, label_data, self.label_int_dict) |
|
for label, result in outlier_results.items(): |
|
if result.get("is_outlier", False): |
|
logging.info( |
|
( |
|
f"Generated image quality check for label '{label}' failed: median value {result['median_value']} " |
|
f"is outside the acceptable range ({result['low_thresh']} - {result['high_thresh']})." |
|
) |
|
) |
|
return False |
|
return True |
|
|