File size: 3,032 Bytes
e85fecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""

Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)

Copyright(c) 2023 lyuwenyu. All Rights Reserved.

"""

import os
from typing import Callable, Optional

import torch
import torchvision
import torchvision.transforms.functional as TVF
from PIL import Image
from sympy import im

try:
    from defusedxml.ElementTree import parse as ET_parse
except ImportError:
    from xml.etree.ElementTree import parse as ET_parse

from ...core import register
from .._misc import convert_to_tv_tensor
from ._dataset import DetDataset


@register()
class VOCDetection(torchvision.datasets.VOCDetection, DetDataset):
    __inject__ = [
        "transforms",
    ]

    def __init__(

        self,

        root: str,

        ann_file: str = "trainval.txt",

        label_file: str = "label_list.txt",

        transforms: Optional[Callable] = None,

    ):
        with open(os.path.join(root, ann_file), "r") as f:
            lines = [x.strip() for x in f.readlines()]
            lines = [x.split(" ") for x in lines]

        self.images = [os.path.join(root, lin[0]) for lin in lines]
        self.targets = [os.path.join(root, lin[1]) for lin in lines]
        assert len(self.images) == len(self.targets)

        with open(os.path.join(root + label_file), "r") as f:
            labels = f.readlines()
            labels = [lab.strip() for lab in labels]

        self.transforms = transforms
        self.labels_map = {lab: i for i, lab in enumerate(labels)}

    def __getitem__(self, index: int):
        image, target = self.load_item(index)
        if self.transforms is not None:
            image, target, _ = self.transforms(image, target, self)
        # target["orig_size"] = torch.tensor(TVF.get_image_size(image))
        return image, target

    def load_item(self, index: int):
        image = Image.open(self.images[index]).convert("RGB")
        target = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot())

        output = {}
        output["image_id"] = torch.tensor([index])
        for k in ["area", "boxes", "labels", "iscrowd"]:
            output[k] = []

        for blob in target["annotation"]["object"]:
            box = [float(v) for v in blob["bndbox"].values()]
            output["boxes"].append(box)
            output["labels"].append(blob["name"])
            output["area"].append((box[2] - box[0]) * (box[3] - box[1]))
            output["iscrowd"].append(0)

        w, h = image.size
        boxes = torch.tensor(output["boxes"]) if len(output["boxes"]) > 0 else torch.zeros(0, 4)
        output["boxes"] = convert_to_tv_tensor(
            boxes, "boxes", box_format="xyxy", spatial_size=[h, w]
        )
        output["labels"] = torch.tensor([self.labels_map[lab] for lab in output["labels"]])
        output["area"] = torch.tensor(output["area"])
        output["iscrowd"] = torch.tensor(output["iscrowd"])
        output["orig_size"] = torch.tensor([w, h])

        return image, output