mlbench123 commited on
Commit
bacea87
·
verified ·
1 Parent(s): 70843b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -39
app.py CHANGED
@@ -86,9 +86,8 @@ paper_model_path = os.path.join(CACHE_DIR, "paper_detector.pt") # You'll need t
86
  u2net_model_path = os.path.join(CACHE_DIR, "u2netp.pth")
87
 
88
  # Global variable for YOLOWorld
89
- yolo_world_global = None
90
- yolo_world_model_path = os.path.join(CACHE_DIR, "yolov8s_world.pt") # Adjust path as needed
91
-
92
 
93
 
94
  # Device configuration
@@ -108,11 +107,7 @@ def ensure_model_files():
108
  shutil.copy("u2netp.pth", u2net_model_path)
109
  else:
110
  raise FileNotFoundError("u2netp.pth model file not found")
111
- if not os.path.exists(yolo_world_model_path):
112
- if os.path.exists("yolov8s_world.pt"): # Adjust to match your file name
113
- shutil.copy("yolov8s_world.pt", yolo_world_model_path)
114
- else:
115
- logger.warning("yolov8s-world.pt model file not found - falling back to full image processing")
116
 
117
  ensure_model_files()
118
 
@@ -134,22 +129,18 @@ def get_paper_detector():
134
  logger.warning("Paper model file not found, using fallback detection")
135
  paper_detector_global = None
136
  return paper_detector_global
137
- def get_yolo_world():
138
- """Lazy load YOLOWorld model"""
139
- global yolo_world_global
140
- if yolo_world_global is None:
141
- logger.info("Loading YOLOWorld model...")
142
- if os.path.exists(yolo_world_model_path):
143
- try:
144
- yolo_world_global = YOLOWorld(yolo_world_model_path)
145
- logger.info("YOLOWorld model loaded successfully")
146
- except Exception as e:
147
- logger.error(f"Failed to load YOLOWorld: {e}")
148
- yolo_world_global = None
149
- else:
150
- logger.warning("YOLOWorld model file not found, will raise error if used")
151
- yolo_world_global = None
152
- return yolo_world_global
153
  def get_u2net():
154
  """Lazy load U2NETP model"""
155
  global u2net_global
