Dr-Brain / app.py
thunder-007's picture
Text Question feature.
9d20b7c
raw
history blame
5.87 kB
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import openai
from dotenv import load_dotenv
import os
from monai.networks.nets import SegResNet
from monai.inferers import sliding_window_inference
from monai.transforms import (
Activations,
AsDiscrete,
Compose,
LoadImaged,
NormalizeIntensityd,
Orientationd,
EnsureChannelFirstd,
)
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import librosa
import torch
load_dotenv()
title = 'Dr Brain Tumors 🧠'
description = '''
'''
channel_mapping = {
0: 1,
1: 0,
2: 2,
}
preproc_transforms = Compose(
[
LoadImaged(keys=["image"]),
EnsureChannelFirstd(keys="image"),
Orientationd(keys=["image"], axcodes="RAS"),
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
]
)
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
processor_whisper = WhisperProcessor.from_pretrained("whisper-tiny")
model_whisper = WhisperForConditionalGeneration.from_pretrained("whisper-tiny")
model_tumor_seg = SegResNet(
blocks_down=[1, 2, 2, 4],
blocks_up=[1, 1, 1],
init_filters=16,
in_channels=4,
out_channels=3,
dropout_prob=0.2,
).to('cpu')
model_tumor_seg.load_state_dict(torch.load("models/best_metric_model_epoch_40.pth", map_location='cpu'))
def inference(input):
def _compute(input):
return sliding_window_inference(
inputs=input,
roi_size=(240, 240, 160),
sw_batch_size=1,
predictor=model_tumor_seg,
overlap=0.5,
)
return _compute(input)
examples = [
['examples/BRATS_225.nii.gz', 83, 2, 'english', 'examples/sample1_en.mp3'],
['examples/BRATS_485.nii.gz', 90, 1, 'japanese', 'examples/sample2_jp.mp3'],
['examples/BRATS_485.nii.gz', 110, 0, 'german', 'examples/sample3_gr.mp3'],
]
def process_audio(sampling_rate, waveform):
waveform = waveform / 32678.0
if len(waveform.shape) > 1:
waveform = librosa.to_mono(waveform.T)
if sampling_rate != 16000:
waveform = librosa.resample(waveform, orig_sr=sampling_rate, target_sr=16000)
waveform = waveform[:16000 * 30]
waveform = torch.tensor(waveform)
return waveform
openai.api_key = os.environ.get("OPENAI_KEY")
def make_llm_call(prompt,
context="You are a text generation model DR-Brain Developed by team brute force a team 4 AI engineers from RMKCET college they are HARSHA VARDHAN V , SAWIN KUMAR Y , CHARAN TEJA P, KISHORE S. Your specialized in medical stuff"):
messages = [{"role": "user", "content": prompt}]
if context:
messages.insert(0, {"role": "system", "content": context})
response_obj = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=messages)
response_message = dict(dict(response_obj)['choices'][0])["message"]["content"]
return response_message
def detector(tumor_file, slice_number, channel, language, audio_question, text_question):
llm_answer = "Hi I'm Dr brain please enter a question to answer"
if text_question:
llm_answer = make_llm_call(text_question)
elif audio_question:
sampling_rate, waveform = audio_question
forced_decoder_ids = processor_whisper.get_decoder_prompt_ids(language=language, task="transcribe")
waveform = process_audio(sampling_rate, waveform)
audio_inputs = processor_whisper(audio=waveform, sampling_rate=16000, return_tensors="pt")
predicted_ids = model_whisper.generate(**audio_inputs, max_length=400, forced_decoder_ids=forced_decoder_ids)
transcription = processor_whisper.batch_decode(predicted_ids, skip_special_tokens=True)
llm_quesion = transcription[0]
llm_answer = make_llm_call(llm_quesion)
tumor_file_path = tumor_file.name
processed_data = preproc_transforms({'image': [tumor_file_path]})
tensor_3d_input = processed_data['image'].unsqueeze(0).to('cpu')
with torch.no_grad():
output = inference(tensor_3d_input)
img_slice = tensor_3d_input[0][channel, :, :, slice_number]
plt.imshow(img_slice, cmap='gray')
input_image_path = f"input_img_channel{channel}.png"
plt.axis('off')
plt.savefig(input_image_path, bbox_inches='tight', pad_inches=0)
channel_image = np.asarray(Image.open(input_image_path))
os.remove(input_image_path)
output_image_path = f"ouput_img_channel{channel}.png"
plt.imshow(post_trans(output[0][channel_mapping[channel], :, :, slice_number]))
plt.axis('off')
plt.savefig(output_image_path, bbox_inches='tight', pad_inches=0)
segment_image = np.asarray(Image.open(output_image_path))
os.remove(output_image_path)
return (channel_image, segment_image, llm_answer)
interface = gr.Interface(fn=detector, inputs=[gr.File(label="Tumor File"),
gr.Slider(0, 200, 50, step=1, label="Slice Number"),
gr.Radio((0, 1, 2), label="Channel"),
gr.Radio(("english", "japanese", "german", "spanish"), label="Language"),
gr.Audio(source="microphone"),
gr.Textbox(label='Text Question')],
outputs=[gr.Image(label='channel', shape=(1, 1)),
gr.Image(label='Segmented Tumor', shape=(1, 1)),
gr.Textbox(label="Dr brain response")], title=title,
examples=examples,
description=description, theme='dark')
theme = gr.themes.Default().set(
button_primary_background_fill="#FF0000",
button_primary_background_fill_dark="#AAAAAA",
)
interface.launch(debug=True)