yaziciz commited on
Commit
a14e3ff
·
verified ·
1 Parent(s): 375707b

Upload demo.py

Browse files
Files changed (1) hide show
  1. demo.py +202 -0
demo.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ DeepSurg Technologies Ltd. (c) 2025
4
+ Surgical VLLM - v1
5
+ """
6
+
7
+ import os
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from PIL import Image
11
+ from transformers import BertTokenizer
12
+
13
+ # Import the VisualBertClassification model (ensure the module is in your PYTHONPATH)
14
+ from models.VisualBertClassification_ssgqa import VisualBertClassification
15
+
16
+ # For SurgVLP encoder
17
+ from mmengine.config import Config
18
+ from utils.SurgVLP import surgvlp
19
+
20
+ import random
21
+
22
+ # For Gradio UI
23
+ import gradio as gr
24
+
25
+ image_files = None
26
+ selectedID = 0
27
+ question_dropdown = None
28
+
29
+ def seed_everything(seed=27):
30
+ torch.manual_seed(seed)
31
+ torch.cuda.manual_seed_all(seed)
32
+ os.environ["PYTHONHASHSEED"] = str(seed)
33
+ torch.backends.cudnn.deterministic = True
34
+ torch.backends.cudnn.benchmark = False
35
+
36
+ def load_visualbert_model(tokenizer, device, num_class=51, encoder_layers=6, n_heads=8, dropout=0.1, emb_dim=300):
37
+ """
38
+ Initialize the VisualBertClassification model and load the checkpoint.
39
+ """
40
+ model = VisualBertClassification(
41
+ vocab_size=len(tokenizer),
42
+ layers=encoder_layers,
43
+ n_heads=n_heads,
44
+ num_class=num_class,
45
+ )
46
+ checkpoint = torch.load("checkpoint.tar", map_location=device)
47
+ model.load_state_dict(checkpoint["model"])
48
+ model.to(device)
49
+ model.eval()
50
+ return model
51
+
52
+ def load_surgvlp_encoder(device):
53
+ """
54
+ Load the SurgVLP encoder and its preprocessing function.
55
+ """
56
+ config_path = './utils/config_surgvlp.py'
57
+ configs = Config.fromfile(config_path)['config']
58
+ encoder_model, encoder_preprocess = surgvlp.load(configs.model_config, device=device, pretrain='SurgVLP.pth')
59
+ encoder_model.eval()
60
+ return encoder_model, encoder_preprocess
61
+
62
+ # Label conversion list (mapping model output indices to text labels)
63
+ LABEL_LIST = [
64
+ "0", "1", "10", "2", "3", "4", "5", "6", "7", "8", "9",
65
+ "False", "True", "abdominal_wall_cavity", "adhesion", "anatomy",
66
+ "aspirate", "bipolar", "blood_vessel", "blue", "brown", "clip",
67
+ "clipper", "coagulate", "cut", "cystic_artery", "cystic_duct",
68
+ "cystic_pedicle", "cystic_plate", "dissect", "fluid", "gallbladder",
69
+ "grasp", "grasper", "gut", "hook", "instrument", "irrigate", "irrigator",
70
+ "liver", "omentum", "pack", "peritoneum", "red", "retract", "scissors",
71
+ "silver", "specimen_bag", "specimenbag", "white", "yellow"
72
+ ]
73
+
74
+ def main():
75
+ seed_everything()
76
+ device = "cuda" if torch.cuda.is_available() else "cpu"
77
+ tokenizer = BertTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
78
+ visualbert_model = load_visualbert_model(tokenizer, device)
79
+ encoder_model, encoder_preprocess = load_surgvlp_encoder(device)
80
+
81
+ # Define the directories containing images and corresponding label files.
82
+ global image_files
83
+ images_dir = "./test_data/images/VID22/"
84
+ labels_dir = "./test_data/labels/VID22/"
85
+ image_files = [os.path.join(images_dir, f) for f in sorted(os.listdir(images_dir)) if f.lower().endswith('.png')]
86
+ random.shuffle(image_files)
87
+ # Get first 20 images.
88
+ image_files = image_files[:20]
89
+
90
+ # Build a predefined questions array (by reading the label files for each image).
91
+ questions = []
92
+ for image_path in image_files:
93
+
94
+ image_id = int(os.path.basename(image_path).replace('.png', ''))
95
+ label_path = os.path.join(labels_dir, f"{image_id}.txt")
96
+ try:
97
+ with open(label_path, 'r') as f:
98
+ lines = f.readlines()
99
+ for line in lines:
100
+ # Split each line at '|' and take the first part as the question.
101
+ questions.append(line.split("|")[0])
102
+ except Exception as e:
103
+ # If a file is missing, skip it.
104
+ continue
105
+
106
+ # Remove duplicates (optional) and sort.
107
+
108
+ def predict_image(selected_images, question):
109
+ """
110
+ Processes the selected image (by file path) along with the surgical question.
111
+ Returns a text summary that includes the image file name and top-3 predictions.
112
+ """
113
+ if not selected_images:
114
+ return "Please select an image from the list."
115
+ if question.strip() == "":
116
+ return "Please select a question from the dropdown."
117
+
118
+ # Use the global selectedID to pick the image.
119
+ image_path = image_files[selectedID]
120
+ try:
121
+ pil_image = Image.open(image_path).convert("RGB")
122
+ except Exception as e:
123
+ return f"Could not open image: {str(e)}"
124
+
125
+ image_processed = encoder_preprocess(pil_image).unsqueeze(0).to(device)
126
+ with torch.no_grad():
127
+ visual_features = encoder_model(image_processed, None, mode='video')['img_emb']
128
+ visual_features /= visual_features.norm(dim=-1, keepdim=True)
129
+ visual_features = visual_features.unsqueeze(1)
130
+
131
+ inputs = tokenizer(
132
+ [question],
133
+ return_tensors="pt",
134
+ padding="max_length",
135
+ truncation=True,
136
+ max_length=77,
137
+ )
138
+ inputs = {k: v.to(device) for k, v in inputs.items()}
139
+
140
+ with torch.no_grad():
141
+ outputs = visualbert_model(inputs, visual_features)
142
+ probabilities = F.softmax(outputs, dim=1)
143
+ topk = torch.topk(probabilities, k=3, dim=1)
144
+
145
+ topk_scores = topk.values.cpu().numpy().flatten()
146
+ topk_indices = topk.indices.cpu().numpy().flatten()
147
+ top_predictions = [(LABEL_LIST[i], float(score)) for i, score in zip(topk_indices, topk_scores)]
148
+
149
+ image_name = os.path.basename(image_path)
150
+ output_str = f"\nTop 3 Predictions:\n"
151
+ for rank, (lbl, score) in enumerate(top_predictions, start=1):
152
+ output_str += f"Rank {rank}: {lbl} ({score:.4f})\t\t\t"
153
+ print(f"Selected image: {image_name}")
154
+ return output_str
155
+
156
+ # Callback to update the global selectedID when the user selects an image from the SelectData.
157
+ def update_selected(selection: gr.SelectData):
158
+ global selectedID
159
+ global question_dropdown
160
+ selectedID = selection.index
161
+
162
+ question_dropdown = gr.Dropdown(
163
+ choices=questions[selectedID],
164
+ label="Select a Question"
165
+ )
166
+
167
+ with gr.Blocks() as demo:
168
+ gr.Markdown("# DeepSurg Surgical VQA Demo (V1)")
169
+ gr.Markdown("## Cholecystectomy Surgery VLLM")
170
+ gr.Markdown("### Current version supports label-based answers only.")
171
+
172
+ #add a logo here
173
+ # Use gr.SelectData to let the user choose one image.
174
+ image_gallery = gr.Gallery(
175
+ value=image_files,
176
+ label="Select an Image",
177
+ interactive=True,
178
+ allow_preview = True,
179
+ preview = True,
180
+ columns=[20],
181
+ )
182
+
183
+ image_gallery.select(fn=update_selected, inputs=None)
184
+ # Dropdown for selecting a predefined question.
185
+
186
+ global question_dropdown
187
+ question_dropdown = gr.Dropdown(
188
+ choices=questions,
189
+ label="Select a Question"
190
+ )
191
+ generate_btn = gr.Button("Generate")
192
+ predictions_output = gr.Textbox(label="Predictions", lines=10)
193
+
194
+ generate_btn.click(
195
+ fn=predict_image,
196
+ inputs=[image_gallery, question_dropdown],
197
+ outputs=predictions_output
198
+ )
199
+ demo.launch()
200
+
201
+ if __name__ == "__main__":
202
+ main()