Yuantao Feng
Decoupling dataloader and benchmark runner (#16)
0199e9f
import os
import numpy as np
import cv2 as cv
from .base_dataloader import _BaseImageLoader
from ..factory import DATALOADERS
@DATALOADERS.register
class ClassificationImageLoader(_BaseImageLoader):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._to_rgb = kwargs.pop('toRGB', False)
self._center_crop = kwargs.pop('centerCrop', None)
def _toRGB(self, image):
return cv.cvtColor(image, cv.COLOR_BGR2RGB)
def _centerCrop(self, image):
h, w, _ = image.shape
w_crop = int((w - self._center_crop) / 2.)
assert w_crop >= 0
h_crop = int((h - self._center_crop) / 2.)
assert h_crop >= 0
return image[w_crop:w-w_crop, h_crop:h-h_crop, :]
def __iter__(self):
for filename in self._files:
image = cv.imread(os.path.join(self._path, filename))
if self._to_rgb:
image = self._toRGB(image)
if [0, 0] in self._sizes:
yield filename, image
else:
for size in self._sizes:
image = cv.resize(image, size)
if self._center_crop:
image = self._centerCrop(image)
yield filename, image