|
|
|
""" |
|
DeepSurg Technologies Ltd. (c) 2025 |
|
Surgical VLLM - v1 |
|
""" |
|
|
|
import os |
|
import torch |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
from transformers import BertTokenizer |
|
|
|
|
|
from models.VisualBertClassification_ssgqa import VisualBertClassification |
|
|
|
|
|
from mmengine.config import Config |
|
from utils.SurgVLP import surgvlp |
|
|
|
import random |
|
|
|
|
|
import gradio as gr |
|
|
|
image_files = None |
|
selectedID = 0 |
|
question_dropdown = None |
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
|
|
|
def seed_everything(seed=27): |
|
torch.manual_seed(seed) |
|
|
|
os.environ["PYTHONHASHSEED"] = str(seed) |
|
|
|
|
|
|
|
def load_visualbert_model(tokenizer, device, num_class=51, encoder_layers=6, n_heads=8, dropout=0.1, emb_dim=300): |
|
""" |
|
Initialize the VisualBertClassification model and load the checkpoint. |
|
""" |
|
model = VisualBertClassification( |
|
vocab_size=len(tokenizer), |
|
layers=encoder_layers, |
|
n_heads=n_heads, |
|
num_class=num_class, |
|
) |
|
checkpoint = torch.load("./checkpoint.tar", map_location=device) |
|
model.load_state_dict(checkpoint["model"]) |
|
model.to(device) |
|
model.eval() |
|
return model |
|
|
|
def load_surgvlp_encoder(device): |
|
""" |
|
Load the SurgVLP encoder and its preprocessing function. |
|
""" |
|
config_path = './utils/config_surgvlp.py' |
|
configs = Config.fromfile(config_path)['config'] |
|
encoder_model, encoder_preprocess = surgvlp.load(configs.model_config, device=device, pretrain='./SurgVLP2.pth') |
|
encoder_model.eval() |
|
return encoder_model, encoder_preprocess |
|
|
|
|
|
LABEL_LIST = [ |
|
"0", "1", "10", "2", "3", "4", "5", "6", "7", "8", "9", |
|
"False", "True", "abdominal_wall_cavity", "adhesion", "anatomy", |
|
"aspirate", "bipolar", "blood_vessel", "blue", "brown", "clip", |
|
"clipper", "coagulate", "cut", "cystic_artery", "cystic_duct", |
|
"cystic_pedicle", "cystic_plate", "dissect", "fluid", "gallbladder", |
|
"grasp", "grasper", "gut", "hook", "instrument", "irrigate", "irrigator", |
|
"liver", "omentum", "pack", "peritoneum", "red", "retract", "scissors", |
|
"silver", "specimen_bag", "specimenbag", "white", "yellow" |
|
] |
|
|
|
def main(): |
|
seed_everything() |
|
device = "cpu" |
|
tokenizer = BertTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") |
|
visualbert_model = load_visualbert_model(tokenizer, device) |
|
encoder_model, encoder_preprocess = load_surgvlp_encoder(device) |
|
|
|
print("Models loaded successfully.") |
|
|
|
|
|
global image_files |
|
images_dir = "./test_data/images/VID" |
|
labels_dir = "./test_data/labels/VID/" |
|
image_files = [os.path.join(images_dir, f) for f in sorted(os.listdir(images_dir)) if f.lower().endswith('.png')] |
|
random.shuffle(image_files) |
|
|
|
print(f"Found {len(image_files)} images.") |
|
|
|
|
|
image_files = image_files[:20] |
|
|
|
|
|
questions = [] |
|
for image_path in image_files: |
|
|
|
image_id = int(os.path.basename(image_path).replace('.png', '')) |
|
label_path = os.path.join(labels_dir, f"{image_id}.txt") |
|
try: |
|
with open(label_path, 'r') as f: |
|
lines = f.readlines() |
|
for line in lines: |
|
|
|
questions.append(line.split("|")[0]) |
|
except Exception as e: |
|
|
|
continue |
|
|
|
|
|
|
|
def predict_image(selected_images, question): |
|
""" |
|
Processes the selected image (by file path) along with the surgical question. |
|
Returns a text summary that includes the image file name and top-3 predictions. |
|
""" |
|
if not selected_images: |
|
return "Please select an image from the list." |
|
if question.strip() == "": |
|
return "Please select a question from the dropdown." |
|
|
|
|
|
image_path = image_files[selectedID] |
|
try: |
|
pil_image = Image.open(image_path).convert("RGB") |
|
except Exception as e: |
|
return f"Could not open image: {str(e)}" |
|
|
|
image_processed = encoder_preprocess(pil_image).unsqueeze(0).to(device) |
|
with torch.no_grad(): |
|
visual_features = encoder_model(image_processed, None, mode='video')['img_emb'] |
|
visual_features /= visual_features.norm(dim=-1, keepdim=True) |
|
visual_features = visual_features.unsqueeze(1) |
|
|
|
inputs = tokenizer( |
|
[question], |
|
return_tensors="pt", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=77, |
|
) |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
outputs = visualbert_model(inputs, visual_features) |
|
probabilities = F.softmax(outputs, dim=1) |
|
topk = torch.topk(probabilities, k=3, dim=1) |
|
|
|
topk_scores = topk.values.cpu().numpy().flatten() |
|
topk_indices = topk.indices.cpu().numpy().flatten() |
|
top_predictions = [(LABEL_LIST[i], float(score)) for i, score in zip(topk_indices, topk_scores)] |
|
|
|
image_name = os.path.basename(image_path) |
|
output_str = f"Frame: {image_name}\n\nTop 3 Predictions:\n" |
|
for rank, (lbl, score) in enumerate(top_predictions, start=1): |
|
output_str += f"Rank {rank}: {lbl} ({score:.4f})\t\t\t" |
|
print(f"Selected image: {image_name}") |
|
return output_str |
|
|
|
|
|
def update_selected(selection: gr.SelectData): |
|
global selectedID |
|
global question_dropdown |
|
selectedID = selection.index |
|
|
|
question_dropdown = gr.Dropdown( |
|
choices=questions[selectedID], |
|
label="Select a Question" |
|
) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# DeepSurg Surgical VQA Demo (V1)") |
|
gr.Markdown("## Cholecystectomy Surgery VLLM") |
|
gr.Markdown("### Current version supports label-based answers only.") |
|
|
|
|
|
|
|
image_gallery = gr.Gallery( |
|
value=image_files, |
|
label="Select an Image", |
|
interactive=True, |
|
allow_preview = True, |
|
preview = True, |
|
columns=[20], |
|
) |
|
|
|
image_gallery.select(fn=update_selected, inputs=None) |
|
|
|
|
|
global question_dropdown |
|
question_dropdown = gr.Dropdown( |
|
choices=questions, |
|
label="Select a Question" |
|
) |
|
generate_btn = gr.Button("Generate") |
|
predictions_output = gr.Textbox(label="Predictions", lines=10) |
|
|
|
generate_btn.click( |
|
fn=predict_image, |
|
inputs=[image_gallery, question_dropdown], |
|
outputs=predictions_output |
|
) |
|
|
|
print("Launching the Gradio UI...") |
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|