File size: 5,256 Bytes
2f9ea03
 
 
 
73f90ac
2f9ea03
 
 
 
 
 
 
 
 
 
73f90ac
2f9ea03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a85b9e3
2f9ea03
 
 
 
 
a85b9e3
2f9ea03
a85b9e3
2f9ea03
 
a85b9e3
2f9ea03
a85b9e3
2f9ea03
 
 
 
 
 
 
 
 
a85b9e3
 
2f9ea03
a85b9e3
 
 
2f9ea03
a85b9e3
 
 
 
 
2f9ea03
 
a85b9e3
2f9ea03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a85b9e3
2f9ea03
 
a85b9e3
2f9ea03
 
a85b9e3
2f9ea03
a85b9e3
2f9ea03
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoProcessor, AutoModelForImageTextRetrieval
import torch.nn.functional as F

#---------------------------------
#++++++++     Model     ++++++++++
#---------------------------------

def load_biomedclip_model():
    """Loads the BiomedCLIP model and tokenizer."""
    biomedclip_model_name = 'microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
    processor = AutoProcessor.from_pretrained(biomedclip_model_name)
    model = AutoModelForImageTextRetrieval.from_pretrained(biomedclip_model_name).cuda().eval()
    return model, processor

def compute_similarity(image, text, biomedclip_model, biomedclip_processor):
    """Computes similarity scores using BiomedCLIP."""
    with torch.no_grad():
        inputs = biomedclip_processor(text=text, images=image, return_tensors="pt", padding=True).to(biomedclip_model.device)
        outputs = biomedclip_model(**inputs)
        image_embeds = outputs.image_embeds
        text_embeds = outputs.text_embeds
    image_embeds = F.normalize(image_embeds, dim=-1)
    text_embeds = F.normalize(text_embeds, dim=-1)
    similarity = (text_embeds @ image_embeds.transpose(-1, -2)).squeeze()
    return similarity

#---------------------------------
#++++++++     Gradio     ++++++++++
#---------------------------------

def gradio_reset(chat_state, img_list, similarity_output):
    """Resets the chat state and image list."""
    if chat_state is not None:
        chat_state.messages = []
    if img_list is not None:
        img_list = []
    return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your medical image first', interactive=False), gr.update(value="Upload & Start Analysis", interactive=True), chat_state, img_list, gr.update(value="", visible=False)

def upload_img(gr_img, text_input, chat_state, similarity_output):
    """Handles image upload."""
    if gr_img is None:
        return None, None, gr.update(interactive=True), chat_state, None, gr.update(visible=False)
    img_list = [gr_img]
    return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Analysis", interactive=False), chat_state, img_list, gr.update(visible=True)

def gradio_ask(user_message, chatbot, chat_state):
    """Handles user input."""
    if not user_message:
        return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
    chatbot = chatbot + [[user_message, None]]
    return '', chatbot, chat_state

@spaces.GPU
def gradio_answer(chatbot, chat_state, img_list, biomedclip_model, biomedclip_processor, similarity_output):
    """Computes and displays similarity scores."""
    if not img_list:
        return chatbot, chat_state, img_list, similarity_output

    similarity_score = compute_similarity(img_list[0], chatbot[-1][0], biomedclip_model, biomedclip_processor)
    print(f'Similarity Score is: {similarity_score}')

    similarity_text = f"Similarity Score: {similarity_score:.3f}"
    chatbot[-1][1] = similarity_text
    return chatbot, chat_state, img_list, gr.update(value=similarity_text, visible=True)


title = """<h1 align="center">Medical Image Analysis Tool</h1>"""
description = """<h3>Upload medical images, ask questions, and receive a similarity score.</h3>"""
examples_list=[
                    ["./case1.png", "Analyze the X-ray for any abnormalities."],
                    ["./case2.jpg", "What type of disease may be present?"],
                    ["./case1.png","What is the anatomical structure shown here?"]
                ]

# Load models and related resources outside of the Gradio block for loading on startup
biomedclip_model, biomedclip_processor = load_biomedclip_model()

with gr.Blocks() as demo:
    gr.Markdown(title)
    gr.Markdown(description)

    with gr.Row():
        with gr.Column(scale=0.5):
            image = gr.Image(type="pil", label="Medical Image")
            upload_button = gr.Button(value="Upload & Start Analysis", interactive=True, variant="primary")
            clear = gr.Button("Restart")

        with gr.Column():
            chat_state = gr.State()
            img_list = gr.State()
            chatbot = gr.Chatbot(label='Medical Analysis')
            text_input = gr.Textbox(label='Analysis Query', placeholder='Please upload your medical image first', interactive=False)
            similarity_output = gr.Textbox(label="Similarity Score", visible=False, interactive=False)
            gr.Examples(examples=examples_list, inputs=[image, text_input])

    upload_button.click(upload_img, [image, text_input, chat_state, similarity_output], [image, text_input, upload_button, chat_state, img_list, similarity_output])

    text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
        gradio_answer, [chatbot, chat_state, img_list, biomedclip_model, biomedclip_processor, similarity_output], [chatbot, chat_state, img_list, similarity_output]
    )
    clear.click(gradio_reset, [chat_state, img_list, similarity_output], [chatbot, image, text_input, upload_button, chat_state, img_list, similarity_output], queue=False)

demo.launch()