File size: 2,804 Bytes
ce2b58f
5567efd
9538100
ce2b58f
b0c7a24
8386bf1
9538100
b0c7a24
8386bf1
b0c7a24
 
 
8386bf1
b0c7a24
ce2b58f
b0c7a24
ce2b58f
b0c7a24
ce2b58f
 
b0c7a24
cdbafa3
ce2b58f
b0c7a24
4cfdbcf
b0c7a24
4cfdbcf
 
b0c7a24
4cfdbcf
0bd515d
0c0955d
ce2b58f
b0c7a24
 
 
 
 
 
 
cdbafa3
b0c7a24
 
e999761
b0c7a24
ce2b58f
b0c7a24
 
 
 
 
 
 
ce2b58f
b0c7a24
 
 
ce2b58f
b0c7a24
ce2b58f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import logging
from PIL import Image, ImageDraw
from huggingface_hub import hf_hub_download
from ultralytics import YOLO
import os
import shutil

# Setup logger
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

# Optional: clear weights cache each time (only for dev use)
shutil.rmtree("models/detection/weights", ignore_errors=True)

class ObjectDetector:
    def __init__(self, model_key="yolov8n", device="cpu"):
        """
        Initializes an Ultralytics YOLO model using HF download path.

        Args:
            model_key (str): e.g. 'yolov8n', 'yolov8s', etc.
            device (str): 'cpu' or 'cuda'
        """
        # Optional aliasing
        alias_map = {
            "yolov8n": "yolov8n",
            "yolov8s": "yolov8s",
            "yolov8l": "yolov8l",
            "yolov11b": "yolov11b"
        }

        resolved_key = model_key.lower().replace(".pt", "")

        # HF repo map
        hf_map = {
            "yolov8n": ("ultralytics/yolov8", "yolov8n.pt"),
            "yolov8s": ("ultralytics/yolov8", "yolov8s.pt"),
            "yolov8l": ("ultralytics/yolov8", "yolov8l.pt"),
            "yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"),
        }

        if resolved_key not in hf_map:
            raise ValueError(f"Unsupported model key: {resolved_key}")

        repo_id, filename = hf_map[resolved_key]

        # 📥 Download from HF Hub
        weights_path = hf_hub_download(
            repo_id=repo_id,
            filename=filename,
            cache_dir="models/detection/weights",
            force_download=True  # Optional: change to False for reuse
        )

        logger.info(f"✅ Loaded YOLO model: {resolved_key} from {weights_path}")
        self.device = device
        self.model = YOLO(weights_path)

    def predict(self, image: Image.Image, conf_threshold=0.25):
        logger.info("Running object detection")
        results = self.model(image)
        detections = []
        for r in results:
            for box in r.boxes:
                detections.append({
                    "class_name": r.names[int(box.cls)],
                    "confidence": float(box.conf),
                    "bbox": box.xyxy[0].tolist()
                })
        logger.info(f"Detected {len(detections)} objects")
        return detections

    def draw(self, image: Image.Image, detections, alpha=0.5):
        overlay = image.copy()
        draw = ImageDraw.Draw(overlay)
        for det in detections:
            bbox = det["bbox"]
            label = f'{det["class_name"]} {det["confidence"]:.2f}'
            draw.rectangle(bbox, outline="red", width=2)
            draw.text((bbox[0], bbox[1]), label, fill="red")
        return Image.blend(image, overlay, alpha)