Spaces:
Runtime error
Runtime error
File size: 5,256 Bytes
2f9ea03 73f90ac 2f9ea03 73f90ac 2f9ea03 a85b9e3 2f9ea03 a85b9e3 2f9ea03 a85b9e3 2f9ea03 a85b9e3 2f9ea03 a85b9e3 2f9ea03 a85b9e3 2f9ea03 a85b9e3 2f9ea03 a85b9e3 2f9ea03 a85b9e3 2f9ea03 a85b9e3 2f9ea03 a85b9e3 2f9ea03 a85b9e3 2f9ea03 a85b9e3 2f9ea03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoProcessor, AutoModelForImageTextRetrieval
import torch.nn.functional as F
#---------------------------------
#++++++++ Model ++++++++++
#---------------------------------
def load_biomedclip_model():
"""Loads the BiomedCLIP model and tokenizer."""
biomedclip_model_name = 'microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
processor = AutoProcessor.from_pretrained(biomedclip_model_name)
model = AutoModelForImageTextRetrieval.from_pretrained(biomedclip_model_name).cuda().eval()
return model, processor
def compute_similarity(image, text, biomedclip_model, biomedclip_processor):
"""Computes similarity scores using BiomedCLIP."""
with torch.no_grad():
inputs = biomedclip_processor(text=text, images=image, return_tensors="pt", padding=True).to(biomedclip_model.device)
outputs = biomedclip_model(**inputs)
image_embeds = outputs.image_embeds
text_embeds = outputs.text_embeds
image_embeds = F.normalize(image_embeds, dim=-1)
text_embeds = F.normalize(text_embeds, dim=-1)
similarity = (text_embeds @ image_embeds.transpose(-1, -2)).squeeze()
return similarity
#---------------------------------
#++++++++ Gradio ++++++++++
#---------------------------------
def gradio_reset(chat_state, img_list, similarity_output):
"""Resets the chat state and image list."""
if chat_state is not None:
chat_state.messages = []
if img_list is not None:
img_list = []
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your medical image first', interactive=False), gr.update(value="Upload & Start Analysis", interactive=True), chat_state, img_list, gr.update(value="", visible=False)
def upload_img(gr_img, text_input, chat_state, similarity_output):
"""Handles image upload."""
if gr_img is None:
return None, None, gr.update(interactive=True), chat_state, None, gr.update(visible=False)
img_list = [gr_img]
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Analysis", interactive=False), chat_state, img_list, gr.update(visible=True)
def gradio_ask(user_message, chatbot, chat_state):
"""Handles user input."""
if not user_message:
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
chatbot = chatbot + [[user_message, None]]
return '', chatbot, chat_state
@spaces.GPU
def gradio_answer(chatbot, chat_state, img_list, biomedclip_model, biomedclip_processor, similarity_output):
"""Computes and displays similarity scores."""
if not img_list:
return chatbot, chat_state, img_list, similarity_output
similarity_score = compute_similarity(img_list[0], chatbot[-1][0], biomedclip_model, biomedclip_processor)
print(f'Similarity Score is: {similarity_score}')
similarity_text = f"Similarity Score: {similarity_score:.3f}"
chatbot[-1][1] = similarity_text
return chatbot, chat_state, img_list, gr.update(value=similarity_text, visible=True)
title = """<h1 align="center">Medical Image Analysis Tool</h1>"""
description = """<h3>Upload medical images, ask questions, and receive a similarity score.</h3>"""
examples_list=[
["./case1.png", "Analyze the X-ray for any abnormalities."],
["./case2.jpg", "What type of disease may be present?"],
["./case1.png","What is the anatomical structure shown here?"]
]
# Load models and related resources outside of the Gradio block for loading on startup
biomedclip_model, biomedclip_processor = load_biomedclip_model()
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=0.5):
image = gr.Image(type="pil", label="Medical Image")
upload_button = gr.Button(value="Upload & Start Analysis", interactive=True, variant="primary")
clear = gr.Button("Restart")
with gr.Column():
chat_state = gr.State()
img_list = gr.State()
chatbot = gr.Chatbot(label='Medical Analysis')
text_input = gr.Textbox(label='Analysis Query', placeholder='Please upload your medical image first', interactive=False)
similarity_output = gr.Textbox(label="Similarity Score", visible=False, interactive=False)
gr.Examples(examples=examples_list, inputs=[image, text_input])
upload_button.click(upload_img, [image, text_input, chat_state, similarity_output], [image, text_input, upload_button, chat_state, img_list, similarity_output])
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
gradio_answer, [chatbot, chat_state, img_list, biomedclip_model, biomedclip_processor, similarity_output], [chatbot, chat_state, img_list, similarity_output]
)
clear.click(gradio_reset, [chat_state, img_list, similarity_output], [chatbot, image, text_input, upload_button, chat_state, img_list, similarity_output], queue=False)
demo.launch() |