|
import gradio as gr |
|
from llava.model.builder import load_pretrained_model |
|
from llava.mm_utils import process_images, tokenizer_image_token |
|
from llava.constants import IMAGE_TOKEN_INDEX |
|
import torch |
|
from PIL import Image |
|
|
|
model_path = "microsoft/llava-med-v1.5-mistral-7b" |
|
tokenizer, model, image_processor, _ = load_pretrained_model( |
|
model_path=model_path, |
|
model_base=None, |
|
model_name="llava-med-v1.5-mistral-7b", |
|
load_4bit=False, |
|
device_map="cpu" |
|
) |
|
model.to('cpu') |
|
|
|
def analyze_medical_image(image, question): |
|
if isinstance(image, str): |
|
image = Image.open(image) |
|
else: |
|
image = Image.fromarray(image) |
|
|
|
image_tensor = process_images([image], image_processor, model.config)[0] |
|
prompt = f"USER: <image>\n{question}\nASSISTANT:" |
|
|
|
input_ids = tokenizer_image_token( |
|
prompt, |
|
tokenizer, |
|
IMAGE_TOKEN_INDEX, |
|
return_tensors='pt' |
|
).unsqueeze(0) |
|
|
|
with torch.inference_mode(): |
|
output_ids = model.generate( |
|
input_ids, |
|
images=image_tensor.unsqueeze(0), |
|
max_new_tokens=512, |
|
do_sample=True, |
|
temperature=0.7, |
|
use_cache=True |
|
) |
|
|
|
response = tokenizer.decode( |
|
output_ids[0][input_ids.shape[1]:], |
|
skip_special_tokens=True |
|
).strip() |
|
|
|
return response |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# LLaVA-Med Medical Image Analysis") |
|
gr.Markdown("Ask questions about medical images using Microsoft's LLaVA-Med 1.5-Mistral-7B") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(label="Upload Medical Image", type="pil") |
|
question_input = gr.Textbox(label="Question", placeholder="Ask about the medical image...") |
|
submit_btn = gr.Button("Analyze") |
|
|
|
with gr.Column(): |
|
output_text = gr.Textbox(label="Analysis Result", interactive=False) |
|
|
|
examples = gr.Examples( |
|
examples=[ |
|
["examples/chest_xray.jpg", "What abnormalities are present in this chest X-ray?"], |
|
["examples/retina_scan.jpg", "Are there any signs of diabetic retinopathy?"] |
|
], |
|
inputs=[image_input, question_input], |
|
label="Example Queries" |
|
) |
|
|
|
submit_btn.click( |
|
fn=analyze_medical_image, |
|
inputs=[image_input, question_input], |
|
outputs=output_text |
|
) |
|
|
|
demo.queue(max_size=10).launch() |