File size: 6,297 Bytes
02aa18d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np


def get_masked_data(label_data, image_data, labels):
    """
    Extracts and returns the image data corresponding to specified labels within a 3D volume.

    This function efficiently masks the `image_data` array based on the provided `labels` in the `label_data` array.
    The function handles cases with both a large and small number of labels, optimizing performance accordingly.

    Args:
        label_data (np.ndarray): A NumPy array containing label data, representing different anatomical
                                 regions or classes in a 3D medical image.
        image_data (np.ndarray): A NumPy array containing the image data from which the relevant regions
                                 will be extracted.
        labels (list of int): A list of integers representing the label values to be used for masking.

    Returns:
        np.ndarray: A NumPy array containing the elements of `image_data` that correspond to the specified
                    labels in `label_data`. If no labels are provided, an empty array is returned.

    Raises:
        ValueError: If `image_data` and `label_data` do not have the same shape.

    Example:
        label_int_dict = {"liver": [1], "kidney": [5, 14]}
        masked_data = get_masked_data(label_data, image_data, label_int_dict["kidney"])
    """

    # Check if the shapes of image_data and label_data match
    if image_data.shape != label_data.shape:
        raise ValueError(
            f"Shape mismatch: image_data has shape {image_data.shape}, "
            f"but label_data has shape {label_data.shape}. They must be the same."
        )

    if not labels:
        return np.array([])  # Return an empty array if no labels are provided

    labels = list(set(labels))  # remove duplicate items

    # Optimize performance based on the number of labels
    num_label_acceleration_thresh = 3
    if len(labels) >= num_label_acceleration_thresh:
        # if many labels, np.isin is faster
        mask = np.isin(label_data, labels)
    else:
        # Use logical OR to combine masks if the number of labels is small
        mask = np.zeros_like(label_data, dtype=bool)
        for label in labels:
            mask = np.logical_or(mask, label_data == label)

    # Retrieve the masked data
    masked_data = image_data[mask.astype(bool)]

    return masked_data


def is_outlier(statistics, image_data, label_data, label_int_dict):
    """
    Perform a quality check on the generated image by comparing its statistics with precomputed thresholds.

    Args:
        statistics (dict): Dictionary containing precomputed statistics including mean +/- 3sigma ranges.
        image_data (np.ndarray): The image data to be checked, typically a 3D NumPy array.
        label_data (np.ndarray): The label data corresponding to the image, used for masking regions of interest.
        label_int_dict (dict): Dictionary mapping label names to their corresponding integer lists.
            e.g., label_int_dict = {"liver": [1], "kidney": [5, 14]}

    Returns:
        dict: A dictionary with labels as keys, each containing the quality check result,
              including whether it's an outlier, the median value, and the thresholds used.
              If no data is found for a label, the median value will be `None` and `is_outlier` will be `False`.

    Example:
        # Example input data
        statistics = {
            "liver": {
                "sigma_6_low": -21.596463547885904,
                "sigma_6_high": 156.27881534763367
            },
            "kidney": {
                "sigma_6_low": -15.0,
                "sigma_6_high": 120.0
            }
        }
        label_int_dict = {
            "liver": [1],
            "kidney": [5, 14]
        }
        image_data = np.random.rand(100, 100, 100)  # Replace with actual image data
        label_data = np.zeros((100, 100, 100))  # Replace with actual label data
        label_data[40:60, 40:60, 40:60] = 1  # Example region for liver
        label_data[70:90, 70:90, 70:90] = 5  # Example region for kidney
        result = is_outlier(statistics, image_data, label_data, label_int_dict)
    """
    outlier_results = {}

    for label_name, stats in statistics.items():
        # Get the thresholds from the statistics
        low_thresh = min(stats["sigma_6_low"], stats["percentile_0_5"])  # or "sigma_12_low" depending on your needs
        high_thresh = max(stats["sigma_6_high"], stats["percentile_99_5"])  # or "sigma_12_high" depending on your needs

        if label_name == "bone":
            high_thresh = 1000.0

        # Retrieve the corresponding label integers
        labels = label_int_dict.get(label_name, [])
        masked_data = get_masked_data(label_data, image_data, labels)
        masked_data = masked_data[~np.isnan(masked_data)]

        if len(masked_data) == 0 or masked_data.size == 0:
            outlier_results[label_name] = {
                "is_outlier": False,
                "median_value": None,
                "low_thresh": low_thresh,
                "high_thresh": high_thresh,
            }
            continue

        # Compute the median of the masked region
        median_value = np.nanmedian(masked_data)

        if np.isnan(median_value):
            median_value = None
            is_outlier = False
        else:
            # Determine if the median value is an outlier
            is_outlier = median_value < low_thresh or median_value > high_thresh

        outlier_results[label_name] = {
            "is_outlier": is_outlier,
            "median_value": median_value,
            "low_thresh": low_thresh,
            "high_thresh": high_thresh,
        }

    return outlier_results