yaziciz commited on
Commit
fca2557
·
verified ·
1 Parent(s): 97a9e39

Update demo.py

Browse files
Files changed (1) hide show
  1. demo.py +18 -8
demo.py CHANGED
@@ -26,12 +26,15 @@ image_files = None
26
  selectedID = 0
27
  question_dropdown = None
28
 
 
 
 
29
  def seed_everything(seed=27):
30
  torch.manual_seed(seed)
31
- torch.cuda.manual_seed_all(seed)
32
  os.environ["PYTHONHASHSEED"] = str(seed)
33
- torch.backends.cudnn.deterministic = True
34
- torch.backends.cudnn.benchmark = False
35
 
36
  def load_visualbert_model(tokenizer, device, num_class=51, encoder_layers=6, n_heads=8, dropout=0.1, emb_dim=300):
37
  """
@@ -43,7 +46,7 @@ def load_visualbert_model(tokenizer, device, num_class=51, encoder_layers=6, n_h
43
  n_heads=n_heads,
44
  num_class=num_class,
45
  )
46
- checkpoint = torch.load("checkpoint.tar", map_location=device)
47
  model.load_state_dict(checkpoint["model"])
48
  model.to(device)
49
  model.eval()
@@ -55,7 +58,7 @@ def load_surgvlp_encoder(device):
55
  """
56
  config_path = './utils/config_surgvlp.py'
57
  configs = Config.fromfile(config_path)['config']
58
- encoder_model, encoder_preprocess = surgvlp.load(configs.model_config, device=device, pretrain='SurgVLP2.pth')
59
  encoder_model.eval()
60
  return encoder_model, encoder_preprocess
61
 
@@ -73,17 +76,22 @@ LABEL_LIST = [
73
 
74
  def main():
75
  seed_everything()
76
- device = "cuda" if torch.cuda.is_available() else "cpu"
77
  tokenizer = BertTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
78
  visualbert_model = load_visualbert_model(tokenizer, device)
79
  encoder_model, encoder_preprocess = load_surgvlp_encoder(device)
 
 
80
 
81
  # Define the directories containing images and corresponding label files.
82
  global image_files
83
- images_dir = "./test_data/images/VID01/"
84
  labels_dir = "./test_data/labels/VID01/"
85
  image_files = [os.path.join(images_dir, f) for f in sorted(os.listdir(images_dir)) if f.lower().endswith('.png')]
86
  random.shuffle(image_files)
 
 
 
87
  # Get first 20 images.
88
  image_files = image_files[:20]
89
 
@@ -196,7 +204,9 @@ def main():
196
  inputs=[image_gallery, question_dropdown],
197
  outputs=predictions_output
198
  )
 
 
199
  demo.launch()
200
 
201
  if __name__ == "__main__":
202
- main()
 
26
  selectedID = 0
27
  question_dropdown = None
28
 
29
+ #NO GPU is available
30
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
31
+
32
  def seed_everything(seed=27):
33
  torch.manual_seed(seed)
34
+ #torch.cuda.manual_seed_all(seed)
35
  os.environ["PYTHONHASHSEED"] = str(seed)
36
+ #torch.backends.cudnn.deterministic = True
37
+ #torch.backends.cudnn.benchmark = False
38
 
39
  def load_visualbert_model(tokenizer, device, num_class=51, encoder_layers=6, n_heads=8, dropout=0.1, emb_dim=300):
40
  """
 
46
  n_heads=n_heads,
47
  num_class=num_class,
48
  )
49
+ checkpoint = torch.load("./checkpoint.tar", map_location=device)
50
  model.load_state_dict(checkpoint["model"])
51
  model.to(device)
52
  model.eval()
 
58
  """
59
  config_path = './utils/config_surgvlp.py'
60
  configs = Config.fromfile(config_path)['config']
61
+ encoder_model, encoder_preprocess = surgvlp.load(configs.model_config, device=device, pretrain='./SurgVLP2.pth')
62
  encoder_model.eval()
63
  return encoder_model, encoder_preprocess
64
 
 
76
 
77
  def main():
78
  seed_everything()
79
+ device = "cpu"
80
  tokenizer = BertTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
81
  visualbert_model = load_visualbert_model(tokenizer, device)
82
  encoder_model, encoder_preprocess = load_surgvlp_encoder(device)
83
+
84
+ print("Models loaded successfully.")
85
 
86
  # Define the directories containing images and corresponding label files.
87
  global image_files
88
+ images_dir = "./test_data/images/VID01"
89
  labels_dir = "./test_data/labels/VID01/"
90
  image_files = [os.path.join(images_dir, f) for f in sorted(os.listdir(images_dir)) if f.lower().endswith('.png')]
91
  random.shuffle(image_files)
92
+
93
+ print(f"Found {len(image_files)} images.")
94
+
95
  # Get first 20 images.
96
  image_files = image_files[:20]
97
 
 
204
  inputs=[image_gallery, question_dropdown],
205
  outputs=predictions_output
206
  )
207
+
208
+ print("Launching the Gradio UI...")
209
  demo.launch()
210
 
211
  if __name__ == "__main__":
212
+ main()