| import cv2 | |
| import glob | |
| import gradio as gr | |
| import mediapy | |
| import nibabel | |
| import numpy as np | |
| import shutil | |
| import torch | |
| import torch.nn.functional as F | |
| from omegaconf import OmegaConf | |
| from skp import builder | |
| def window(x, WL=400, WW=2500): | |
| lower, upper = WL - WW // 2, WL + WW // 2 | |
| x = np.clip(x, lower, upper) | |
| x = x - lower | |
| x = x / (upper - lower) | |
| return (x * 255).astype("uint8") | |
| def rescale(x): | |
| x = x / 255. | |
| x = x - 0.5 | |
| x = x * 2.0 | |
| return x | |
| def generate_segmentation_video(study): | |
| img = nibabel.load(study).get_fdata()[:, ::-1, ::-1].transpose(2, 1, 0) | |
| img = window(img) | |
| X = torch.from_numpy(img).float().unsqueeze(0).unsqueeze(0) | |
| X = F.interpolate(X, size=(192, 192, 192), mode="nearest") | |
| X = rescale(X) | |
| with torch.no_grad(): | |
| seg_output = seg_model(X) | |
| seg_output = torch.sigmoid(seg_output) | |
| p_spine = seg_output[:, :7].sum(1) | |
| seg_output = torch.argmax(seg_output, dim=1) + 1 | |
| seg_output[p_spine < 0.5] = 0 | |
| seg_output = F.interpolate(seg_output.unsqueeze(0).float(), size=img.shape, mode="nearest") | |
| seg_output = seg_output.squeeze(0).squeeze(0).numpy() | |
| seg_output = (seg_output * 255 / 7).astype("uint8") | |
| seg_output = np.stack([cv2.applyColorMap(_, cv2.COLORMAP_JET) for _ in seg_output]) | |
| frames = [] | |
| skip = 8 | |
| for idx in range(0, img.shape[2], skip): | |
| i = img[:, :, idx] | |
| o = seg_output[:, :, idx] | |
| i = cv2.cvtColor(i, cv2.COLOR_GRAY2RGB) | |
| frame = np.concatenate((i, o), 1) | |
| frames.append(frame) | |
| mediapy.write_video("video.mp4", frames, fps=30) | |
| return "video.mp4" | |
| ffmpeg_path = shutil.which('ffmpeg') | |
| mediapy.set_ffmpeg(ffmpeg_path) | |
| config = OmegaConf.load("configs/pseudoseg000.yaml") | |
| config.model.load_pretrained = "seg.ckpt" | |
| config.model.params.encoder_params.pretrained = False | |
| seg_model = builder.build_model(config).eval() | |
| examples = glob.glob("examples/*.nii.gz") | |
| with gr.Blocks(theme="dark-peach") as demo: | |
| select_study = gr.Dropdown(choices=sorted(examples), type="value", label="Select a study") | |
| button_predict = gr.Button("Predict") | |
| video_output = gr.Video() | |
| button_predict.click(fn=generate_segmentation_video, | |
| inputs=select_study, | |
| outputs=video_output) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |