mlbench123 commited on
Commit
af7ebad
·
verified ·
1 Parent(s): 2a0abec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -2
app.py CHANGED
@@ -42,6 +42,11 @@ PAPER_SIZES = {
42
  "US Letter": {"width": 215.9, "height": 279.4}
43
  }
44
 
 
 
 
 
 
45
  # Custom Exception Classes
46
  class TimeoutReachedError(Exception):
47
  pass
@@ -520,7 +525,7 @@ def remove_bg_u2netp(image: np.ndarray) -> np.ndarray:
520
  logger.error(f"Error in U2NETP background removal: {e}")
521
  raise
522
 
523
- def remove_bg(image: np.ndarray) -> np.ndarray:
524
  """Remove background using BiRefNet model for main objects"""
525
  try:
526
  birefnet_model = get_birefnet()
@@ -547,6 +552,36 @@ def remove_bg(image: np.ndarray) -> np.ndarray:
547
  except Exception as e:
548
  logger.error(f"Error in BiRefNet background removal: {e}")
549
  raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
 
551
  # def exclude_paper_area(mask: np.ndarray, paper_contour: np.ndarray, expansion_factor: float = 1.2) -> np.ndarray:
552
  # """
@@ -1239,7 +1274,8 @@ def predict_with_paper(image, paper_size, offset, offset_unit, finger_clearance=
1239
 
1240
  # Remove background from main objects
1241
  orig_size = image.shape[:2]
1242
- objects_mask = remove_bg(image)
 
1243
  processed_size = objects_mask.shape[:2]
1244
 
1245
  # Resize mask to match original image
 
42
  "US Letter": {"width": 215.9, "height": 279.4}
43
  }
44
 
45
+ def get_yolo_world():
46
+ """Lazy load YOLO-World model"""
47
+ yolo_world = YOLOWorld('yolov8s-world.pt') # Smaller model for efficiency
48
+ return yolo_world
49
+
50
  # Custom Exception Classes
51
  class TimeoutReachedError(Exception):
52
  pass
 
525
  logger.error(f"Error in U2NETP background removal: {e}")
526
  raise
527
 
528
+ def remove_bg_original(image: np.ndarray) -> np.ndarray:
529
  """Remove background using BiRefNet model for main objects"""
530
  try:
531
  birefnet_model = get_birefnet()
 
552
  except Exception as e:
553
  logger.error(f"Error in BiRefNet background removal: {e}")
554
  raise
555
+ def remove_bg(image: np.ndarray, paper_contour: np.ndarray) -> np.ndarray:
556
+ """Remove background using either BiRefNet or YOLO-World based on object size"""
557
+ # Calculate paper area percentage
558
+ paper_area = cv2.contourArea(paper_contour)
559
+ image_area = image.shape[0] * image.shape[1]
560
+ paper_ratio = paper_area / image_area
561
+
562
+ # If paper takes up most of the image (small object case)
563
+ if paper_ratio > 0.7:
564
+ logger.info("Using YOLO-World for small object detection")
565
+ try:
566
+ model = get_yolo_world()
567
+ results = model.predict(image, conf=0.3, verbose=False)
568
+
569
+ # Create blank mask
570
+ mask = np.zeros(image.shape[:2], dtype=np.uint8)
571
+
572
+ # Draw all detected objects
573
+ for box in results[0].boxes.xyxy.cpu().numpy():
574
+ x1, y1, x2, y2 = map(int, box)
575
+ mask[y1:y2, x1:x2] = 255
576
+
577
+ return mask
578
+
579
+ except Exception as e:
580
+ logger.warning(f"YOLO-World failed, falling back to BiRefNet: {e}")
581
+
582
+ # Default case - use BiRefNet/U2NET
583
+ logger.info("Using BiRefNet for standard object detection")
584
+ return remove_bg_original(image) # Your existing BiRefNet implementation
585
 
586
  # def exclude_paper_area(mask: np.ndarray, paper_contour: np.ndarray, expansion_factor: float = 1.2) -> np.ndarray:
587
  # """
 
1274
 
1275
  # Remove background from main objects
1276
  orig_size = image.shape[:2]
1277
+ # objects_mask = remove_bg(image)
1278
+ objects_mask = remove_bg(image, paper_contour)
1279
  processed_size = objects_mask.shape[:2]
1280
 
1281
  # Resize mask to match original image