File size: 5,214 Bytes
2f9ea03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a85b9e3
2f9ea03
 
 
 
 
a85b9e3
2f9ea03
a85b9e3
2f9ea03
 
a85b9e3
2f9ea03
a85b9e3
2f9ea03
 
 
 
 
 
 
 
 
a85b9e3
 
2f9ea03
a85b9e3
 
 
2f9ea03
a85b9e3
 
 
 
 
2f9ea03
 
a85b9e3
2f9ea03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a85b9e3
2f9ea03
 
a85b9e3
2f9ea03
 
a85b9e3
2f9ea03
a85b9e3
2f9ea03
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoProcessor, AutoModel
import torch.nn.functional as F

#---------------------------------
#++++++++     Model     ++++++++++
#---------------------------------

def load_biomedclip_model():
    """Loads the BiomedCLIP model and tokenizer."""
    biomedclip_model_name = 'microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
    processor = AutoProcessor.from_pretrained(biomedclip_model_name)
    model = AutoModel.from_pretrained(biomedclip_model_name).cuda().eval()
    return model, processor

def compute_similarity(image, text, biomedclip_model, biomedclip_processor):
    """Computes similarity scores using BiomedCLIP."""
    with torch.no_grad():
        inputs = biomedclip_processor(text=text, images=image, return_tensors="pt", padding=True).to(biomedclip_model.device)
        outputs = biomedclip_model(**inputs)
        image_embeds = outputs.image_embeds
        text_embeds = outputs.text_embeds
    image_embeds = F.normalize(image_embeds, dim=-1)
    text_embeds = F.normalize(text_embeds, dim=-1)
    similarity = (text_embeds @ image_embeds.transpose(-1, -2)).squeeze()
    return similarity

#---------------------------------
#++++++++     Gradio     ++++++++++
#---------------------------------

def gradio_reset(chat_state, img_list, similarity_output):
    """Resets the chat state and image list."""
    if chat_state is not None:
        chat_state.messages = []
    if img_list is not None:
        img_list = []
    return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your medical image first', interactive=False), gr.update(value="Upload & Start Analysis", interactive=True), chat_state, img_list, gr.update(value="", visible=False)

def upload_img(gr_img, text_input, chat_state, similarity_output):
    """Handles image upload."""
    if gr_img is None:
        return None, None, gr.update(interactive=True), chat_state, None, gr.update(visible=False)
    img_list = [gr_img]
    return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Analysis", interactive=False), chat_state, img_list, gr.update(visible=True)

def gradio_ask(user_message, chatbot, chat_state):
    """Handles user input."""
    if not user_message:
        return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
    chatbot = chatbot + [[user_message, None]]
    return '', chatbot, chat_state

@spaces.GPU
def gradio_answer(chatbot, chat_state, img_list, biomedclip_model, biomedclip_processor, similarity_output):
    """Computes and displays similarity scores."""
    if not img_list:
        return chatbot, chat_state, img_list, similarity_output

    similarity_score = compute_similarity(img_list[0], chatbot[-1][0], biomedclip_model, biomedclip_processor)
    print(f'Similarity Score is: {similarity_score}')

    similarity_text = f"Similarity Score: {similarity_score:.3f}"
    chatbot[-1][1] = similarity_text
    return chatbot, chat_state, img_list, gr.update(value=similarity_text, visible=True)


title = """<h1 align="center">Medical Image Analysis Tool</h1>"""
description = """<h3>Upload medical images, ask questions, and receive a similarity score.</h3>"""
examples_list=[
                    ["./case1.png", "Analyze the X-ray for any abnormalities."],
                    ["./case2.jpg", "What type of disease may be present?"],
                    ["./case1.png","What is the anatomical structure shown here?"]
                ]

# Load models and related resources outside of the Gradio block for loading on startup
biomedclip_model, biomedclip_processor = load_biomedclip_model()

with gr.Blocks() as demo:
    gr.Markdown(title)
    gr.Markdown(description)

    with gr.Row():
        with gr.Column(scale=0.5):
            image = gr.Image(type="pil", label="Medical Image")
            upload_button = gr.Button(value="Upload & Start Analysis", interactive=True, variant="primary")
            clear = gr.Button("Restart")

        with gr.Column():
            chat_state = gr.State()
            img_list = gr.State()
            chatbot = gr.Chatbot(label='Medical Analysis')
            text_input = gr.Textbox(label='Analysis Query', placeholder='Please upload your medical image first', interactive=False)
            similarity_output = gr.Textbox(label="Similarity Score", visible=False, interactive=False)
            gr.Examples(examples=examples_list, inputs=[image, text_input])

    upload_button.click(upload_img, [image, text_input, chat_state, similarity_output], [image, text_input, upload_button, chat_state, img_list, similarity_output])

    text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
        gradio_answer, [chatbot, chat_state, img_list, biomedclip_model, biomedclip_processor, similarity_output], [chatbot, chat_state, img_list, similarity_output]
    )
    clear.click(gradio_reset, [chat_state, img_list, similarity_output], [chatbot, image, text_input, upload_button, chat_state, img_list, similarity_output], queue=False)

demo.launch()