YOLO / utils /dataloader.py
henry000's picture
✨ [Add] Dataset for loading image, labels
e9629ef
raw
history blame
3.51 kB
from PIL import Image
from os import path
import os
import hydra
import numpy as np
import torch
from torch.utils.data import Dataset
from loguru import logger
from tqdm.rich import tqdm
import diskcache as dc
class YoloDataset(Dataset):
def __init__(self, dataset_cfg: dict, phase="train", transform=None, mixup=None):
phase_name = dataset_cfg.get(phase, phase)
self.transform = transform
self.mixup = mixup
self.data = self.load_data(dataset_cfg.path, phase_name)
def load_data(self, dataset_path, phase_name):
cache = dc.Cache(path.join(dataset_path, ".cache"))
if phase_name not in cache:
logger.info("Generate {} Cache", phase_name)
images_path = path.join(dataset_path, phase_name, "images")
labels_path = path.join(dataset_path, phase_name, "labels")
cache[phase_name] = self.filter_data(images_path, labels_path)
logger.info("Load {} Cache", phase_name)
data = cache[phase_name]
cache.close()
return data
def filter_data(self, images_path, labels_path):
data = []
valid_input = 0
images_list = os.listdir(images_path)
images_list.sort()
for image_name in tqdm(images_list):
if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
continue
img_path = path.join(images_path, image_name)
base_name, _ = path.splitext(image_name)
label_name = base_name + ".txt"
label_path = path.join(labels_path, label_name)
if not path.isfile(label_path):
# logger.warning(f"Warning: No label file for {label_path}")
continue
labels = self.load_valid_labels(label_path)
if labels is not None:
data.append((img_path, labels))
valid_input += 1
logger.info("Finish Record {}/{}", valid_input, len(os.listdir(images_path)))
return data
def load_valid_labels(self, label_path):
bboxes = []
with open(label_path, "r") as file:
for line in file:
segment = list(map(float, line.strip().split()))
cls = segment[0]
# Ensure parts length is odd and more than two points
if len(segment) % 2 != 1 or len(segment) < 5:
logger.warning(f"Warning: Format error in {label_path}")
continue
points = np.array(segment[1:]).reshape(-1, 2) # change points to n x 2
valid_idx = np.any((points <= 1) | (points >= 0), axis=1) # filter outlier points
points = points[valid_idx] # only keep valid points
bbox = torch.tensor([cls, *points.max(axis=0), *points.min(axis=0)])
bboxes.append(bbox)
if not bboxes:
logger.warning(f"Warning: No valid BBox in {label_path}")
return None
return torch.stack(bboxes)
def __getitem__(self, idx):
img_path, bboxes = self.data[idx]
img = Image.open(img_path).convert("RGB")
return img, bboxes
def __len__(self):
return len(self.images)
@hydra.main(config_path="../config/data", config_name="coco", version_base=None)
def main(cfg):
dataset = YoloDataset(cfg)
if __name__ == "__main__":
import sys
sys.path.append("./")
from tools.log_helper import custom_logger
custom_logger()
main()