Spaces:
Sleeping
Sleeping
| import math | |
| from typing import List, Sequence | |
| import keras.utils as k_utils | |
| import numpy as np | |
| import pydicom | |
| from keras.utils.data_utils import OrderedEnqueuer | |
| from tqdm import tqdm | |
| def parse_windows(windows): | |
| """Parse windows provided by the user. | |
| These windows can either be strings corresponding to popular windowing | |
| thresholds for CT or tuples of (upper, lower) bounds. | |
| Args: | |
| windows (list): List of strings or tuples. | |
| Returns: | |
| list: List of tuples of (upper, lower) bounds. | |
| """ | |
| windowing = { | |
| "soft": (400, 50), | |
| "bone": (1800, 400), | |
| "liver": (150, 30), | |
| "spine": (250, 50), | |
| "custom": (500, 50), | |
| } | |
| vals = [] | |
| for w in windows: | |
| if isinstance(w, Sequence) and len(w) == 2: | |
| assert_msg = "Expected tuple of (lower, upper) bound" | |
| assert len(w) == 2, assert_msg | |
| assert isinstance(w[0], (float, int)), assert_msg | |
| assert isinstance(w[1], (float, int)), assert_msg | |
| assert w[0] < w[1], assert_msg | |
| vals.append(w) | |
| continue | |
| if w not in windowing: | |
| raise KeyError("Window {} not found".format(w)) | |
| window_width = windowing[w][0] | |
| window_level = windowing[w][1] | |
| upper = window_level + window_width / 2 | |
| lower = window_level - window_width / 2 | |
| vals.append((lower, upper)) | |
| return tuple(vals) | |
| def _window(xs, bounds): | |
| """Apply windowing to an array of CT images. | |
| Args: | |
| xs (ndarray): NxHxW | |
| bounds (tuple): (lower, upper) bounds | |
| Returns: | |
| ndarray: Windowed images. | |
| """ | |
| imgs = [] | |
| for lb, ub in bounds: | |
| imgs.append(np.clip(xs, a_min=lb, a_max=ub)) | |
| if len(imgs) == 1: | |
| return imgs[0] | |
| elif xs.shape[-1] == 1: | |
| return np.concatenate(imgs, axis=-1) | |
| else: | |
| return np.stack(imgs, axis=-1) | |
| class Dataset(k_utils.Sequence): | |
| def __init__(self, files: List[str], batch_size: int = 16, windows=None): | |
| self._files = files | |
| self._batch_size = batch_size | |
| self.windows = windows | |
| def __len__(self): | |
| return math.ceil(len(self._files) / self._batch_size) | |
| def __getitem__(self, idx): | |
| files = self._files[idx * self._batch_size : (idx + 1) * self._batch_size] | |
| dcms = [pydicom.read_file(f, force=True) for f in files] | |
| xs = [(x.pixel_array + int(x.RescaleIntercept)).astype("float32") for x in dcms] | |
| params = [ | |
| {"spacing": header.PixelSpacing, "image": x} for header, x in zip(dcms, xs) | |
| ] | |
| # Preprocess xs via windowing. | |
| xs = np.stack(xs, axis=0) | |
| if self.windows: | |
| xs = _window(xs, parse_windows(self.windows)) | |
| else: | |
| xs = xs[..., np.newaxis] | |
| return xs, params | |
| def _swap_muscle_imap(xs, ys, muscle_idx: int, imat_idx: int, threshold=-30.0): | |
| """ | |
| If pixel labeled as muscle but has HU < threshold, change label to imat. | |
| Args: | |
| xs (ndarray): NxHxWxC | |
| ys (ndarray): NxHxWxC | |
| muscle_idx (int): Index of the muscle label. | |
| imat_idx (int): Index of the imat label. | |
| threshold (float): Threshold for HU value. | |
| Returns: | |
| ndarray: Segmentation mask with swapped labels. | |
| """ | |
| labels = ys.copy() | |
| muscle_mask = (labels[..., muscle_idx] > 0.5).astype(int) | |
| imat_mask = labels[..., imat_idx] | |
| imat_mask[muscle_mask.astype(np.bool) & (xs < threshold)] = 1 | |
| muscle_mask[xs < threshold] = 0 | |
| labels[..., muscle_idx] = muscle_mask | |
| labels[..., imat_idx] = imat_mask | |
| return labels | |
| def postprocess(xs: np.ndarray, ys: np.ndarray): | |
| """Built-in post-processing. | |
| TODO: Make this configurable. | |
| Args: | |
| xs (ndarray): NxHxW | |
| ys (ndarray): NxHxWxC | |
| params (dictionary): Post-processing parameters. Must contain | |
| "categories". | |
| Returns: | |
| ndarray: Post-processed labels. | |
| """ | |
| # Add another channel full of zeros to ys | |
| ys = np.concatenate([ys, np.zeros_like(ys[..., :1])], axis=-1) | |
| # If muscle hu is < -30, assume it is imat. | |
| """ | |
| if "muscle" in categories and "imat" in categories: | |
| ys = _swap_muscle_imap( | |
| xs, | |
| ys, | |
| muscle_idx=categories["muscle"], | |
| imat_idx=categories["imat"], | |
| ) | |
| """ | |
| return ys | |
| def predict( | |
| model, | |
| dataset: Dataset, | |
| batch_size: int = 16, | |
| num_workers: int = 1, | |
| max_queue_size: int = 10, | |
| use_multiprocessing: bool = False, | |
| ): | |
| """Predict segmentation masks for a dataset. | |
| Args: | |
| model (keras.Model): Model to use for prediction. | |
| dataset (Dataset): Dataset to predict on. | |
| batch_size (int): Batch size. | |
| num_workers (int): Number of workers. | |
| max_queue_size (int): Maximum queue size. | |
| use_multiprocessing (bool): Use multiprocessing. | |
| use_postprocessing (bool): Use built-in post-processing. | |
| postprocessing_params (dict): Post-processing parameters. | |
| Returns: | |
| List: List of segmentation masks. | |
| """ | |
| if num_workers > 0: | |
| enqueuer = OrderedEnqueuer( | |
| dataset, use_multiprocessing=use_multiprocessing, shuffle=False | |
| ) | |
| enqueuer.start(workers=num_workers, max_queue_size=max_queue_size) | |
| output_generator = enqueuer.get() | |
| else: | |
| output_generator = iter(dataset) | |
| num_scans = len(dataset) | |
| xs = [] | |
| ys = [] | |
| params = [] | |
| for _ in tqdm(range(num_scans)): | |
| x, p_dicts = next(output_generator) | |
| y = model.predict(x, batch_size=batch_size) | |
| image = np.stack([out["image"] for out in p_dicts], axis=0) | |
| y = postprocess(image, y) | |
| params.extend(p_dicts) | |
| xs.extend([x[i, ...] for i in range(len(x))]) | |
| ys.extend([y[i, ...] for i in range(len(y))]) | |
| return xs, ys, params | |