Dr-Brain / app.py
thunder-007's picture
model examples and ui improvement
13679ee
raw
history blame
3.46 kB
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
from monai.networks.nets import SegResNet
from monai.inferers import sliding_window_inference
from monai.transforms import (
Activations,
AsDiscrete,
Compose,
LoadImaged,
NormalizeIntensityd,
Orientationd,
EnsureChannelFirstd,
)
import torch
title = 'Detect and Segment Brain Tumors 🧠'
description = '''
'''
channel_mapping = {
0: 1,
1: 0,
2: 2,
}
preproc_transforms = Compose(
[
LoadImaged(keys=["image"]),
EnsureChannelFirstd(keys="image"),
Orientationd(keys=["image"], axcodes="RAS"),
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
]
)
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
model_tumor_seg = SegResNet(
blocks_down=[1, 2, 2, 4],
blocks_up=[1, 1, 1],
init_filters=16,
in_channels=4,
out_channels=3,
dropout_prob=0.2,
).to('cpu')
model_tumor_seg.load_state_dict(torch.load("models/best_metric_model_epoch_40.pth", map_location='cpu'))
def inference(input):
def _compute(input):
return sliding_window_inference(
inputs=input,
roi_size=(240, 240, 160),
sw_batch_size=1,
predictor=model_tumor_seg,
overlap=0.5,
)
return _compute(input)
examples = [
['examples/BRATS_225.nii.gz', 83, 2],
['examples/BRATS_485.nii.gz', 90, 1],
['examples/BRATS_485.nii.gz', 110, 0]
]
def detector(tumor_file, slice_number, channel, audio_question):
tumor_file_path = tumor_file.name
processed_data = preproc_transforms({'image': [tumor_file_path]})
tensor_3d_input = processed_data['image'].unsqueeze(0).to('cpu')
with torch.no_grad():
output = inference(tensor_3d_input)
img_slice = tensor_3d_input[0][channel, :, :, slice_number]
plt.imshow(img_slice, cmap='gray')
input_image_path = f"input_img_channel{channel}.png"
plt.axis('off')
plt.savefig(input_image_path, bbox_inches='tight', pad_inches=0)
channel_image = np.asarray(Image.open(input_image_path))
os.remove(input_image_path)
output_image_path = f"ouput_img_channel{channel}.png"
plt.imshow(post_trans(output[0][channel_mapping[channel], :, :, slice_number]))
plt.axis('off')
plt.savefig(output_image_path, bbox_inches='tight', pad_inches=0)
segment_image = np.asarray(Image.open(output_image_path))
os.remove(output_image_path)
return (channel_image, segment_image, "Question answer")
interface = gr.Interface(fn=detector, inputs=[gr.File(label="Tumor File"),
gr.Slider(0, 200, 50, step=1, label="Slice Number"),
gr.Radio([0, 1, 2], label="Channel"),
gr.Audio(source="microphone"), ],
outputs=[gr.Image(label='channel', shape=(1, 1)),
gr.Image(label='Segmented Tumor', shape=(1, 1)),
gr.Textbox(label="Medical Summary")], title=title,
examples=examples,
description=description, outputs_layout="row", theme='dark')
theme = gr.themes.Default().set(
button_primary_background_fill="#FF0000",
button_primary_background_fill_dark="#AAAAAA",
)
interface.launch(debug=True)