LLaVA-Med / app.py
ayyuce's picture
Create app.py
e41d0cb verified
raw
history blame
2.49 kB
import gradio as gr
from llava.model.builder import load_pretrained_model
from llava.mm_utils import process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX
import torch
from PIL import Image
model_path = "microsoft/llava-med-v1.5-mistral-7b"
tokenizer, model, image_processor, _ = load_pretrained_model(
model_path=model_path,
model_base=None,
model_name="llava-med-v1.5-mistral-7b",
load_4bit=False, # Disable 4-bit quantization for CPU
device_map="cpu" # Force CPU usage
)
model.to('cpu')
def analyze_medical_image(image, question):
if isinstance(image, str):
image = Image.open(image)
else:
image = Image.fromarray(image)
image_tensor = process_images([image], image_processor, model.config)[0]
prompt = f"USER: <image>\n{question}\nASSISTANT:"
input_ids = tokenizer_image_token(
prompt,
tokenizer,
IMAGE_TOKEN_INDEX,
return_tensors='pt'
).unsqueeze(0)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor.unsqueeze(0),
max_new_tokens=512,
do_sample=True,
temperature=0.7,
use_cache=True
)
response = tokenizer.decode(
output_ids[0][input_ids.shape[1]:],
skip_special_tokens=True
).strip()
return response
with gr.Blocks() as demo:
gr.Markdown("# LLaVA-Med Medical Image Analysis")
gr.Markdown("Ask questions about medical images using Microsoft's LLaVA-Med 1.5-Mistral-7B")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Upload Medical Image", type="pil")
question_input = gr.Textbox(label="Question", placeholder="Ask about the medical image...")
submit_btn = gr.Button("Analyze")
with gr.Column():
output_text = gr.Textbox(label="Analysis Result", interactive=False)
examples = gr.Examples(
examples=[
["examples/chest_xray.jpg", "What abnormalities are present in this chest X-ray?"],
["examples/retina_scan.jpg", "Are there any signs of diabetic retinopathy?"]
],
inputs=[image_input, question_input],
label="Example Queries"
)
submit_btn.click(
fn=analyze_medical_image,
inputs=[image_input, question_input],
outputs=output_text
)
demo.queue(max_size=10).launch()