SurgVLLM / demo.py
yaziciz's picture
Update demo.py
bf64a7f verified
#!/usr/bin/env python
"""
DeepSurg Technologies Ltd. (c) 2025
Surgical VLLM - v1
"""
import os
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import BertTokenizer
# Import the VisualBertClassification model (ensure the module is in your PYTHONPATH)
from models.VisualBertClassification_ssgqa import VisualBertClassification
# For SurgVLP encoder
from mmengine.config import Config
from utils.SurgVLP import surgvlp
import random
# For Gradio UI
import gradio as gr
image_files = None
selectedID = 0
question_dropdown = None
#NO GPU is available
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
def seed_everything(seed=27):
torch.manual_seed(seed)
#torch.cuda.manual_seed_all(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
#torch.backends.cudnn.deterministic = True
#torch.backends.cudnn.benchmark = False
def load_visualbert_model(tokenizer, device, num_class=51, encoder_layers=6, n_heads=8, dropout=0.1, emb_dim=300):
"""
Initialize the VisualBertClassification model and load the checkpoint.
"""
model = VisualBertClassification(
vocab_size=len(tokenizer),
layers=encoder_layers,
n_heads=n_heads,
num_class=num_class,
)
checkpoint = torch.load("./checkpoint.tar", map_location=device)
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()
return model
def load_surgvlp_encoder(device):
"""
Load the SurgVLP encoder and its preprocessing function.
"""
config_path = './utils/config_surgvlp.py'
configs = Config.fromfile(config_path)['config']
encoder_model, encoder_preprocess = surgvlp.load(configs.model_config, device=device, pretrain='./SurgVLP2.pth')
encoder_model.eval()
return encoder_model, encoder_preprocess
# Label conversion list (mapping model output indices to text labels)
LABEL_LIST = [
"0", "1", "10", "2", "3", "4", "5", "6", "7", "8", "9",
"False", "True", "abdominal_wall_cavity", "adhesion", "anatomy",
"aspirate", "bipolar", "blood_vessel", "blue", "brown", "clip",
"clipper", "coagulate", "cut", "cystic_artery", "cystic_duct",
"cystic_pedicle", "cystic_plate", "dissect", "fluid", "gallbladder",
"grasp", "grasper", "gut", "hook", "instrument", "irrigate", "irrigator",
"liver", "omentum", "pack", "peritoneum", "red", "retract", "scissors",
"silver", "specimen_bag", "specimenbag", "white", "yellow"
]
def main():
seed_everything()
device = "cpu"
tokenizer = BertTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
visualbert_model = load_visualbert_model(tokenizer, device)
encoder_model, encoder_preprocess = load_surgvlp_encoder(device)
print("Models loaded successfully.")
# Define the directories containing images and corresponding label files.
global image_files
images_dir = "./test_data/images/VID"
labels_dir = "./test_data/labels/VID/"
image_files = [os.path.join(images_dir, f) for f in sorted(os.listdir(images_dir)) if f.lower().endswith('.png')]
random.shuffle(image_files)
print(f"Found {len(image_files)} images.")
# Get first 20 images.
image_files = image_files[:20]
# Build a predefined questions array (by reading the label files for each image).
questions = []
for image_path in image_files:
image_id = int(os.path.basename(image_path).replace('.png', ''))
label_path = os.path.join(labels_dir, f"{image_id}.txt")
try:
with open(label_path, 'r') as f:
lines = f.readlines()
for line in lines:
# Split each line at '|' and take the first part as the question.
questions.append(line.split("|")[0])
except Exception as e:
# If a file is missing, skip it.
continue
# Remove duplicates (optional) and sort.
def predict_image(selected_images, question):
"""
Processes the selected image (by file path) along with the surgical question.
Returns a text summary that includes the image file name and top-3 predictions.
"""
if not selected_images:
return "Please select an image from the list."
if question.strip() == "":
return "Please select a question from the dropdown."
# Use the global selectedID to pick the image.
image_path = image_files[selectedID]
try:
pil_image = Image.open(image_path).convert("RGB")
except Exception as e:
return f"Could not open image: {str(e)}"
image_processed = encoder_preprocess(pil_image).unsqueeze(0).to(device)
with torch.no_grad():
visual_features = encoder_model(image_processed, None, mode='video')['img_emb']
visual_features /= visual_features.norm(dim=-1, keepdim=True)
visual_features = visual_features.unsqueeze(1)
inputs = tokenizer(
[question],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=77,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = visualbert_model(inputs, visual_features)
probabilities = F.softmax(outputs, dim=1)
topk = torch.topk(probabilities, k=3, dim=1)
topk_scores = topk.values.cpu().numpy().flatten()
topk_indices = topk.indices.cpu().numpy().flatten()
top_predictions = [(LABEL_LIST[i], float(score)) for i, score in zip(topk_indices, topk_scores)]
image_name = os.path.basename(image_path)
output_str = f"Frame: {image_name}\n\nTop 3 Predictions:\n"
for rank, (lbl, score) in enumerate(top_predictions, start=1):
output_str += f"Rank {rank}: {lbl} ({score:.4f})\t\t\t"
print(f"Selected image: {image_name}")
return output_str
# Callback to update the global selectedID when the user selects an image from the SelectData.
def update_selected(selection: gr.SelectData):
global selectedID
global question_dropdown
selectedID = selection.index
question_dropdown = gr.Dropdown(
choices=questions[selectedID],
label="Select a Question"
)
with gr.Blocks() as demo:
gr.Markdown("# DeepSurg Surgical VQA Demo (V1)")
gr.Markdown("## Cholecystectomy Surgery VLLM")
gr.Markdown("### Current version supports label-based answers only.")
#add a logo here
# Use gr.SelectData to let the user choose one image.
image_gallery = gr.Gallery(
value=image_files,
label="Select an Image",
interactive=True,
allow_preview = True,
preview = True,
columns=[20],
)
image_gallery.select(fn=update_selected, inputs=None)
# Dropdown for selecting a predefined question.
global question_dropdown
question_dropdown = gr.Dropdown(
choices=questions,
label="Select a Question"
)
generate_btn = gr.Button("Generate")
predictions_output = gr.Textbox(label="Predictions", lines=10)
generate_btn.click(
fn=predict_image,
inputs=[image_gallery, question_dropdown],
outputs=predictions_output
)
print("Launching the Gradio UI...")
demo.launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
main()