Spaces:
Running
on
Zero
Running
on
Zero
Update eval/grounded_sam/grounded_sam2_florence2_autolabel_pipeline.py
Browse files
eval/grounded_sam/grounded_sam2_florence2_autolabel_pipeline.py
CHANGED
@@ -60,7 +60,7 @@ class FlorenceSAM:
|
|
60 |
self.torch_dtype = torch.bfloat16
|
61 |
|
62 |
FLORENCE2_MODEL_ID = os.getenv('FLORENCE2_MODEL_PATH', "microsoft/Florence-2-large")
|
63 |
-
SAM2_CHECKPOINT = os.getenv('SAM2_MODEL_PATH')
|
64 |
SAM2_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
65 |
|
66 |
self.florence2_model = Florence2ForConditionalGeneration.from_pretrained(
|
@@ -127,7 +127,7 @@ class FlorenceSAM:
|
|
127 |
|
128 |
def segmentation(self, image, input_boxes, seg_model="sam"):
|
129 |
if seg_model == "sam":
|
130 |
-
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.
|
131 |
sam2_predictor = self.sam2_predictor
|
132 |
sam2_predictor.set_image(np.array(image))
|
133 |
masks, scores, logits = sam2_predictor.predict(
|
|
|
60 |
self.torch_dtype = torch.bfloat16
|
61 |
|
62 |
FLORENCE2_MODEL_ID = os.getenv('FLORENCE2_MODEL_PATH', "microsoft/Florence-2-large")
|
63 |
+
SAM2_CHECKPOINT = os.getenv('SAM2_MODEL_PATH', "facebook/sam2-hiera-large")
|
64 |
SAM2_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
65 |
|
66 |
self.florence2_model = Florence2ForConditionalGeneration.from_pretrained(
|
|
|
127 |
|
128 |
def segmentation(self, image, input_boxes, seg_model="sam"):
|
129 |
if seg_model == "sam":
|
130 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
131 |
sam2_predictor = self.sam2_predictor
|
132 |
sam2_predictor.set_image(np.array(image))
|
133 |
masks, scores, logits = sam2_predictor.predict(
|