Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
7877974
1
Parent(s):
bd1381f
Debug mask images
Browse files
app.py
CHANGED
@@ -225,100 +225,99 @@ def draw_box(box: torch.Tensor, draw: ImageDraw.Draw, label: Optional[str]) -> N
|
|
225 |
|
226 |
def run_grounded_sam(input_image):
|
227 |
"""Main function to run GroundingDINO and SAM-HQ"""
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
scribble = np.array(input_image["mask"])
|
242 |
-
image_pil = input_image["image"].convert("RGB")
|
243 |
-
else:
|
244 |
-
# Direct image input
|
245 |
-
image_pil = input_image.convert("RGB") if input_image else None
|
246 |
-
scribble = None
|
247 |
-
|
248 |
-
if image_pil is None:
|
249 |
-
logger.error("No input image provided")
|
250 |
-
return [Image.new('RGB', (400, 300), color='gray')]
|
251 |
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
# Apply non-maximum suppression if we have multiple boxes
|
273 |
-
if boxes_filt.size(0) > 1:
|
274 |
-
logger.info(f"Before NMS: {boxes_filt.shape[0]} boxes")
|
275 |
-
nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
|
276 |
-
boxes_filt = boxes_filt[nms_idx]
|
277 |
-
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
|
278 |
-
logger.info(f"After NMS: {boxes_filt.shape[0]} boxes")
|
279 |
-
|
280 |
-
# Load SAM model
|
281 |
-
ModelManager.load_model('sam')
|
282 |
-
sam_predictor = ModelManager.get_model('sam_predictor')
|
283 |
-
|
284 |
-
# Set image for SAM
|
285 |
-
image = np.array(image_pil)
|
286 |
-
sam_predictor.set_image(image)
|
287 |
-
|
288 |
-
# Run SAM
|
289 |
-
# Use boxes for these task types
|
290 |
-
if boxes_filt.size(0) == 0:
|
291 |
-
logger.warning("No boxes detected")
|
292 |
-
return [image_pil, Image.new('RGBA', size, color=(0, 0, 0, 0))]
|
293 |
-
|
294 |
-
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
|
295 |
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
|
315 |
-
|
316 |
-
|
317 |
|
318 |
-
|
319 |
|
320 |
-
except Exception as e:
|
321 |
-
|
322 |
# # Return original image on error
|
323 |
# if isinstance(input_image, dict) and "image" in input_image:
|
324 |
# return [input_image["image"], Image.new('RGBA', input_image["image"].size, color=(0, 0, 0, 0))]
|
|
|
225 |
|
226 |
def run_grounded_sam(input_image):
|
227 |
"""Main function to run GroundingDINO and SAM-HQ"""
|
228 |
+
# Create output directory
|
229 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
230 |
+
text_prompt = 'car'
|
231 |
+
task_type = 'text'
|
232 |
+
box_threshold = 0.3
|
233 |
+
text_threshold = 0.25
|
234 |
+
iou_threshold = 0.8
|
235 |
+
hq_token_only = True
|
236 |
+
|
237 |
+
# Process input image
|
238 |
+
if isinstance(input_image, dict):
|
239 |
+
# Input from gradio sketch component
|
240 |
+
scribble = np.array(input_image["mask"])
|
241 |
+
image_pil = input_image["image"].convert("RGB")
|
242 |
+
else:
|
243 |
+
# Direct image input
|
244 |
+
image_pil = input_image.convert("RGB") if input_image else None
|
245 |
+
scribble = None
|
246 |
|
247 |
+
if image_pil is None:
|
248 |
+
logger.error("No input image provided")
|
249 |
+
return [Image.new('RGB', (400, 300), color='gray')]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
|
251 |
+
# Transform image for GroundingDINO
|
252 |
+
transformed_image = transform_image(image_pil)
|
253 |
+
|
254 |
+
# Load models as needed
|
255 |
+
ModelManager.load_model('groundingdino')
|
256 |
+
size = image_pil.size
|
257 |
+
H, W = size[1], size[0]
|
258 |
+
|
259 |
+
# Run GroundingDINO with provided text
|
260 |
+
boxes_filt, scores, pred_phrases = get_grounding_output(
|
261 |
+
transformed_image, text_prompt, box_threshold, text_threshold
|
262 |
+
)
|
263 |
|
264 |
+
if boxes_filt is not None:
|
265 |
+
# Scale boxes to image dimensions
|
266 |
+
for i in range(boxes_filt.size(0)):
|
267 |
+
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
|
268 |
+
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
|
269 |
+
boxes_filt[i][2:] += boxes_filt[i][:2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
|
271 |
+
# Apply non-maximum suppression if we have multiple boxes
|
272 |
+
if boxes_filt.size(0) > 1:
|
273 |
+
logger.info(f"Before NMS: {boxes_filt.shape[0]} boxes")
|
274 |
+
nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
|
275 |
+
boxes_filt = boxes_filt[nms_idx]
|
276 |
+
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
|
277 |
+
logger.info(f"After NMS: {boxes_filt.shape[0]} boxes")
|
278 |
+
|
279 |
+
# Load SAM model
|
280 |
+
ModelManager.load_model('sam')
|
281 |
+
sam_predictor = ModelManager.get_model('sam_predictor')
|
282 |
+
|
283 |
+
# Set image for SAM
|
284 |
+
image = np.array(image_pil)
|
285 |
+
sam_predictor.set_image(image)
|
286 |
+
|
287 |
+
# Run SAM
|
288 |
+
# Use boxes for these task types
|
289 |
+
if boxes_filt.size(0) == 0:
|
290 |
+
logger.warning("No boxes detected")
|
291 |
+
return [image_pil, Image.new('RGBA', size, color=(0, 0, 0, 0))]
|
292 |
+
|
293 |
+
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
|
294 |
+
|
295 |
+
masks, _, _ = sam_predictor.predict_torch(
|
296 |
+
point_coords=None,
|
297 |
+
point_labels=None,
|
298 |
+
boxes=transformed_boxes,
|
299 |
+
multimask_output=False,
|
300 |
+
hq_token_only=hq_token_only,
|
301 |
+
)
|
302 |
+
|
303 |
+
# Create mask image
|
304 |
+
mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
|
305 |
+
mask_draw = ImageDraw.Draw(mask_image)
|
306 |
+
|
307 |
+
# Draw masks
|
308 |
+
for mask in masks:
|
309 |
+
draw_mask(mask[0].cpu().numpy(), mask_draw)
|
310 |
+
|
311 |
+
# Draw boxes and points on original image
|
312 |
+
image_draw = ImageDraw.Draw(image_pil)
|
313 |
|
314 |
+
for box, label in zip(boxes_filt, pred_phrases):
|
315 |
+
draw_box(box, image_draw, label)
|
316 |
|
317 |
+
return mask_image
|
318 |
|
319 |
+
# except Exception as e:
|
320 |
+
# logger.error(f"Error in run_grounded_sam: {e}")
|
321 |
# # Return original image on error
|
322 |
# if isinstance(input_image, dict) and "image" in input_image:
|
323 |
# return [input_image["image"], Image.new('RGBA', input_image["image"].size, color=(0, 0, 0, 0))]
|