import os | |
import numpy as np | |
import cv2 as cv | |
from .base_dataloader import _BaseImageLoader | |
from ..factory import DATALOADERS | |
class ClassificationImageLoader(_BaseImageLoader): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self._to_rgb = kwargs.pop('toRGB', False) | |
self._center_crop = kwargs.pop('centerCrop', None) | |
def _toRGB(self, image): | |
return cv.cvtColor(image, cv.COLOR_BGR2RGB) | |
def _centerCrop(self, image): | |
h, w, _ = image.shape | |
w_crop = int((w - self._center_crop) / 2.) | |
assert w_crop >= 0 | |
h_crop = int((h - self._center_crop) / 2.) | |
assert h_crop >= 0 | |
return image[w_crop:w-w_crop, h_crop:h-h_crop, :] | |
def __iter__(self): | |
for filename in self._files: | |
image = cv.imread(os.path.join(self._path, filename)) | |
if self._to_rgb: | |
image = self._toRGB(image) | |
if [0, 0] in self._sizes: | |
yield filename, image | |
else: | |
for size in self._sizes: | |
image = cv.resize(image, size) | |
if self._center_crop: | |
image = self._centerCrop(image) | |
yield filename, image |