File size: 3,507 Bytes
e9629ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from PIL import Image
from os import path
import os
import hydra
import numpy as np
import torch
from torch.utils.data import Dataset
from loguru import logger
from tqdm.rich import tqdm
import diskcache as dc


class YoloDataset(Dataset):
    def __init__(self, dataset_cfg: dict, phase="train", transform=None, mixup=None):
        phase_name = dataset_cfg.get(phase, phase)

        self.transform = transform
        self.mixup = mixup
        self.data = self.load_data(dataset_cfg.path, phase_name)

    def load_data(self, dataset_path, phase_name):
        cache = dc.Cache(path.join(dataset_path, ".cache"))

        if phase_name not in cache:
            logger.info("Generate {} Cache", phase_name)

            images_path = path.join(dataset_path, phase_name, "images")
            labels_path = path.join(dataset_path, phase_name, "labels")

            cache[phase_name] = self.filter_data(images_path, labels_path)

        logger.info("Load {} Cache", phase_name)
        data = cache[phase_name]
        cache.close()

        return data

    def filter_data(self, images_path, labels_path):
        data = []
        valid_input = 0
        images_list = os.listdir(images_path)
        images_list.sort()
        for image_name in tqdm(images_list):
            if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
                continue
            img_path = path.join(images_path, image_name)
            base_name, _ = path.splitext(image_name)
            label_name = base_name + ".txt"
            label_path = path.join(labels_path, label_name)

            if not path.isfile(label_path):
                # logger.warning(f"Warning: No label file for {label_path}")
                continue

            labels = self.load_valid_labels(label_path)
            if labels is not None:
                data.append((img_path, labels))
                valid_input += 1
        logger.info("Finish Record {}/{}", valid_input, len(os.listdir(images_path)))
        return data

    def load_valid_labels(self, label_path):
        bboxes = []
        with open(label_path, "r") as file:
            for line in file:
                segment = list(map(float, line.strip().split()))
                cls = segment[0]
                # Ensure parts length is odd and more than two points
                if len(segment) % 2 != 1 or len(segment) < 5:
                    logger.warning(f"Warning: Format error in {label_path}")
                    continue
                points = np.array(segment[1:]).reshape(-1, 2)  # change points to n x 2
                valid_idx = np.any((points <= 1) | (points >= 0), axis=1)  # filter outlier points
                points = points[valid_idx]  # only keep valid points

                bbox = torch.tensor([cls, *points.max(axis=0), *points.min(axis=0)])
                bboxes.append(bbox)
        if not bboxes:
            logger.warning(f"Warning: No valid BBox in {label_path}")
            return None
        return torch.stack(bboxes)

    def __getitem__(self, idx):
        img_path, bboxes = self.data[idx]
        img = Image.open(img_path).convert("RGB")

        return img, bboxes

    def __len__(self):
        return len(self.images)


@hydra.main(config_path="../config/data", config_name="coco", version_base=None)
def main(cfg):
    dataset = YoloDataset(cfg)


if __name__ == "__main__":
    import sys

    sys.path.append("./")
    from tools.log_helper import custom_logger

    custom_logger()
    main()