File size: 2,690 Bytes
e85fecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""

Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)

Copyright(c) 2023 lyuwenyu. All Rights Reserved.

"""

import random

import torch
import torchvision
import torchvision.transforms.v2 as T
import torchvision.transforms.v2.functional as F
from PIL import Image

from ...core import register
from .._misc import convert_to_tv_tensor

torchvision.disable_beta_transforms_warning()


@register()
class Mosaic(T.Transform):
    def __init__(

        self,

        size,

        max_size=None,

    ) -> None:
        super().__init__()
        self.resize = T.Resize(size=size, max_size=max_size)
        self.crop = T.RandomCrop(size=max_size if max_size else size)

        # TODO add arg `output_size` for affine`
        # self.random_perspective = T.RandomPerspective(distortion_scale=0.5, p=1., )
        self.random_affine = T.RandomAffine(
            degrees=0, translate=(0.1, 0.1), scale=(0.5, 1.5), fill=114
        )

    def forward(self, *inputs):
        inputs = inputs if len(inputs) > 1 else inputs[0]
        image, target, dataset = inputs

        images = []
        targets = []
        indices = random.choices(range(len(dataset)), k=3)
        for i in indices:
            image, target = dataset.load_item(i)
            image, target = self.resize(image, target)
            images.append(image)
            targets.append(target)

        h, w = F.get_spatial_size(images[0])
        offset = [[0, 0], [w, 0], [0, h], [w, h]]
        image = Image.new(mode=images[0].mode, size=(w * 2, h * 2), color=0)
        for i, im in enumerate(images):
            image.paste(im, offset[i])

        offset = torch.tensor([[0, 0], [w, 0], [0, h], [w, h]]).repeat(1, 2)
        target = {}
        for k in targets[0]:
            if k == "boxes":
                v = [t[k] + offset[i] for i, t in enumerate(targets)]
            else:
                v = [t[k] for t in targets]

            if isinstance(v[0], torch.Tensor):
                v = torch.cat(v, dim=0)

            target[k] = v

        if "boxes" in target:
            # target['boxes'] = target['boxes'].clamp(0, 640 * 2 - 1)
            w, h = image.size
            target["boxes"] = convert_to_tv_tensor(
                target["boxes"], "boxes", box_format="xyxy", spatial_size=[h, w]
            )

        if "masks" in target:
            target["masks"] = convert_to_tv_tensor(target["masks"], "masks")

        image, target = self.random_affine(image, target)
        # image, target = self.resize(image, target)
        image, target = self.crop(image, target)

        return image, target, dataset