DurgaDeepak commited on
Commit
cdbafa3
·
verified ·
1 Parent(s): b3e594d

Update models/detection/detector.py

Browse files
Files changed (1) hide show
  1. models/detection/detector.py +7 -36
models/detection/detector.py CHANGED
@@ -9,18 +9,13 @@ import shutil
9
  logger = logging.getLogger(__name__)
10
  shutil.rmtree("models/detection/weights", ignore_errors=True)
11
  class ObjectDetector:
12
- """
13
- Generalized Object Detection Wrapper for YOLOv5, YOLOv8, and future variants.
14
- """
15
-
16
  def __init__(self, model_key="yolov5n", device="cpu"):
17
  """
18
- Initialize the Object Detection model.
19
 
20
  Args:
21
- model_key (str): Model identifier as defined in model_downloader.py.
22
- weights_dir (str): Directory to store/download model weights.
23
- device (str): Inference device ("cpu" or "cuda").
24
  """
25
  alias_map = {
26
  "yolov5n-seg": "yolov5n",
@@ -28,39 +23,15 @@ class ObjectDetector:
28
  "yolov8s": "yolov8s",
29
  "yolov8l": "yolov8l",
30
  "yolov11b": "yolov11b",
31
- "rtdetr": "rtdetr"
32
  }
33
 
34
  raw_key = model_key.lower()
35
- model_key = alias_map.get(raw_key, raw_key) #to make model_key case-insensitive
36
- repo_map = {
37
- "yolov5n": ("ultralytics/yolov5", "yolov5n.pt"),
38
- "yolov5s": ("ultralytics/yolov5", "yolov5s.pt"),
39
- "yolov8n": ("ultralytics/yolov8", "yolov8n.pt"),
40
- "yolov8s": ("ultralytics/yolov8", "yolov8s.pt"),
41
- "yolov8l": ("ultralytics/yolov8", "yolov8l.pt"),
42
- "yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"),
43
- "rtdetr": ("IDEA-Research/RT-DETR", "rtdetr_r50vd_detr.pth")
44
- }
45
-
46
- if model_key not in repo_map:
47
- raise ValueError(f"Unsupported model_key: {model_key}")
48
-
49
- repo_id, filename = repo_map[model_key]
50
-
51
- weights_path = hf_hub_download(
52
- repo_id=repo_id,
53
- filename=filename,
54
- cache_dir="models/detection/weights",
55
- force_download=True #Clear cache
56
- )
57
-
58
 
59
  self.device = device
60
- print("Loading weights from:", weights_path)
61
- self.model = YOLO(weights_path)
62
- print("Model object type:", type(self.model))
63
- print("Model class string:", self.model.__class__)
64
 
65
 
66
  def predict(self, image: Image.Image):
 
9
  logger = logging.getLogger(__name__)
10
  shutil.rmtree("models/detection/weights", ignore_errors=True)
11
  class ObjectDetector:
 
 
 
 
12
  def __init__(self, model_key="yolov5n", device="cpu"):
13
  """
14
+ Initialize the Object Detection model using Ultralytics YOLO registry.
15
 
16
  Args:
17
+ model_key (str): Model name supported by ultralytics, e.g. 'yolov5n', 'yolov8s', etc.
18
+ device (str): 'cpu' or 'cuda'
 
19
  """
20
  alias_map = {
21
  "yolov5n-seg": "yolov5n",
 
23
  "yolov8s": "yolov8s",
24
  "yolov8l": "yolov8l",
25
  "yolov11b": "yolov11b",
 
26
  }
27
 
28
  raw_key = model_key.lower()
29
+ resolved_key = alias_map.get(raw_key, raw_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  self.device = device
32
+ self.model = YOLO(resolved_key)
33
+ logger.info(f" Ultralytics YOLO model '{resolved_key}' initialized on {device}")
34
+
 
35
 
36
 
37
  def predict(self, image: Image.Image):