Update demo.py
Browse files
demo.py
CHANGED
@@ -26,12 +26,15 @@ image_files = None
|
|
26 |
selectedID = 0
|
27 |
question_dropdown = None
|
28 |
|
|
|
|
|
|
|
29 |
def seed_everything(seed=27):
|
30 |
torch.manual_seed(seed)
|
31 |
-
torch.cuda.manual_seed_all(seed)
|
32 |
os.environ["PYTHONHASHSEED"] = str(seed)
|
33 |
-
torch.backends.cudnn.deterministic = True
|
34 |
-
torch.backends.cudnn.benchmark = False
|
35 |
|
36 |
def load_visualbert_model(tokenizer, device, num_class=51, encoder_layers=6, n_heads=8, dropout=0.1, emb_dim=300):
|
37 |
"""
|
@@ -43,7 +46,7 @@ def load_visualbert_model(tokenizer, device, num_class=51, encoder_layers=6, n_h
|
|
43 |
n_heads=n_heads,
|
44 |
num_class=num_class,
|
45 |
)
|
46 |
-
checkpoint = torch.load("checkpoint.tar", map_location=device)
|
47 |
model.load_state_dict(checkpoint["model"])
|
48 |
model.to(device)
|
49 |
model.eval()
|
@@ -55,7 +58,7 @@ def load_surgvlp_encoder(device):
|
|
55 |
"""
|
56 |
config_path = './utils/config_surgvlp.py'
|
57 |
configs = Config.fromfile(config_path)['config']
|
58 |
-
encoder_model, encoder_preprocess = surgvlp.load(configs.model_config, device=device, pretrain='SurgVLP2.pth')
|
59 |
encoder_model.eval()
|
60 |
return encoder_model, encoder_preprocess
|
61 |
|
@@ -73,17 +76,22 @@ LABEL_LIST = [
|
|
73 |
|
74 |
def main():
|
75 |
seed_everything()
|
76 |
-
device = "
|
77 |
tokenizer = BertTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
|
78 |
visualbert_model = load_visualbert_model(tokenizer, device)
|
79 |
encoder_model, encoder_preprocess = load_surgvlp_encoder(device)
|
|
|
|
|
80 |
|
81 |
# Define the directories containing images and corresponding label files.
|
82 |
global image_files
|
83 |
-
images_dir = "./test_data/images/VID01
|
84 |
labels_dir = "./test_data/labels/VID01/"
|
85 |
image_files = [os.path.join(images_dir, f) for f in sorted(os.listdir(images_dir)) if f.lower().endswith('.png')]
|
86 |
random.shuffle(image_files)
|
|
|
|
|
|
|
87 |
# Get first 20 images.
|
88 |
image_files = image_files[:20]
|
89 |
|
@@ -196,7 +204,9 @@ def main():
|
|
196 |
inputs=[image_gallery, question_dropdown],
|
197 |
outputs=predictions_output
|
198 |
)
|
|
|
|
|
199 |
demo.launch()
|
200 |
|
201 |
if __name__ == "__main__":
|
202 |
-
main()
|
|
|
26 |
selectedID = 0
|
27 |
question_dropdown = None
|
28 |
|
29 |
+
#NO GPU is available
|
30 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
31 |
+
|
32 |
def seed_everything(seed=27):
|
33 |
torch.manual_seed(seed)
|
34 |
+
#torch.cuda.manual_seed_all(seed)
|
35 |
os.environ["PYTHONHASHSEED"] = str(seed)
|
36 |
+
#torch.backends.cudnn.deterministic = True
|
37 |
+
#torch.backends.cudnn.benchmark = False
|
38 |
|
39 |
def load_visualbert_model(tokenizer, device, num_class=51, encoder_layers=6, n_heads=8, dropout=0.1, emb_dim=300):
|
40 |
"""
|
|
|
46 |
n_heads=n_heads,
|
47 |
num_class=num_class,
|
48 |
)
|
49 |
+
checkpoint = torch.load("./checkpoint.tar", map_location=device)
|
50 |
model.load_state_dict(checkpoint["model"])
|
51 |
model.to(device)
|
52 |
model.eval()
|
|
|
58 |
"""
|
59 |
config_path = './utils/config_surgvlp.py'
|
60 |
configs = Config.fromfile(config_path)['config']
|
61 |
+
encoder_model, encoder_preprocess = surgvlp.load(configs.model_config, device=device, pretrain='./SurgVLP2.pth')
|
62 |
encoder_model.eval()
|
63 |
return encoder_model, encoder_preprocess
|
64 |
|
|
|
76 |
|
77 |
def main():
|
78 |
seed_everything()
|
79 |
+
device = "cpu"
|
80 |
tokenizer = BertTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
|
81 |
visualbert_model = load_visualbert_model(tokenizer, device)
|
82 |
encoder_model, encoder_preprocess = load_surgvlp_encoder(device)
|
83 |
+
|
84 |
+
print("Models loaded successfully.")
|
85 |
|
86 |
# Define the directories containing images and corresponding label files.
|
87 |
global image_files
|
88 |
+
images_dir = "./test_data/images/VID01"
|
89 |
labels_dir = "./test_data/labels/VID01/"
|
90 |
image_files = [os.path.join(images_dir, f) for f in sorted(os.listdir(images_dir)) if f.lower().endswith('.png')]
|
91 |
random.shuffle(image_files)
|
92 |
+
|
93 |
+
print(f"Found {len(image_files)} images.")
|
94 |
+
|
95 |
# Get first 20 images.
|
96 |
image_files = image_files[:20]
|
97 |
|
|
|
204 |
inputs=[image_gallery, question_dropdown],
|
205 |
outputs=predictions_output
|
206 |
)
|
207 |
+
|
208 |
+
print("Launching the Gradio UI...")
|
209 |
demo.launch()
|
210 |
|
211 |
if __name__ == "__main__":
|
212 |
+
main()
|