maisi_ct_generative / scripts /quality_check.py
Venn's picture
Upload maisi_ct_generative version 1.0.0
02aa18d verified
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
def get_masked_data(label_data, image_data, labels):
"""
Extracts and returns the image data corresponding to specified labels within a 3D volume.
This function efficiently masks the `image_data` array based on the provided `labels` in the `label_data` array.
The function handles cases with both a large and small number of labels, optimizing performance accordingly.
Args:
label_data (np.ndarray): A NumPy array containing label data, representing different anatomical
regions or classes in a 3D medical image.
image_data (np.ndarray): A NumPy array containing the image data from which the relevant regions
will be extracted.
labels (list of int): A list of integers representing the label values to be used for masking.
Returns:
np.ndarray: A NumPy array containing the elements of `image_data` that correspond to the specified
labels in `label_data`. If no labels are provided, an empty array is returned.
Raises:
ValueError: If `image_data` and `label_data` do not have the same shape.
Example:
label_int_dict = {"liver": [1], "kidney": [5, 14]}
masked_data = get_masked_data(label_data, image_data, label_int_dict["kidney"])
"""
# Check if the shapes of image_data and label_data match
if image_data.shape != label_data.shape:
raise ValueError(
f"Shape mismatch: image_data has shape {image_data.shape}, "
f"but label_data has shape {label_data.shape}. They must be the same."
)
if not labels:
return np.array([]) # Return an empty array if no labels are provided
labels = list(set(labels)) # remove duplicate items
# Optimize performance based on the number of labels
num_label_acceleration_thresh = 3
if len(labels) >= num_label_acceleration_thresh:
# if many labels, np.isin is faster
mask = np.isin(label_data, labels)
else:
# Use logical OR to combine masks if the number of labels is small
mask = np.zeros_like(label_data, dtype=bool)
for label in labels:
mask = np.logical_or(mask, label_data == label)
# Retrieve the masked data
masked_data = image_data[mask.astype(bool)]
return masked_data
def is_outlier(statistics, image_data, label_data, label_int_dict):
"""
Perform a quality check on the generated image by comparing its statistics with precomputed thresholds.
Args:
statistics (dict): Dictionary containing precomputed statistics including mean +/- 3sigma ranges.
image_data (np.ndarray): The image data to be checked, typically a 3D NumPy array.
label_data (np.ndarray): The label data corresponding to the image, used for masking regions of interest.
label_int_dict (dict): Dictionary mapping label names to their corresponding integer lists.
e.g., label_int_dict = {"liver": [1], "kidney": [5, 14]}
Returns:
dict: A dictionary with labels as keys, each containing the quality check result,
including whether it's an outlier, the median value, and the thresholds used.
If no data is found for a label, the median value will be `None` and `is_outlier` will be `False`.
Example:
# Example input data
statistics = {
"liver": {
"sigma_6_low": -21.596463547885904,
"sigma_6_high": 156.27881534763367
},
"kidney": {
"sigma_6_low": -15.0,
"sigma_6_high": 120.0
}
}
label_int_dict = {
"liver": [1],
"kidney": [5, 14]
}
image_data = np.random.rand(100, 100, 100) # Replace with actual image data
label_data = np.zeros((100, 100, 100)) # Replace with actual label data
label_data[40:60, 40:60, 40:60] = 1 # Example region for liver
label_data[70:90, 70:90, 70:90] = 5 # Example region for kidney
result = is_outlier(statistics, image_data, label_data, label_int_dict)
"""
outlier_results = {}
for label_name, stats in statistics.items():
# Get the thresholds from the statistics
low_thresh = min(stats["sigma_6_low"], stats["percentile_0_5"]) # or "sigma_12_low" depending on your needs
high_thresh = max(stats["sigma_6_high"], stats["percentile_99_5"]) # or "sigma_12_high" depending on your needs
if label_name == "bone":
high_thresh = 1000.0
# Retrieve the corresponding label integers
labels = label_int_dict.get(label_name, [])
masked_data = get_masked_data(label_data, image_data, labels)
masked_data = masked_data[~np.isnan(masked_data)]
if len(masked_data) == 0 or masked_data.size == 0:
outlier_results[label_name] = {
"is_outlier": False,
"median_value": None,
"low_thresh": low_thresh,
"high_thresh": high_thresh,
}
continue
# Compute the median of the masked region
median_value = np.nanmedian(masked_data)
if np.isnan(median_value):
median_value = None
is_outlier = False
else:
# Determine if the median value is an outlier
is_outlier = median_value < low_thresh or median_value > high_thresh
outlier_results[label_name] = {
"is_outlier": is_outlier,
"median_value": median_value,
"low_thresh": low_thresh,
"high_thresh": high_thresh,
}
return outlier_results