helloworld-S commited on
Commit
bfaf6e1
·
verified ·
1 Parent(s): c31ec0f

Update eval/grounded_sam/grounded_sam2_florence2_autolabel_pipeline.py

Browse files
eval/grounded_sam/grounded_sam2_florence2_autolabel_pipeline.py CHANGED
@@ -10,7 +10,7 @@ import sys
10
 
11
  from eval.grounded_sam.florence2.modeling_florence2 import Florence2ForConditionalGeneration
12
  from eval.grounded_sam.florence2.processing_florence2 import Florence2Processor
13
- from eval.grounded_sam.sam2.build_sam import build_sam2
14
  from eval.grounded_sam.sam2.sam2_image_predictor import SAM2ImagePredictor
15
 
16
 
@@ -61,7 +61,6 @@ class FlorenceSAM:
61
 
62
  FLORENCE2_MODEL_ID = os.getenv('FLORENCE2_MODEL_PATH', "microsoft/Florence-2-large")
63
  SAM2_CHECKPOINT = os.getenv('SAM2_MODEL_PATH', "facebook/sam2-hiera-large")
64
- SAM2_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
65
 
66
  self.florence2_model = Florence2ForConditionalGeneration.from_pretrained(
67
  FLORENCE2_MODEL_ID,
@@ -70,7 +69,7 @@ class FlorenceSAM:
70
  self.florence2_processor = Florence2Processor.from_pretrained(
71
  FLORENCE2_MODEL_ID,
72
  )
73
- sam2_model = build_sam2(SAM2_CONFIG, SAM2_CHECKPOINT, device=self.device)
74
  self.sam2_predictor = SAM2ImagePredictor(sam2_model)
75
 
76
  def __str__(self):
 
10
 
11
  from eval.grounded_sam.florence2.modeling_florence2 import Florence2ForConditionalGeneration
12
  from eval.grounded_sam.florence2.processing_florence2 import Florence2Processor
13
+ from eval.grounded_sam.sam2.build_sam import build_sam2, build_sam2_hf
14
  from eval.grounded_sam.sam2.sam2_image_predictor import SAM2ImagePredictor
15
 
16
 
 
61
 
62
  FLORENCE2_MODEL_ID = os.getenv('FLORENCE2_MODEL_PATH', "microsoft/Florence-2-large")
63
  SAM2_CHECKPOINT = os.getenv('SAM2_MODEL_PATH', "facebook/sam2-hiera-large")
 
64
 
65
  self.florence2_model = Florence2ForConditionalGeneration.from_pretrained(
66
  FLORENCE2_MODEL_ID,
 
69
  self.florence2_processor = Florence2Processor.from_pretrained(
70
  FLORENCE2_MODEL_ID,
71
  )
72
+ sam2_model = build_sam2_hf(SAM2_CHECKPOINT, device=self.device)
73
  self.sam2_predictor = SAM2ImagePredictor(sam2_model)
74
 
75
  def __str__(self):