File size: 3,459 Bytes
85a1b63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13679ee
 
85a1b63
 
 
13679ee
85a1b63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13679ee
 
 
 
85a1b63
 
 
13679ee
85a1b63
 
13679ee
 
 
 
 
85a1b63
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
from monai.networks.nets import SegResNet
from monai.inferers import sliding_window_inference
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
    LoadImaged,
    NormalizeIntensityd,
    Orientationd,
    EnsureChannelFirstd,
)
import torch

title = 'Detect and Segment Brain Tumors 🧠'
description = '''
'''

channel_mapping = {
    0: 1,
    1: 0,
    2: 2,
}

preproc_transforms = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys="image"),
        Orientationd(keys=["image"], axcodes="RAS"),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

model_tumor_seg = SegResNet(
    blocks_down=[1, 2, 2, 4],
    blocks_up=[1, 1, 1],
    init_filters=16,
    in_channels=4,
    out_channels=3,
    dropout_prob=0.2,
).to('cpu')

model_tumor_seg.load_state_dict(torch.load("models/best_metric_model_epoch_40.pth", map_location='cpu'))


def inference(input):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=(240, 240, 160),
            sw_batch_size=1,
            predictor=model_tumor_seg,
            overlap=0.5,
        )

    return _compute(input)


examples = [
    ['examples/BRATS_225.nii.gz', 83, 2],
    ['examples/BRATS_485.nii.gz', 90, 1],
    ['examples/BRATS_485.nii.gz', 110, 0]
]


def detector(tumor_file, slice_number, channel, audio_question):
    tumor_file_path = tumor_file.name
    processed_data = preproc_transforms({'image': [tumor_file_path]})
    tensor_3d_input = processed_data['image'].unsqueeze(0).to('cpu')
    with torch.no_grad():
        output = inference(tensor_3d_input)
    img_slice = tensor_3d_input[0][channel, :, :, slice_number]
    plt.imshow(img_slice, cmap='gray')
    input_image_path = f"input_img_channel{channel}.png"
    plt.axis('off')
    plt.savefig(input_image_path, bbox_inches='tight', pad_inches=0)
    channel_image = np.asarray(Image.open(input_image_path))
    os.remove(input_image_path)
    output_image_path = f"ouput_img_channel{channel}.png"
    plt.imshow(post_trans(output[0][channel_mapping[channel], :, :, slice_number]))
    plt.axis('off')
    plt.savefig(output_image_path, bbox_inches='tight', pad_inches=0)
    segment_image = np.asarray(Image.open(output_image_path))
    os.remove(output_image_path)
    return (channel_image, segment_image, "Question answer")


interface = gr.Interface(fn=detector, inputs=[gr.File(label="Tumor File"),
                                              gr.Slider(0, 200, 50, step=1, label="Slice Number"),
                                              gr.Radio([0, 1, 2], label="Channel"),
                                              gr.Audio(source="microphone"), ],
                         outputs=[gr.Image(label='channel', shape=(1, 1)),
                                  gr.Image(label='Segmented Tumor', shape=(1, 1)),
                                  gr.Textbox(label="Medical Summary")], title=title,
                         examples=examples,
                         description=description, outputs_layout="row", theme='dark')

theme = gr.themes.Default().set(
    button_primary_background_fill="#FF0000",
    button_primary_background_fill_dark="#AAAAAA",
)

interface.launch(debug=True)