dschandra commited on
Commit
977fa4b
·
verified ·
1 Parent(s): d42c2d6

Update detector.py

Browse files
Files changed (1) hide show
  1. detector.py +2 -0
detector.py CHANGED
@@ -5,10 +5,12 @@ import torch.serialization
5
 
6
  class LBWDetector:
7
  def __init__(self, model_path='best.pt'):
 
8
  with torch.serialization.safe_globals([torch.nn.modules.container.Sequential]):
9
  self.model = YOLO(model_path)
10
 
11
  def detect_objects(self, frame):
 
12
  results = self.model.predict(source=frame, conf=0.3, save=False, verbose=False)
13
  detections = results[0].boxes.data.cpu().numpy() # x1, y1, x2, y2, conf, class
14
  return detections, results[0].names
 
5
 
6
  class LBWDetector:
7
  def __init__(self, model_path='best.pt'):
8
+ """Initialize YOLO model with safe globals for PyTorch 2.6+."""
9
  with torch.serialization.safe_globals([torch.nn.modules.container.Sequential]):
10
  self.model = YOLO(model_path)
11
 
12
  def detect_objects(self, frame):
13
+ """Detect objects in a frame and return bounding boxes and class names."""
14
  results = self.model.predict(source=frame, conf=0.3, save=False, verbose=False)
15
  detections = results[0].boxes.data.cpu().numpy() # x1, y1, x2, y2, conf, class
16
  return detections, results[0].names