helloworld-S commited on
Commit
c31ec0f
·
verified ·
1 Parent(s): 2f2e9f1

Update eval/grounded_sam/grounded_sam2_florence2_autolabel_pipeline.py

Browse files
eval/grounded_sam/grounded_sam2_florence2_autolabel_pipeline.py CHANGED
@@ -60,7 +60,7 @@ class FlorenceSAM:
60
  self.torch_dtype = torch.bfloat16
61
 
62
  FLORENCE2_MODEL_ID = os.getenv('FLORENCE2_MODEL_PATH', "microsoft/Florence-2-large")
63
- SAM2_CHECKPOINT = os.getenv('SAM2_MODEL_PATH')
64
  SAM2_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
65
 
66
  self.florence2_model = Florence2ForConditionalGeneration.from_pretrained(
@@ -127,7 +127,7 @@ class FlorenceSAM:
127
 
128
  def segmentation(self, image, input_boxes, seg_model="sam"):
129
  if seg_model == "sam":
130
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float32):
131
  sam2_predictor = self.sam2_predictor
132
  sam2_predictor.set_image(np.array(image))
133
  masks, scores, logits = sam2_predictor.predict(
 
60
  self.torch_dtype = torch.bfloat16
61
 
62
  FLORENCE2_MODEL_ID = os.getenv('FLORENCE2_MODEL_PATH', "microsoft/Florence-2-large")
63
+ SAM2_CHECKPOINT = os.getenv('SAM2_MODEL_PATH', "facebook/sam2-hiera-large")
64
  SAM2_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
65
 
66
  self.florence2_model = Florence2ForConditionalGeneration.from_pretrained(
 
127
 
128
  def segmentation(self, image, input_boxes, seg_model="sam"):
129
  if seg_model == "sam":
130
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
131
  sam2_predictor = self.sam2_predictor
132
  sam2_predictor.set_image(np.array(image))
133
  masks, scores, logits = sam2_predictor.predict(