Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import os | |
from monai.networks.nets import SegResNet | |
from monai.inferers import sliding_window_inference | |
from monai.transforms import ( | |
Activations, | |
AsDiscrete, | |
Compose, | |
LoadImaged, | |
NormalizeIntensityd, | |
Orientationd, | |
EnsureChannelFirstd, | |
) | |
import torch | |
title = 'Detect and Segment Brain Tumors 🧠' | |
description = ''' | |
''' | |
channel_mapping = { | |
0: 1, | |
1: 0, | |
2: 2, | |
} | |
preproc_transforms = Compose( | |
[ | |
LoadImaged(keys=["image"]), | |
EnsureChannelFirstd(keys="image"), | |
Orientationd(keys=["image"], axcodes="RAS"), | |
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), | |
] | |
) | |
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) | |
model_tumor_seg = SegResNet( | |
blocks_down=[1, 2, 2, 4], | |
blocks_up=[1, 1, 1], | |
init_filters=16, | |
in_channels=4, | |
out_channels=3, | |
dropout_prob=0.2, | |
).to('cpu') | |
model_tumor_seg.load_state_dict(torch.load("models/best_metric_model_epoch_40.pth", map_location='cpu')) | |
def inference(input): | |
def _compute(input): | |
return sliding_window_inference( | |
inputs=input, | |
roi_size=(240, 240, 160), | |
sw_batch_size=1, | |
predictor=model_tumor_seg, | |
overlap=0.5, | |
) | |
return _compute(input) | |
examples = [ | |
['examples/BRATS_225.nii.gz', 83, 2], | |
['examples/BRATS_485.nii.gz', 90, 1], | |
['examples/BRATS_485.nii.gz', 110, 0] | |
] | |
def detector(tumor_file, slice_number, channel, audio_question): | |
tumor_file_path = tumor_file.name | |
processed_data = preproc_transforms({'image': [tumor_file_path]}) | |
tensor_3d_input = processed_data['image'].unsqueeze(0).to('cpu') | |
with torch.no_grad(): | |
output = inference(tensor_3d_input) | |
img_slice = tensor_3d_input[0][channel, :, :, slice_number] | |
plt.imshow(img_slice, cmap='gray') | |
input_image_path = f"input_img_channel{channel}.png" | |
plt.axis('off') | |
plt.savefig(input_image_path, bbox_inches='tight', pad_inches=0) | |
channel_image = np.asarray(Image.open(input_image_path)) | |
os.remove(input_image_path) | |
output_image_path = f"ouput_img_channel{channel}.png" | |
plt.imshow(post_trans(output[0][channel_mapping[channel], :, :, slice_number])) | |
plt.axis('off') | |
plt.savefig(output_image_path, bbox_inches='tight', pad_inches=0) | |
segment_image = np.asarray(Image.open(output_image_path)) | |
os.remove(output_image_path) | |
return (channel_image, segment_image, "Question answer") | |
interface = gr.Interface(fn=detector, inputs=[gr.File(label="Tumor File"), | |
gr.Slider(0, 200, 50, step=1, label="Slice Number"), | |
gr.Radio([0, 1, 2], label="Channel"), | |
gr.Audio(source="microphone"), ], | |
outputs=[gr.Image(label='channel', shape=(1, 1)), | |
gr.Image(label='Segmented Tumor', shape=(1, 1)), | |
gr.Textbox(label="Medical Summary")], title=title, | |
examples=examples, | |
description=description, outputs_layout="row", theme='dark') | |
theme = gr.themes.Default().set( | |
button_primary_background_fill="#FF0000", | |
button_primary_background_fill_dark="#AAAAAA", | |
) | |
interface.launch(debug=True) | |