@@ -976,46 +967,43 @@ def predict_with_paper(image, paper_size, offset, offset_unit, finger_clearance=
976
  # Mask paper area in input image first
977
  masked_input_image = mask_paper_area_in_image(image, paper_contour)
978
 
979
- # Use YOLOWorld to detect object bounding box
980
- yolo_world = get_yolo_world()
981
- # Lower confidence and add size-based filtering
982
- if yolo_world is None:
983
- logger.warning("YOLOWorld model not available, proceeding with full image")
984
  cropped_image = masked_input_image
985
  crop_offset = (0, 0)
986
  else:
987
- yolo_world.set_classes(["small object", "tool", "item", "component", "part", "piece", "device"])
988
- results = yolo_world.predict(masked_input_image, conf=0.05, verbose=False) # Much lower confidence
989
 
990
  if not results or len(results) == 0 or not hasattr(results[0], 'boxes') or len(results[0].boxes) == 0:
991
- logger.warning("No objects detected by YOLOWorld, proceeding with full image")
992
  cropped_image = masked_input_image
993
  crop_offset = (0, 0)
994
  else:
995
  boxes = results[0].boxes.xyxy.cpu().numpy()
996
  confidences = results[0].boxes.conf.cpu().numpy()
997
 
998
- # Filter out boxes that are too large (likely paper detection)
999
- valid_boxes = []
1000
  image_area = masked_input_image.shape[0] * masked_input_image.shape[1]
 
1001
 
1002
  for i, box in enumerate(boxes):
1003
  x_min, y_min, x_max, y_max = box
1004
  box_area = (x_max - x_min) * (y_max - y_min)
1005
- if box_area < image_area * 0.3: # Reject if larger than 30% of image
 
1006
  valid_boxes.append((i, confidences[i]))
1007
 
1008
  if not valid_boxes:
 
1009
  cropped_image = masked_input_image
1010
  crop_offset = (0, 0)
1011
  else:
1012
  # Get highest confidence valid box
1013
  best_idx = max(valid_boxes, key=lambda x: x[1])[0]
1014
  x_min, y_min, x_max, y_max = map(int, boxes[best_idx])
1015
-
1016
- # Larger margin for small objects
1017
- box_size = min(x_max - x_min, y_max - y_min)
1018
- margin = max(30, int(box_size * 0.3)) # At least 30px margin
1019
 
1020
  # Remove background from cropped image
1021
  orig_size = image.shape[:2]
 
86
  u2net_model_path = os.path.join(CACHE_DIR, "u2netp.pth")
87
 
88
  # Global variable for YOLOWorld
89
+ yolo_v8_global = None
90
+ yolo_v8_model_path = os.path.join(CACHE_DIR, "yolov8s.pt") # Adjust path as needed
 
91
 
92
 
93
  # Device configuration
 
107
  shutil.copy("u2netp.pth", u2net_model_path)
108
  else:
109
  raise FileNotFoundError("u2netp.pth model file not found")
110
+ logger.info("YOLOv8 will auto-download if not present")
 
 
 
 
111
 
112
  ensure_model_files()
113
 
 
129
  logger.warning("Paper model file not found, using fallback detection")
130
  paper_detector_global = None
131
  return paper_detector_global
132
+ def get_yolo_v8():
133
+ """Lazy load YOLOv8 model"""
134
+ global yolo_v8_global
135
+ if yolo_v8_global is None:
136
+ logger.info("Loading YOLOv8 model...")
137
+ try:
138
+ yolo_v8_global = YOLO(yolo_v8_model_path) # Auto-downloads if needed
139
+ logger.info("YOLOv8 model loaded successfully")
140
+ except Exception as e:
141
+ logger.error(f"Failed to load YOLOv8: {e}")
142
+ yolo_v8_global = None
143
+ return yolo_v8_global
 
 
 
 
144
  def get_u2net():
145
  """Lazy load U2NETP model"""
146
  global u2net_global
 
967
  # Mask paper area in input image first
968
  masked_input_image = mask_paper_area_in_image(image, paper_contour)
969
 
970
+ # Use YOLOv8 to detect objects
971
+ yolo_v8 = get_yolo_v8()
972
+ if yolo_v8 is None:
973
+ logger.warning("YOLOv8 model not available, proceeding with full image")
 
974
  cropped_image = masked_input_image
975
  crop_offset = (0, 0)
976
  else:
977
+ # YOLOv8 detects all COCO classes by default
978
+ results = yolo_v8.predict(masked_input_image, conf=0.1, verbose=False)
979
 
980
  if not results or len(results) == 0 or not hasattr(results[0], 'boxes') or len(results[0].boxes) == 0:
981
+ logger.warning("No objects detected by YOLOv8, proceeding with full image")
982
  cropped_image = masked_input_image
983
  crop_offset = (0, 0)
984
  else:
985
  boxes = results[0].boxes.xyxy.cpu().numpy()
986
  confidences = results[0].boxes.conf.cpu().numpy()
987
 
988
+ # Filter out very large boxes (likely paper/background)
 
989
  image_area = masked_input_image.shape[0] * masked_input_image.shape[1]
990
+ valid_boxes = []
991
 
992
  for i, box in enumerate(boxes):
993
  x_min, y_min, x_max, y_max = box
994
  box_area = (x_max - x_min) * (y_max - y_min)
995
+ # Keep boxes that are 5% to 40% of image area
996
+ if 0.05 * image_area < box_area < 0.4 * image_area:
997
  valid_boxes.append((i, confidences[i]))
998
 
999
  if not valid_boxes:
1000
+ logger.warning("No valid objects detected, proceeding with full image")
1001
  cropped_image = masked_input_image
1002
  crop_offset = (0, 0)
1003
  else:
1004
  # Get highest confidence valid box
1005
  best_idx = max(valid_boxes, key=lambda x: x[1])[0]
1006
  x_min, y_min, x_max, y_max = map(int, boxes[best_idx])
 
 
 
 
1007
 
1008
  # Remove background from cropped image
1009
  orig_size = image.shape[:2]