Spaces:
Sleeping
Sleeping
| import math | |
| import operator | |
| import os | |
| import zipfile | |
| from pathlib import Path | |
| from time import time | |
| from tkinter import Tcl | |
| from typing import Union | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| import moviepy.video.io.ImageSequenceClip | |
| import nibabel as nib | |
| import numpy as np | |
| import pandas as pd | |
| import pydicom | |
| import wget | |
| from totalsegmentator.libs import nostdout | |
| from comp2comp.inference_class_base import InferenceClass | |
| class AortaSegmentation(InferenceClass): | |
| """Spine segmentation.""" | |
| def __init__(self, save=True): | |
| super().__init__() | |
| self.model_name = "totalsegmentator" | |
| self.save_segmentations = save | |
| def __call__(self, inference_pipeline): | |
| # inference_pipeline.dicom_series_path = self.input_path | |
| self.output_dir = inference_pipeline.output_dir | |
| self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") | |
| if not os.path.exists(self.output_dir_segmentations): | |
| os.makedirs(self.output_dir_segmentations) | |
| self.model_dir = inference_pipeline.model_dir | |
| seg, mv = self.spine_seg( | |
| os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), | |
| self.output_dir_segmentations + "spine.nii.gz", | |
| inference_pipeline.model_dir, | |
| ) | |
| seg = seg.get_fdata() | |
| medical_volume = mv.get_fdata() | |
| axial_masks = [] | |
| ct_image = [] | |
| for i in range(seg.shape[2]): | |
| axial_masks.append(seg[:, :, i]) | |
| for i in range(medical_volume.shape[2]): | |
| ct_image.append(medical_volume[:, :, i]) | |
| # Save input axial slices to pipeline | |
| inference_pipeline.ct_image = ct_image | |
| # Save aorta masks to pipeline | |
| inference_pipeline.axial_masks = axial_masks | |
| return {} | |
| def setup_nnunet_c2c(self, model_dir: Union[str, Path]): | |
| """Adapted from TotalSegmentator.""" | |
| model_dir = Path(model_dir) | |
| config_dir = model_dir / Path("." + self.model_name) | |
| (config_dir / "nnunet/results/nnUNet/3d_fullres").mkdir( | |
| exist_ok=True, parents=True | |
| ) | |
| (config_dir / "nnunet/results/nnUNet/2d").mkdir(exist_ok=True, parents=True) | |
| weights_dir = config_dir / "nnunet/results" | |
| self.weights_dir = weights_dir | |
| os.environ["nnUNet_raw_data_base"] = str( | |
| weights_dir | |
| ) # not needed, just needs to be an existing directory | |
| os.environ["nnUNet_preprocessed"] = str( | |
| weights_dir | |
| ) # not needed, just needs to be an existing directory | |
| os.environ["RESULTS_FOLDER"] = str(weights_dir) | |
| def download_spine_model(self, model_dir: Union[str, Path]): | |
| download_dir = Path( | |
| os.path.join( | |
| self.weights_dir, | |
| "nnUNet/3d_fullres/Task253_Aorta/nnUNetTrainerV2_ep4000_nomirror__nnUNetPlansv2.1", | |
| ) | |
| ) | |
| print(download_dir) | |
| fold_0_path = download_dir / "fold_0" | |
| if not os.path.exists(fold_0_path): | |
| download_dir.mkdir(parents=True, exist_ok=True) | |
| wget.download( | |
| "https://huggingface.co/AdritRao/aaa_test/resolve/main/fold_0.zip", | |
| out=os.path.join(download_dir, "fold_0.zip"), | |
| ) | |
| with zipfile.ZipFile( | |
| os.path.join(download_dir, "fold_0.zip"), "r" | |
| ) as zip_ref: | |
| zip_ref.extractall(download_dir) | |
| os.remove(os.path.join(download_dir, "fold_0.zip")) | |
| wget.download( | |
| "https://huggingface.co/AdritRao/aaa_test/resolve/main/plans.pkl", | |
| out=os.path.join(download_dir, "plans.pkl"), | |
| ) | |
| print("Spine model downloaded.") | |
| else: | |
| print("Spine model already downloaded.") | |
| def spine_seg( | |
| self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir | |
| ): | |
| """Run spine segmentation. | |
| Args: | |
| input_path (Union[str, Path]): Input path. | |
| output_path (Union[str, Path]): Output path. | |
| """ | |
| print("Segmenting spine...") | |
| st = time() | |
| os.environ["SCRATCH"] = self.model_dir | |
| print(self.model_dir) | |
| # Setup nnunet | |
| model = "3d_fullres" | |
| folds = [0] | |
| trainer = "nnUNetTrainerV2_ep4000_nomirror" | |
| crop_path = None | |
| task_id = [253] | |
| self.setup_nnunet_c2c(model_dir) | |
| self.download_spine_model(model_dir) | |
| from totalsegmentator.nnunet import nnUNet_predict_image | |
| with nostdout(): | |
| img, seg = nnUNet_predict_image( | |
| input_path, | |
| output_path, | |
| task_id, | |
| model=model, | |
| folds=folds, | |
| trainer=trainer, | |
| tta=False, | |
| multilabel_image=True, | |
| resample=1.5, | |
| crop=None, | |
| crop_path=crop_path, | |
| task_name="total", | |
| nora_tag="None", | |
| preview=False, | |
| nr_threads_resampling=1, | |
| nr_threads_saving=6, | |
| quiet=False, | |
| verbose=False, | |
| test=0, | |
| ) | |
| end = time() | |
| # Log total time for spine segmentation | |
| print(f"Total time for spine segmentation: {end-st:.2f}s.") | |
| seg_data = seg.get_fdata() | |
| seg = nib.Nifti1Image(seg_data, seg.affine, seg.header) | |
| return seg, img | |
| class AortaDiameter(InferenceClass): | |
| def __init__(self): | |
| super().__init__() | |
| def normalize_img(self, img: np.ndarray) -> np.ndarray: | |
| """Normalize the image. | |
| Args: | |
| img (np.ndarray): Input image. | |
| Returns: | |
| np.ndarray: Normalized image. | |
| """ | |
| return (img - img.min()) / (img.max() - img.min()) | |
| def __call__(self, inference_pipeline): | |
| axial_masks = ( | |
| inference_pipeline.axial_masks | |
| ) # list of 2D numpy arrays of shape (512, 512) | |
| ct_img = ( | |
| inference_pipeline.ct_image | |
| ) # 3D numpy array of shape (512, 512, num_axial_slices) | |
| # image output directory | |
| output_dir = inference_pipeline.output_dir | |
| output_dir_slices = os.path.join(output_dir, "images/slices/") | |
| if not os.path.exists(output_dir_slices): | |
| os.makedirs(output_dir_slices) | |
| output_dir = inference_pipeline.output_dir | |
| output_dir_summary = os.path.join(output_dir, "images/summary/") | |
| if not os.path.exists(output_dir_summary): | |
| os.makedirs(output_dir_summary) | |
| DICOM_PATH = inference_pipeline.dicom_series_path | |
| dicom = pydicom.dcmread(DICOM_PATH + "/" + os.listdir(DICOM_PATH)[0]) | |
| dicom.PhotometricInterpretation = "YBR_FULL" | |
| pixel_conversion = dicom.PixelSpacing | |
| print("Pixel conversion: " + str(pixel_conversion)) | |
| RATIO_PIXEL_TO_MM = pixel_conversion[0] | |
| SLICE_COUNT = dicom["InstanceNumber"].value | |
| print(SLICE_COUNT) | |
| SLICE_COUNT = len(ct_img) | |
| diameterDict = {} | |
| for i in range(len(ct_img)): | |
| mask = axial_masks[i].astype("uint8") | |
| img = ct_img[i] | |
| img = np.clip(img, -300, 1800) | |
| img = self.normalize_img(img) * 255.0 | |
| img = img.reshape((img.shape[0], img.shape[1], 1)) | |
| img = np.tile(img, (1, 1, 3)) | |
| contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) | |
| if len(contours) != 0: | |
| areas = [cv2.contourArea(c) for c in contours] | |
| sorted_areas = np.sort(areas) | |
| areas = [cv2.contourArea(c) for c in contours] | |
| sorted_areas = np.sort(areas) | |
| contours = contours[areas.index(sorted_areas[-1])] | |
| img.copy() | |
| back = img.copy() | |
| cv2.drawContours(back, [contours], 0, (0, 255, 0), -1) | |
| alpha = 0.25 | |
| img = cv2.addWeighted(img, 1 - alpha, back, alpha, 0) | |
| ellipse = cv2.fitEllipse(contours) | |
| (xc, yc), (d1, d2), angle = ellipse | |
| cv2.ellipse(img, ellipse, (0, 255, 0), 1) | |
| xc, yc = ellipse[0] | |
| cv2.circle(img, (int(xc), int(yc)), 5, (0, 0, 255), -1) | |
| rmajor = max(d1, d2) / 2 | |
| rminor = min(d1, d2) / 2 | |
| ### Draw major axes | |
| if angle > 90: | |
| angle = angle - 90 | |
| else: | |
| angle = angle + 90 | |
| print(angle) | |
| xtop = xc + math.cos(math.radians(angle)) * rmajor | |
| ytop = yc + math.sin(math.radians(angle)) * rmajor | |
| xbot = xc + math.cos(math.radians(angle + 180)) * rmajor | |
| ybot = yc + math.sin(math.radians(angle + 180)) * rmajor | |
| cv2.line( | |
| img, (int(xtop), int(ytop)), (int(xbot), int(ybot)), (0, 0, 255), 3 | |
| ) | |
| ### Draw minor axes | |
| if angle > 90: | |
| angle = angle - 90 | |
| else: | |
| angle = angle + 90 | |
| print(angle) | |
| x1 = xc + math.cos(math.radians(angle)) * rminor | |
| y1 = yc + math.sin(math.radians(angle)) * rminor | |
| x2 = xc + math.cos(math.radians(angle + 180)) * rminor | |
| y2 = yc + math.sin(math.radians(angle + 180)) * rminor | |
| cv2.line(img, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 3) | |
| # pixel_length = math.sqrt( (x1-x2)**2 + (y1-y2)**2 ) | |
| pixel_length = rminor * 2 | |
| print("Pixel_length_minor: " + str(pixel_length)) | |
| area_px = cv2.contourArea(contours) | |
| area_mm = round(area_px * RATIO_PIXEL_TO_MM) | |
| area_cm = area_mm / 10 | |
| diameter_mm = round((pixel_length) * RATIO_PIXEL_TO_MM) | |
| diameter_cm = diameter_mm / 10 | |
| diameterDict[(SLICE_COUNT - (i))] = diameter_cm | |
| img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) | |
| h, w, c = img.shape | |
| lbls = [ | |
| "Area (mm): " + str(area_mm) + "mm", | |
| "Area (cm): " + str(area_cm) + "cm", | |
| "Diameter (mm): " + str(diameter_mm) + "mm", | |
| "Diameter (cm): " + str(diameter_cm) + "cm", | |
| "Slice: " + str(SLICE_COUNT - (i)), | |
| ] | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| scale = 0.03 | |
| fontScale = min(w, h) / (25 / scale) | |
| cv2.putText(img, lbls[0], (10, 40), font, fontScale, (0, 255, 0), 2) | |
| cv2.putText(img, lbls[1], (10, 70), font, fontScale, (0, 255, 0), 2) | |
| cv2.putText(img, lbls[2], (10, 100), font, fontScale, (0, 255, 0), 2) | |
| cv2.putText(img, lbls[3], (10, 130), font, fontScale, (0, 255, 0), 2) | |
| cv2.putText(img, lbls[4], (10, 160), font, fontScale, (0, 255, 0), 2) | |
| cv2.imwrite( | |
| output_dir_slices + "slice" + str(SLICE_COUNT - (i)) + ".png", img | |
| ) | |
| plt.bar(list(diameterDict.keys()), diameterDict.values(), color="b") | |
| plt.title(r"$\bf{Diameter}$" + " " + r"$\bf{Progression}$") | |
| plt.xlabel("Slice Number") | |
| plt.ylabel("Diameter Measurement (cm)") | |
| plt.savefig(output_dir_summary + "diameter_graph.png", dpi=500) | |
| print(diameterDict) | |
| print(max(diameterDict.items(), key=operator.itemgetter(1))[0]) | |
| print(diameterDict[max(diameterDict.items(), key=operator.itemgetter(1))[0]]) | |
| inference_pipeline.max_diameter = diameterDict[ | |
| max(diameterDict.items(), key=operator.itemgetter(1))[0] | |
| ] | |
| img = ct_img[ | |
| SLICE_COUNT - (max(diameterDict.items(), key=operator.itemgetter(1))[0]) | |
| ] | |
| img = np.clip(img, -300, 1800) | |
| img = self.normalize_img(img) * 255.0 | |
| img = img.reshape((img.shape[0], img.shape[1], 1)) | |
| img2 = np.tile(img, (1, 1, 3)) | |
| img2 = cv2.rotate(img2, cv2.ROTATE_90_COUNTERCLOCKWISE) | |
| img1 = cv2.imread( | |
| output_dir_slices | |
| + "slice" | |
| + str(max(diameterDict.items(), key=operator.itemgetter(1))[0]) | |
| + ".png" | |
| ) | |
| border_size = 3 | |
| img1 = cv2.copyMakeBorder( | |
| img1, | |
| top=border_size, | |
| bottom=border_size, | |
| left=border_size, | |
| right=border_size, | |
| borderType=cv2.BORDER_CONSTANT, | |
| value=[0, 244, 0], | |
| ) | |
| img2 = cv2.copyMakeBorder( | |
| img2, | |
| top=border_size, | |
| bottom=border_size, | |
| left=border_size, | |
| right=border_size, | |
| borderType=cv2.BORDER_CONSTANT, | |
| value=[244, 0, 0], | |
| ) | |
| vis = np.concatenate((img2, img1), axis=1) | |
| cv2.imwrite(output_dir_summary + "out.png", vis) | |
| image_folder = output_dir_slices | |
| fps = 20 | |
| image_files = [ | |
| os.path.join(image_folder, img) | |
| for img in Tcl().call("lsort", "-dict", os.listdir(image_folder)) | |
| if img.endswith(".png") | |
| ] | |
| clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip( | |
| image_files, fps=fps | |
| ) | |
| clip.write_videofile(output_dir_summary + "aaa.mp4") | |
| return {} | |
| class AortaMetricsSaver(InferenceClass): | |
| """Save metrics to a CSV file.""" | |
| def __init__(self): | |
| super().__init__() | |
| def __call__(self, inference_pipeline): | |
| """Save metrics to a CSV file.""" | |
| self.max_diameter = inference_pipeline.max_diameter | |
| self.dicom_series_path = inference_pipeline.dicom_series_path | |
| self.output_dir = inference_pipeline.output_dir | |
| self.csv_output_dir = os.path.join(self.output_dir, "metrics") | |
| if not os.path.exists(self.csv_output_dir): | |
| os.makedirs(self.csv_output_dir, exist_ok=True) | |
| self.save_results() | |
| return {} | |
| def save_results(self): | |
| """Save results to a CSV file.""" | |
| _, filename = os.path.split(self.dicom_series_path) | |
| data = [[filename, str(self.max_diameter)]] | |
| df = pd.DataFrame(data, columns=["Filename", "Max Diameter"]) | |
| df.to_csv(os.path.join(self.csv_output_dir, "aorta_metrics.csv"), index=False) | |