|
import cv2 |
|
import glob |
|
import gradio as gr |
|
import mediapy |
|
import nibabel |
|
import numpy as np |
|
import shutil |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from omegaconf import OmegaConf |
|
from skp import builder |
|
|
|
|
|
def window(x, WL=400, WW=2500): |
|
lower, upper = WL - WW // 2, WL + WW // 2 |
|
x = np.clip(x, lower, upper) |
|
x = x - lower |
|
x = x / (upper - lower) |
|
return (x * 255).astype("uint8") |
|
|
|
|
|
def rescale(x): |
|
x = x / 255. |
|
x = x - 0.5 |
|
x = x * 2.0 |
|
return x |
|
|
|
|
|
def get_cervical_spine_coordinates(x, original_shape): |
|
|
|
x = x.squeeze(0).numpy()[:7] |
|
rescale_factor = [original_shape[0] / x.shape[1], original_shape[1] / x.shape[2], original_shape[2] / x.shape[3]] |
|
coords_dict = {} |
|
for level in range(x.shape[0]): |
|
coords = np.where(x[level] >= 0.4) |
|
coords = np.vstack(coords).astype("float") |
|
coords[0] = coords[0] * rescale_factor[0] |
|
coords[1] = coords[1] * rescale_factor[1] |
|
coords[2] = coords[2] * rescale_factor[2] |
|
coords = coords.astype("int") |
|
coords_dict[level] = coords[0].min(), coords[0].max(),\ |
|
coords[1].min(), coords[1].max(),\ |
|
coords[2].min(), coords[2].max() |
|
return coords_dict |
|
|
|
|
|
def generate_segmentation_video(study): |
|
img = nibabel.load(study).get_fdata()[:, ::-1, ::-1].transpose(2, 1, 0) |
|
img = window(img) |
|
|
|
X = torch.from_numpy(img).float().unsqueeze(0).unsqueeze(0) |
|
X = F.interpolate(X, size=(192, 192, 192), mode="nearest") |
|
X = rescale(X) |
|
with torch.no_grad(): |
|
seg_output = seg_model(X) |
|
|
|
seg_output = torch.sigmoid(seg_output) |
|
c_spine_coords = get_cervical_spine_coordinates(seg_output, img.shape) |
|
|
|
chunk_features = [] |
|
for level, coords in c_spine_coords.items(): |
|
z1, z2, h1, h2, w1, w2 = coords |
|
X = torch.from_numpy(img[z1:z2+1, h1:h2+1, w1:w2+1]).float().unsqueeze(0).unsqueeze(0) |
|
X = F.interpolate(X, size=(64, 288, 288), mode="nearest") |
|
X = rescale(X) |
|
with torch.no_grad(): |
|
chunk_features.append(x3d_model.extract_features(X)) |
|
|
|
chunk_features = torch.stack(chunk_features, dim=1) |
|
with torch.no_grad(): |
|
final_output = torch.sigmoid(seq_model((chunk_features, torch.ones((chunk_features.size(1), ))))) |
|
|
|
final_output_dict = {f"C{i+1}": final_output[:, i].item() for i in range(7)} |
|
final_output_dict["Overall"] = final_output[:, -1].item() |
|
|
|
seg_output = F.interpolate(seg_output, size=img.shape, mode="nearest").squeeze(0).numpy() |
|
|
|
p_spine = seg_output[:7].sum(0) |
|
|
|
seg_output = np.argmax(seg_output[:7], axis=0) + 1 |
|
|
|
seg_output[p_spine < 0.5] = 0 |
|
seg_output = (seg_output * 255 / 7).astype("uint8") |
|
seg_output = np.stack([cv2.applyColorMap(_, cv2.COLORMAP_JET) for _ in seg_output]) |
|
seg_output[p_spine < 0.5] = 0 |
|
|
|
frames = [] |
|
skip = 8 |
|
for idx in range(0, img.shape[2], skip): |
|
i = img[:, :, idx] |
|
o = seg_output[:, :, idx] |
|
i = cv2.cvtColor(i, cv2.COLOR_GRAY2RGB) |
|
frame = np.concatenate((i, o), 1) |
|
frames.append(frame) |
|
mediapy.write_video("video.mp4", frames, fps=10) |
|
return "video.mp4", final_output_dict |
|
|
|
|
|
ffmpeg_path = shutil.which('ffmpeg') |
|
mediapy.set_ffmpeg(ffmpeg_path) |
|
|
|
config = OmegaConf.load("configs/pseudoseg000.yaml") |
|
config.model.load_pretrained = "seg.ckpt" |
|
config.model.params.encoder_params.pretrained = False |
|
seg_model = builder.build_model(config).eval() |
|
|
|
config = OmegaConf.load("configs/chunk000.yaml") |
|
config.model.load_pretrained = "x3d.ckpt" |
|
config.model.params.pretrained = False |
|
x3d_model = builder.build_model(config).eval() |
|
|
|
config = OmegaConf.load("configs/chunkseq003.yaml") |
|
config.model.load_pretrained = "seq.ckpt" |
|
seq_model = builder.build_model(config).eval() |
|
|
|
examples = glob.glob("examples/*.nii.gz") |
|
|
|
with gr.Blocks(theme="dark-peach") as demo: |
|
select_study = gr.Dropdown(choices=sorted(examples), type="value", label="Select a study") |
|
button_predict = gr.Button("Predict") |
|
video_output = gr.Video(label="Cervical Spine Segmentation") |
|
label_output = gr.Label(label="Fracture Predictions", show_label=False) |
|
button_predict.click(fn=generate_segmentation_video, |
|
inputs=select_study, |
|
outputs=[video_output, label_output]) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=True) |
|
|
|
|
|
|