|
import streamlit as st |
|
from PIL import Image |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoProcessor |
|
import numpy as np |
|
import supervision as sv |
|
import albumentations as A |
|
import cv2 |
|
from transformers import AutoConfig |
|
import yaml |
|
|
|
|
|
st.set_page_config(layout="wide") |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
.main { |
|
max-width: 1200px; /* Max width for content */ |
|
margin: 0 auto; |
|
} |
|
.block-container { |
|
padding-top: 2rem; |
|
padding-bottom: 2rem; |
|
padding-left: 3rem; |
|
padding-right: 3rem; |
|
} |
|
.title { |
|
font-size: 2.5rem; |
|
text-align: center; |
|
color: #FF6347; |
|
} |
|
.subheader { |
|
font-size: 1.5rem; |
|
margin-bottom: 20px; |
|
} |
|
.btn { |
|
font-size: 1.1rem; |
|
padding: 10px 20px; |
|
background-color: #FF6347; |
|
color: white; |
|
border-radius: 5px; |
|
border: none; |
|
cursor: pointer; |
|
} |
|
.btn:hover { |
|
background-color: #FF4500; |
|
} |
|
.column-spacing { |
|
display: flex; |
|
justify-content: space-between; |
|
} |
|
.col-half { |
|
width: 48%; |
|
} |
|
.col-full { |
|
width: 100%; |
|
} |
|
.instructions { |
|
padding: 20px; |
|
background-color: #f9f9f9; |
|
border-radius: 8px; |
|
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
REVISION = 'refs/pr/6' |
|
MODEL_NAME = "RioJune/AG-KD" |
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
config_model = AutoConfig.from_pretrained ("microsoft/Florence-2-base-ft", trust_remote_code=True) |
|
config_model.vision_config.model_type = "davit" |
|
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True, config=config_model).to(DEVICE) |
|
|
|
BASE_PROCESSOR = "microsoft/Florence-2-base-ft" |
|
processor = AutoProcessor.from_pretrained(BASE_PROCESSOR, trust_remote_code=True) |
|
processor.image_processor.size = 512 |
|
processor.image_processor.crop_size = 512 |
|
|
|
return model, processor, DEVICE |
|
|
|
model, processor, DEVICE = load_model() |
|
|
|
|
|
@st.cache_resource |
|
def load_definitions(): |
|
vindr_path = 'configs/vindr_definition.yaml' |
|
padchest_path = 'configs/padchest_definition.yaml' |
|
prompt_path = 'examples/prompt.yaml' |
|
|
|
with open(vindr_path, 'r') as file: |
|
vindr_definitions = yaml.safe_load(file) |
|
with open(padchest_path, 'r') as file: |
|
padchest_definitions = yaml.safe_load(file) |
|
with open(prompt_path, 'r') as file: |
|
prompt_definitions = yaml.safe_load(file) |
|
|
|
return vindr_definitions, padchest_definitions, prompt_definitions |
|
|
|
vindr_definitions, padchest_definitions, prompt_definitions = load_definitions() |
|
|
|
dataset_options = {"Vindr": vindr_definitions, "PadChest": padchest_definitions} |
|
|
|
def load_example_images(): |
|
return list(prompt_definitions.keys()) |
|
|
|
example_images = load_example_images() |
|
|
|
def apply_transform(image, size_mode=512): |
|
pad_resize_transform = A.Compose([ |
|
A.LongestMaxSize(max_size=size_mode, interpolation=cv2.INTER_AREA), |
|
A.PadIfNeeded(min_height=size_mode, min_width=size_mode, border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0)), |
|
A.Resize(height=512, width=512, interpolation=cv2.INTER_AREA), |
|
]) |
|
image_np = np.array(image) |
|
transformed = pad_resize_transform(image=image_np) |
|
return transformed["image"] |
|
|
|
|
|
st.markdown("<h1 class='title'>π©Ί Enhancing Abnormality Grounding for Vision Language Models with Knowledge Descriptions π</h1>", unsafe_allow_html=True) |
|
st.markdown( |
|
"<p style='text-align: center; font-size: 18px;'>Welcome to a simple demo of our work! π Choose an example or upload your own image to get started! π</p>", |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
st.subheader("π Example Images") |
|
selected_example = st.selectbox("Choose an example", example_images) |
|
image = Image.open(selected_example).convert("RGB") |
|
example_diseases = prompt_definitions.get(selected_example, []) |
|
st.write("**Associated Diseases:**", ", ".join(example_diseases)) |
|
|
|
|
|
col1, col2 = st.columns([1, 2]) |
|
|
|
|
|
with col1: |
|
st.image(image, caption=f"Original Example Image: {selected_example}", width=400) |
|
|
|
|
|
with col2: |
|
st.subheader("βοΈ Instructions to Get Started:") |
|
st.write(""" |
|
- **Run Inference**: Click the "Run Inference on Example" button to process the image and display the results. |
|
- **Choose an Example**: π Select an example image from the dataset to view its associated diseases. |
|
- **Upload Your Own Image**: π€ Upload an image of your choice to analyze it for diseases. |
|
- **Select Dataset**: π Choose between available datasets (Vindr or PadChest) for disease information. |
|
- **Select Disease**: π¦ Pick the disease to be analyzed from the list of diseases in the selected dataset. |
|
""") |
|
|
|
st.subheader("β οΈ Warning:") |
|
st.write(""" |
|
- **π« Please avoid uploading non-frontal chest X-ray images.** Our model has been specifically trained on **frontal chest X-ray images** only. |
|
- This demo is intended for **π¬ research purposes only** and should **β not be used for medical diagnoses**. |
|
- The modelβs responses may contain **<span style='color:#dc3545; font-weight:bold;'>π€ hallucinations or incorrect information</span>**. |
|
- Always consult a **<span style='color:#dc3545; font-weight:bold;'>π¨ββοΈ medical professional</span>** for accurate diagnosis and advice. |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.markdown("</div>", unsafe_allow_html=True) |
|
|
|
|
|
if st.button("Run Inference on Example", key="example"): |
|
if image is None: |
|
st.error("β Please select an example image first.") |
|
else: |
|
|
|
disease_choice = example_diseases[0] if example_diseases else "" |
|
definition = vindr_definitions.get(disease_choice, padchest_definitions.get(disease_choice, "")) |
|
|
|
|
|
det_obj = f"{disease_choice} means {definition}." |
|
st.write(f"**Definition:** {definition}") |
|
prompt = f"Locate the phrases in the caption: {det_obj}." |
|
prompt = f"<CAPTION_TO_PHRASE_GROUNDING>{prompt}" |
|
|
|
|
|
np_image = np.array(image) |
|
inputs = processor(text=[prompt], images=[np_image], return_tensors="pt", padding=True).to(DEVICE) |
|
|
|
with st.spinner("Processing... β³"): |
|
outputs = model.generate( |
|
input_ids=inputs["input_ids"], |
|
pixel_values=inputs["pixel_values"], |
|
max_new_tokens=1024, |
|
num_beams=3, |
|
output_scores=True, |
|
return_dict_in_generate=True |
|
) |
|
|
|
|
|
|
|
transition_scores = model.compute_transition_scores( |
|
outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False |
|
) |
|
|
|
|
|
generated_ids = outputs.sequences |
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] |
|
|
|
|
|
input_length = inputs.input_ids.shape[1] |
|
generated_tokens = outputs.sequences |
|
|
|
|
|
output_length = np.sum(transition_scores.cpu().numpy() < 0, axis=1) |
|
|
|
|
|
length_penalty = model.generation_config.length_penalty |
|
|
|
|
|
reconstructed_scores = transition_scores.cpu().sum(axis=1) / (output_length**length_penalty) |
|
|
|
|
|
probabilities = np.exp(reconstructed_scores.cpu().numpy()) |
|
|
|
|
|
st.markdown(f"**π― Probability of the Results:** <span style='color:#28a745; font-size:24px; font-weight:bold;'>{probabilities[0] * 100:.2f}%</span>", unsafe_allow_html=True) |
|
|
|
|
|
predictions = processor.post_process_generation(generated_text, task="<CAPTION_TO_PHRASE_GROUNDING>", image_size=np_image.shape[:2]) |
|
|
|
detection = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, predictions, resolution_wh=np_image.shape[:2]) |
|
|
|
|
|
bounding_box_annotator = sv.BoundingBoxAnnotator(color_lookup=sv.ColorLookup.INDEX) |
|
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX) |
|
image_with_predictions = bounding_box_annotator.annotate(np_image.copy(), detection) |
|
image_with_predictions = label_annotator.annotate(image_with_predictions, detection) |
|
annotated_image = Image.fromarray(image_with_predictions.astype(np.uint8)) |
|
|
|
|
|
col1, col2 = st.columns([1, 1]) |
|
|
|
with col1: |
|
st.image(image, caption=f"Original Image: {selected_example}", width=400) |
|
|
|
with col2: |
|
st.image(annotated_image, caption="Inference Results πΌοΈ", width=400) |
|
|
|
|
|
st.write("**Generated Text:**", generated_text) |
|
|
|
|
|
st.subheader("π€ Upload Your Own Image") |
|
|
|
col1, col2 = st.columns([1, 1]) |
|
with col1: |
|
dataset_choice = st.selectbox("Select Dataset π", options=list(dataset_options.keys())) |
|
disease_options = list(dataset_options[dataset_choice].keys()) |
|
with col2: |
|
disease_choice = st.selectbox("Select Disease π¦ ", options=disease_options) |
|
|
|
uploaded_file = st.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"]) |
|
|
|
|
|
col1, col2 = st.columns([1, 2]) |
|
|
|
with col1: |
|
|
|
if uploaded_file: |
|
image = Image.open(uploaded_file).convert("RGB") |
|
image = apply_transform(image) |
|
st.image(image, caption="Uploaded Image", width=400) |
|
|
|
|
|
disease_choice = disease_choice if disease_choice else example_diseases[0] |
|
|
|
|
|
definition = vindr_definitions.get(disease_choice, padchest_definitions.get(disease_choice, "")) |
|
if not definition: |
|
definition = st.text_input("Enter Definition Manually π", value="") |
|
|
|
with col2: |
|
|
|
st.subheader("βοΈ Instructions to Get Started:") |
|
st.write(""" |
|
- **Run Inference**: Click the "Run Inference on Example" button to process the image and display the results. |
|
- **Choose an Example**: π Select an example image from the dataset to view its associated diseases. |
|
- **Upload Your Own Image**: π€ Upload an image of your choice to analyze it for diseases. |
|
- **Select Dataset**: π Choose between available datasets (Vindr or PadChest) for disease information. |
|
- **Select Disease**: π¦ Pick the disease to be analyzed from the list of diseases in the selected dataset. |
|
""") |
|
|
|
st.subheader("β οΈ Warning:") |
|
st.write(""" |
|
- **π« Please avoid uploading non-frontal chest X-ray images.** Our model has been specifically trained on **frontal chest X-ray images** only. |
|
- This demo is intended for **π¬ research purposes only** and should **β not be used for medical diagnoses**. |
|
- The modelβs responses may contain **<span style='color:#dc3545; font-weight:bold;'>π€ hallucinations or incorrect information</span>**. |
|
- Always consult a **<span style='color:#dc3545; font-weight:bold;'>π¨ββοΈ medical professional</span>** for accurate diagnosis and advice. |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
if st.button("Run Inference πββοΈ"): |
|
if image is None: |
|
st.error("β Please upload an image or select an example.") |
|
else: |
|
det_obj = f"{disease_choice} means {definition}." |
|
st.write(f"**Definition:** {definition}") |
|
|
|
|
|
prompt = f"Locate the phrases in the caption: {det_obj}." |
|
prompt = f"<CAPTION_TO_PHRASE_GROUNDING>{prompt}" |
|
|
|
np_image = np.array(image) |
|
inputs = processor(text=[prompt], images=[np_image], return_tensors="pt", padding=True).to(DEVICE) |
|
|
|
with st.spinner("Processing... β³"): |
|
|
|
|
|
|
|
outputs = model.generate( |
|
input_ids=inputs["input_ids"], |
|
pixel_values=inputs["pixel_values"], |
|
max_new_tokens=1024, |
|
num_beams=3, |
|
output_scores=True, |
|
return_dict_in_generate=True |
|
) |
|
|
|
transition_scores = model.compute_transition_scores( |
|
outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False |
|
) |
|
|
|
|
|
generated_ids = outputs.sequences |
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] |
|
|
|
|
|
input_length = inputs.input_ids.shape[1] |
|
|
|
|
|
|
|
generated_tokens = outputs.sequences |
|
|
|
|
|
output_length = np.sum(transition_scores.cpu().numpy() < 0, axis=1) |
|
|
|
|
|
length_penalty = model.generation_config.length_penalty |
|
|
|
|
|
reconstructed_scores = transition_scores.cpu().sum(axis=1) / (output_length**length_penalty) |
|
|
|
|
|
probabilities = np.exp(reconstructed_scores.cpu().numpy()) |
|
|
|
|
|
|
|
|
|
st.markdown(f"**π― Probability of the Results:** <span style='color:green; font-size:24px; font-weight:bold;'>{probabilities[0] * 100:.2f}%</span>", unsafe_allow_html=True) |
|
|
|
|
|
|
|
predictions = processor.post_process_generation(generated_text, task="<CAPTION_TO_PHRASE_GROUNDING>", image_size=np_image.shape[:2]) |
|
|
|
detection = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, predictions, resolution_wh=np_image.shape[:2]) |
|
|
|
bounding_box_annotator = sv.BoundingBoxAnnotator(color_lookup=sv.ColorLookup.INDEX) |
|
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX) |
|
image_with_predictions = bounding_box_annotator.annotate(np_image.copy(), detection) |
|
image_with_predictions = label_annotator.annotate(image_with_predictions, detection) |
|
annotated_image = Image.fromarray(image_with_predictions.astype(np.uint8)) |
|
|
|
|
|
col1, col2 = st.columns([1, 1]) |
|
|
|
|
|
with col1: |
|
st.image(image, caption="Uploaded Image", width=400) |
|
|
|
|
|
with col2: |
|
st.image(annotated_image, caption="Inference Results πΌοΈ", width=400) |
|
|
|
|
|
st.write("**Generated Text:**", generated_text) |
|
|
|
|
|
|
|
|