✨ [Add] A stream dataloader for webcam
Browse files- requirements.txt +1 -0
- yolo/config/config.py +1 -0
- yolo/config/task/dataset/demo.yaml +0 -3
- yolo/config/task/inference.yaml +2 -3
- yolo/tools/data_loader.py +100 -2
- yolo/tools/drawer.py +2 -1
- yolo/tools/solver.py +22 -10
- yolo/utils/bounding_box_utils.py +1 -1
requirements.txt
CHANGED
|
@@ -3,6 +3,7 @@ graphviz
|
|
| 3 |
hydra-core
|
| 4 |
loguru
|
| 5 |
numpy
|
|
|
|
| 6 |
Pillow
|
| 7 |
pytest
|
| 8 |
pyyaml
|
|
|
|
| 3 |
hydra-core
|
| 4 |
loguru
|
| 5 |
numpy
|
| 6 |
+
opencv-python
|
| 7 |
Pillow
|
| 8 |
pytest
|
| 9 |
pyyaml
|
yolo/config/config.py
CHANGED
|
@@ -113,6 +113,7 @@ class NMSConfig:
|
|
| 113 |
@dataclass
|
| 114 |
class InferenceConfig:
|
| 115 |
task: str
|
|
|
|
| 116 |
nms: NMSConfig
|
| 117 |
|
| 118 |
|
|
|
|
| 113 |
@dataclass
|
| 114 |
class InferenceConfig:
|
| 115 |
task: str
|
| 116 |
+
source: Union[str, int]
|
| 117 |
nms: NMSConfig
|
| 118 |
|
| 119 |
|
yolo/config/task/dataset/demo.yaml
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
path: demo
|
| 2 |
-
|
| 3 |
-
auto_download:
|
|
|
|
|
|
|
|
|
|
|
|
yolo/config/task/inference.yaml
CHANGED
|
@@ -1,11 +1,10 @@
|
|
| 1 |
task: inference
|
| 2 |
-
|
| 3 |
-
- dataset: demo
|
| 4 |
data:
|
| 5 |
batch_size: 16
|
| 6 |
shuffle: False
|
| 7 |
pin_memory: True
|
| 8 |
data_augment: {}
|
| 9 |
nms:
|
| 10 |
-
min_confidence: 0.
|
| 11 |
min_iou: 0.5
|
|
|
|
| 1 |
task: inference
|
| 2 |
+
source: demo/images/inference/image.png
|
|
|
|
| 3 |
data:
|
| 4 |
batch_size: 16
|
| 5 |
shuffle: False
|
| 6 |
pin_memory: True
|
| 7 |
data_augment: {}
|
| 8 |
nms:
|
| 9 |
+
min_confidence: 0.1
|
| 10 |
min_iou: 0.5
|
yolo/tools/data_loader.py
CHANGED
|
@@ -1,16 +1,19 @@
|
|
| 1 |
import os
|
| 2 |
from os import path
|
| 3 |
-
from
|
|
|
|
|
|
|
| 4 |
|
|
|
|
| 5 |
import hydra
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
| 8 |
from loguru import logger
|
| 9 |
from PIL import Image
|
| 10 |
from rich.progress import track
|
|
|
|
| 11 |
from torch.utils.data import DataLoader, Dataset
|
| 12 |
from torchvision.transforms import functional as TF
|
| 13 |
-
from tqdm.rich import tqdm
|
| 14 |
|
| 15 |
from yolo.config.config import Config, TrainConfig
|
| 16 |
from yolo.tools.data_augmentation import (
|
|
@@ -199,12 +202,107 @@ class YoloDataLoader(DataLoader):
|
|
| 199 |
|
| 200 |
|
| 201 |
def create_dataloader(config: Config):
|
|
|
|
|
|
|
|
|
|
| 202 |
if config.task.dataset.auto_download:
|
| 203 |
prepare_dataset(config.task.dataset)
|
| 204 |
|
| 205 |
return YoloDataLoader(config)
|
| 206 |
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
@hydra.main(config_path="../config", config_name="config", version_base=None)
|
| 209 |
def main(cfg):
|
| 210 |
dataloader = create_dataloader(cfg)
|
|
|
|
| 1 |
import os
|
| 2 |
from os import path
|
| 3 |
+
from queue import Empty, Queue
|
| 4 |
+
from threading import Event, Thread
|
| 5 |
+
from typing import Generator, List, Optional, Tuple, Union
|
| 6 |
|
| 7 |
+
import cv2
|
| 8 |
import hydra
|
| 9 |
import numpy as np
|
| 10 |
import torch
|
| 11 |
from loguru import logger
|
| 12 |
from PIL import Image
|
| 13 |
from rich.progress import track
|
| 14 |
+
from torch import Tensor
|
| 15 |
from torch.utils.data import DataLoader, Dataset
|
| 16 |
from torchvision.transforms import functional as TF
|
|
|
|
| 17 |
|
| 18 |
from yolo.config.config import Config, TrainConfig
|
| 19 |
from yolo.tools.data_augmentation import (
|
|
|
|
| 202 |
|
| 203 |
|
| 204 |
def create_dataloader(config: Config):
|
| 205 |
+
if config.task.task == "inference":
|
| 206 |
+
return StreamDataLoader(config)
|
| 207 |
+
|
| 208 |
if config.task.dataset.auto_download:
|
| 209 |
prepare_dataset(config.task.dataset)
|
| 210 |
|
| 211 |
return YoloDataLoader(config)
|
| 212 |
|
| 213 |
|
| 214 |
+
class StreamDataLoader:
|
| 215 |
+
def __init__(self, config: Config):
|
| 216 |
+
self.source = config.task.source
|
| 217 |
+
self.running = True
|
| 218 |
+
self.is_stream = isinstance(self.source, int) or self.source.lower().startswith("rtmp://")
|
| 219 |
+
|
| 220 |
+
self.transform = AugmentationComposer([], config.image_size[0])
|
| 221 |
+
self.stop_event = Event()
|
| 222 |
+
|
| 223 |
+
if self.is_stream:
|
| 224 |
+
self.cap = cv2.VideoCapture(self.source)
|
| 225 |
+
else:
|
| 226 |
+
self.queue = Queue()
|
| 227 |
+
self.thread = Thread(target=self.load_source)
|
| 228 |
+
self.thread.start()
|
| 229 |
+
|
| 230 |
+
def load_source(self):
|
| 231 |
+
if os.path.isdir(self.source): # image folder
|
| 232 |
+
self.load_image_folder(self.source)
|
| 233 |
+
elif any(self.source.lower().endswith(ext) for ext in [".mp4", ".avi", ".mkv"]): # Video file
|
| 234 |
+
self.load_video_file(self.source)
|
| 235 |
+
else: # Single image
|
| 236 |
+
self.process_image(self.source)
|
| 237 |
+
|
| 238 |
+
def load_image_folder(self, folder):
|
| 239 |
+
for root, _, files in os.walk(folder):
|
| 240 |
+
for file in files:
|
| 241 |
+
if self.stop_event.is_set():
|
| 242 |
+
break
|
| 243 |
+
if any(file.lower().endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".bmp"]):
|
| 244 |
+
self.process_image(os.path.join(root, file))
|
| 245 |
+
|
| 246 |
+
def process_image(self, image_path):
|
| 247 |
+
image = Image.open(image_path).convert("RGB")
|
| 248 |
+
if image is None:
|
| 249 |
+
raise ValueError(f"Error loading image: {image_path}")
|
| 250 |
+
self.process_frame(image)
|
| 251 |
+
|
| 252 |
+
def load_video_file(self, video_path):
|
| 253 |
+
cap = cv2.VideoCapture(video_path)
|
| 254 |
+
while self.running:
|
| 255 |
+
ret, frame = cap.read()
|
| 256 |
+
if not ret:
|
| 257 |
+
break
|
| 258 |
+
self.process_frame(frame)
|
| 259 |
+
cap.release()
|
| 260 |
+
|
| 261 |
+
def cv2_to_tensor(self, frame: np.ndarray) -> Tensor:
|
| 262 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 263 |
+
frame_float = frame_rgb.astype("float32") / 255.0
|
| 264 |
+
return torch.from_numpy(frame_float).permute(2, 0, 1)[None]
|
| 265 |
+
|
| 266 |
+
def process_frame(self, frame):
|
| 267 |
+
if isinstance(frame, np.ndarray):
|
| 268 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 269 |
+
frame = Image.fromarray(frame)
|
| 270 |
+
frame, _ = self.transform(frame, torch.zeros(0, 5))
|
| 271 |
+
frame = TF.to_tensor(frame)[None]
|
| 272 |
+
if not self.is_stream:
|
| 273 |
+
self.queue.put(frame)
|
| 274 |
+
else:
|
| 275 |
+
self.current_frame = frame
|
| 276 |
+
|
| 277 |
+
def __iter__(self) -> Generator[Tensor, None, None]:
|
| 278 |
+
return self
|
| 279 |
+
|
| 280 |
+
def __next__(self) -> Tensor:
|
| 281 |
+
if self.is_stream:
|
| 282 |
+
ret, frame = self.cap.read()
|
| 283 |
+
if not ret:
|
| 284 |
+
self.stop()
|
| 285 |
+
raise StopIteration
|
| 286 |
+
self.process_frame(frame)
|
| 287 |
+
return self.current_frame
|
| 288 |
+
else:
|
| 289 |
+
try:
|
| 290 |
+
frame = self.queue.get(timeout=1)
|
| 291 |
+
return frame
|
| 292 |
+
except Empty:
|
| 293 |
+
raise StopIteration
|
| 294 |
+
|
| 295 |
+
def stop(self):
|
| 296 |
+
self.running = False
|
| 297 |
+
if self.is_stream:
|
| 298 |
+
self.cap.release()
|
| 299 |
+
else:
|
| 300 |
+
self.thread.join(timeout=1)
|
| 301 |
+
|
| 302 |
+
def __len__(self):
|
| 303 |
+
return self.queue.qsize() if not self.is_stream else 0
|
| 304 |
+
|
| 305 |
+
|
| 306 |
@hydra.main(config_path="../config", config_name="config", version_base=None)
|
| 307 |
def main(cfg):
|
| 308 |
dataloader = create_dataloader(cfg)
|
yolo/tools/drawer.py
CHANGED
|
@@ -14,6 +14,7 @@ def draw_bboxes(
|
|
| 14 |
*,
|
| 15 |
scaled_bbox: bool = True,
|
| 16 |
save_path: str = "",
|
|
|
|
| 17 |
):
|
| 18 |
"""
|
| 19 |
Draw bounding boxes on an image.
|
|
@@ -46,7 +47,7 @@ def draw_bboxes(
|
|
| 46 |
draw.rectangle(shape, outline="red", width=3)
|
| 47 |
draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
|
| 48 |
|
| 49 |
-
save_image_path = os.path.join(save_path,
|
| 50 |
img.save(save_image_path) # Save the image with annotations
|
| 51 |
logger.info(f"💾 Saved visualize image at {save_image_path}")
|
| 52 |
return img
|
|
|
|
| 14 |
*,
|
| 15 |
scaled_bbox: bool = True,
|
| 16 |
save_path: str = "",
|
| 17 |
+
save_name: str = "visualize.png",
|
| 18 |
):
|
| 19 |
"""
|
| 20 |
Draw bounding boxes on an image.
|
|
|
|
| 47 |
draw.rectangle(shape, outline="red", width=3)
|
| 48 |
draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
|
| 49 |
|
| 50 |
+
save_image_path = os.path.join(save_path, save_name)
|
| 51 |
img.save(save_image_path) # Save the image with annotations
|
| 52 |
logger.info(f"💾 Saved visualize image at {save_image_path}")
|
| 53 |
return img
|
yolo/tools/solver.py
CHANGED
|
@@ -7,6 +7,7 @@ from torch.cuda.amp import GradScaler, autocast
|
|
| 7 |
|
| 8 |
from yolo.config.config import Config, TrainConfig
|
| 9 |
from yolo.model.yolo import YOLO
|
|
|
|
| 10 |
from yolo.tools.drawer import draw_bboxes
|
| 11 |
from yolo.tools.loss_functions import get_loss_function
|
| 12 |
from yolo.utils.bounding_box_utils import AnchorBoxConverter, bbox_nms
|
|
@@ -103,15 +104,26 @@ class ModelTester:
|
|
| 103 |
self.nms = cfg.task.nms
|
| 104 |
self.save_path = save_path
|
| 105 |
|
| 106 |
-
def solve(self, dataloader):
|
| 107 |
logger.info("👀 Start Inference!")
|
| 108 |
|
| 109 |
-
|
| 110 |
-
images
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
from yolo.config.config import Config, TrainConfig
|
| 9 |
from yolo.model.yolo import YOLO
|
| 10 |
+
from yolo.tools.data_loader import StreamDataLoader
|
| 11 |
from yolo.tools.drawer import draw_bboxes
|
| 12 |
from yolo.tools.loss_functions import get_loss_function
|
| 13 |
from yolo.utils.bounding_box_utils import AnchorBoxConverter, bbox_nms
|
|
|
|
| 104 |
self.nms = cfg.task.nms
|
| 105 |
self.save_path = save_path
|
| 106 |
|
| 107 |
+
def solve(self, dataloader: StreamDataLoader):
|
| 108 |
logger.info("👀 Start Inference!")
|
| 109 |
|
| 110 |
+
try:
|
| 111 |
+
for idx, images in enumerate(dataloader):
|
| 112 |
+
images = images.to(self.device)
|
| 113 |
+
with torch.no_grad():
|
| 114 |
+
raw_output = self.model(images)
|
| 115 |
+
predict, _ = self.anchor2box(raw_output[0][3:], with_logits=True)
|
| 116 |
+
nms_out = bbox_nms(predict, self.nms)
|
| 117 |
+
draw_bboxes(
|
| 118 |
+
images[0], nms_out[0], scaled_bbox=False, save_path=self.save_path, save_name=f"frame{idx:03d}.png"
|
| 119 |
+
)
|
| 120 |
+
except KeyboardInterrupt:
|
| 121 |
+
logger.error("Interrupted by user")
|
| 122 |
+
dataloader.stop_event.set()
|
| 123 |
+
dataloader.stop()
|
| 124 |
+
except Exception as e:
|
| 125 |
+
logger.error(e)
|
| 126 |
+
dataloader.stop_event.set()
|
| 127 |
+
dataloader.stop()
|
| 128 |
+
raise e
|
| 129 |
+
dataloader.stop()
|
yolo/utils/bounding_box_utils.py
CHANGED
|
@@ -303,7 +303,7 @@ def bbox_nms(predicts: Tensor, nms_cfg: NMSConfig):
|
|
| 303 |
batch_idx, *_ = torch.where(valid_mask)
|
| 304 |
nms_idx = batched_nms(valid_box, valid_cls, batch_idx, nms_cfg.min_iou)
|
| 305 |
predicts_nms = []
|
| 306 |
-
for idx in range(
|
| 307 |
instance_idx = nms_idx[idx == batch_idx[nms_idx]]
|
| 308 |
|
| 309 |
predict_nms = torch.cat([valid_cls[instance_idx][:, None], valid_box[instance_idx]], dim=-1)
|
|
|
|
| 303 |
batch_idx, *_ = torch.where(valid_mask)
|
| 304 |
nms_idx = batched_nms(valid_box, valid_cls, batch_idx, nms_cfg.min_iou)
|
| 305 |
predicts_nms = []
|
| 306 |
+
for idx in range(predicts.size(0)):
|
| 307 |
instance_idx = nms_idx[idx == batch_idx[nms_idx]]
|
| 308 |
|
| 309 |
predict_nms = torch.cat([valid_cls[instance_idx][:, None], valid_box[instance_idx]], dim=-1)
|