DurgaDeepak commited on
Commit
9538100
·
verified ·
1 Parent(s): ce2b58f

Update models/detection/detector.py

Browse files
Files changed (1) hide show
  1. models/detection/detector.py +8 -6
models/detection/detector.py CHANGED
@@ -1,9 +1,7 @@
1
- import os
2
- import numpy as np
3
- from PIL import Image, ImageDraw
4
  import logging
 
5
  from ultralytics import YOLO
6
- from utils.model_downloader import download_model_if_needed
7
 
8
  logger = logging.getLogger(__name__)
9
 
@@ -12,7 +10,7 @@ class ObjectDetector:
12
  Generalized Object Detection Wrapper for YOLOv5, YOLOv8, and future variants.
13
  """
14
 
15
- def __init__(self, model_key="yolov5n-seg", device="cpu"):
16
  """
17
  Initialize the Object Detection model.
18
 
@@ -21,10 +19,14 @@ class ObjectDetector:
21
  weights_dir (str): Directory to store/download model weights.
22
  device (str): Inference device ("cpu" or "cuda").
23
  """
 
24
  repo_map = {
25
  "yolov5n": ("ultralytics/yolov5", "yolov5n.pt"),
26
  "yolov8n": ("ultralytics/yolov8", "yolov8n.pt"),
27
- # Add more if needed
 
 
 
28
  }
29
 
30
  if model_key not in repo_map:
 
 
 
 
1
  import logging
2
+ from huggingface_hub import hf_hub_download
3
  from ultralytics import YOLO
4
+
5
 
6
  logger = logging.getLogger(__name__)
7
 
 
10
  Generalized Object Detection Wrapper for YOLOv5, YOLOv8, and future variants.
11
  """
12
 
13
+ def __init__(self, model_key="yolov5n", device="cpu"):
14
  """
15
  Initialize the Object Detection model.
16
 
 
19
  weights_dir (str): Directory to store/download model weights.
20
  device (str): Inference device ("cpu" or "cuda").
21
  """
22
+ model_key = model_key.lower() #to make model_key case-insensitive
23
  repo_map = {
24
  "yolov5n": ("ultralytics/yolov5", "yolov5n.pt"),
25
  "yolov8n": ("ultralytics/yolov8", "yolov8n.pt"),
26
+ "yolov8s": ("ultralytics/yolov8", "yolov8s.pt"),
27
+ "yolov8l": ("ultralytics/yolov8", "yolov8l.pt"),
28
+ "yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"),
29
+ "rtdetr": ("IDEA-Research/RT-DETR", "rtdetr_r50vd_detr.pth")
30
  }
31
 
32
  if model_key not in repo_map: