Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,226 Bytes
46ff99b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
from torchvision import transforms
from .transforms import GaussianBlur, make_normalize_transform
logger = logging.getLogger("dinov2")
class DataAugmentationDINO(object):
def __init__(
self,
global_crops_scale,
local_crops_scale,
local_crops_number,
global_crops_size=224,
local_crops_size=96,
):
self.global_crops_scale = global_crops_scale
self.local_crops_scale = local_crops_scale
self.local_crops_number = local_crops_number
self.global_crops_size = global_crops_size
self.local_crops_size = local_crops_size
logger.info("###################################")
logger.info("Using data augmentation parameters:")
logger.info(f"global_crops_scale: {global_crops_scale}")
logger.info(f"local_crops_scale: {local_crops_scale}")
logger.info(f"local_crops_number: {local_crops_number}")
logger.info(f"global_crops_size: {global_crops_size}")
logger.info(f"local_crops_size: {local_crops_size}")
logger.info("###################################")
# random resized crop and flip
self.geometric_augmentation_global = transforms.Compose(
[
transforms.RandomResizedCrop(
global_crops_size,
scale=global_crops_scale,
interpolation=transforms.InterpolationMode.BICUBIC,
),
transforms.RandomHorizontalFlip(p=0.5),
]
)
self.geometric_augmentation_local = transforms.Compose(
[
transforms.RandomResizedCrop(
local_crops_size,
scale=local_crops_scale,
interpolation=transforms.InterpolationMode.BICUBIC,
),
transforms.RandomHorizontalFlip(p=0.5),
]
)
# color distorsions / blurring
color_jittering = transforms.Compose(
[
transforms.RandomApply(
[
transforms.ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1
)
],
p=0.8,
),
transforms.RandomGrayscale(p=0.2),
]
)
global_transfo1_extra = GaussianBlur(p=1.0)
global_transfo2_extra = transforms.Compose(
[
GaussianBlur(p=0.1),
transforms.RandomSolarize(threshold=128, p=0.2),
]
)
local_transfo_extra = GaussianBlur(p=0.5)
# normalization
self.normalize = transforms.Compose(
[
transforms.ToTensor(),
make_normalize_transform(),
]
)
self.global_transfo1 = transforms.Compose(
[color_jittering, global_transfo1_extra, self.normalize]
)
self.global_transfo2 = transforms.Compose(
[color_jittering, global_transfo2_extra, self.normalize]
)
self.local_transfo = transforms.Compose(
[color_jittering, local_transfo_extra, self.normalize]
)
def __call__(self, image):
output = {}
# global crops:
im1_base = self.geometric_augmentation_global(image)
global_crop_1 = self.global_transfo1(im1_base)
im2_base = self.geometric_augmentation_global(image)
global_crop_2 = self.global_transfo2(im2_base)
output["global_crops"] = [global_crop_1, global_crop_2]
# global crops for teacher:
output["global_crops_teacher"] = [global_crop_1, global_crop_2]
# local crops:
local_crops = [
self.local_transfo(self.geometric_augmentation_local(image))
for _ in range(self.local_crops_number)
]
output["local_crops"] = local_crops
output["offsets"] = ()
return output
|