Spaces:
Runtime error
Runtime error
File size: 3,459 Bytes
85a1b63 13679ee 85a1b63 13679ee 85a1b63 13679ee 85a1b63 13679ee 85a1b63 13679ee 85a1b63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
from monai.networks.nets import SegResNet
from monai.inferers import sliding_window_inference
from monai.transforms import (
Activations,
AsDiscrete,
Compose,
LoadImaged,
NormalizeIntensityd,
Orientationd,
EnsureChannelFirstd,
)
import torch
title = 'Detect and Segment Brain Tumors 🧠'
description = '''
'''
channel_mapping = {
0: 1,
1: 0,
2: 2,
}
preproc_transforms = Compose(
[
LoadImaged(keys=["image"]),
EnsureChannelFirstd(keys="image"),
Orientationd(keys=["image"], axcodes="RAS"),
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
]
)
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
model_tumor_seg = SegResNet(
blocks_down=[1, 2, 2, 4],
blocks_up=[1, 1, 1],
init_filters=16,
in_channels=4,
out_channels=3,
dropout_prob=0.2,
).to('cpu')
model_tumor_seg.load_state_dict(torch.load("models/best_metric_model_epoch_40.pth", map_location='cpu'))
def inference(input):
def _compute(input):
return sliding_window_inference(
inputs=input,
roi_size=(240, 240, 160),
sw_batch_size=1,
predictor=model_tumor_seg,
overlap=0.5,
)
return _compute(input)
examples = [
['examples/BRATS_225.nii.gz', 83, 2],
['examples/BRATS_485.nii.gz', 90, 1],
['examples/BRATS_485.nii.gz', 110, 0]
]
def detector(tumor_file, slice_number, channel, audio_question):
tumor_file_path = tumor_file.name
processed_data = preproc_transforms({'image': [tumor_file_path]})
tensor_3d_input = processed_data['image'].unsqueeze(0).to('cpu')
with torch.no_grad():
output = inference(tensor_3d_input)
img_slice = tensor_3d_input[0][channel, :, :, slice_number]
plt.imshow(img_slice, cmap='gray')
input_image_path = f"input_img_channel{channel}.png"
plt.axis('off')
plt.savefig(input_image_path, bbox_inches='tight', pad_inches=0)
channel_image = np.asarray(Image.open(input_image_path))
os.remove(input_image_path)
output_image_path = f"ouput_img_channel{channel}.png"
plt.imshow(post_trans(output[0][channel_mapping[channel], :, :, slice_number]))
plt.axis('off')
plt.savefig(output_image_path, bbox_inches='tight', pad_inches=0)
segment_image = np.asarray(Image.open(output_image_path))
os.remove(output_image_path)
return (channel_image, segment_image, "Question answer")
interface = gr.Interface(fn=detector, inputs=[gr.File(label="Tumor File"),
gr.Slider(0, 200, 50, step=1, label="Slice Number"),
gr.Radio([0, 1, 2], label="Channel"),
gr.Audio(source="microphone"), ],
outputs=[gr.Image(label='channel', shape=(1, 1)),
gr.Image(label='Segmented Tumor', shape=(1, 1)),
gr.Textbox(label="Medical Summary")], title=title,
examples=examples,
description=description, outputs_layout="row", theme='dark')
theme = gr.themes.Default().set(
button_primary_background_fill="#FF0000",
button_primary_background_fill_dark="#AAAAAA",
)
interface.launch(debug=True)
|