|
from pytorch_caney.config import get_config |
|
from pytorch_caney.data.transforms import SimmimTransform |
|
from pytorch_caney.data.transforms import TensorResizeTransform |
|
|
|
import argparse |
|
import unittest |
|
import torch |
|
import numpy as np |
|
|
|
|
|
class TestTransforms(unittest.TestCase): |
|
|
|
def setUp(self): |
|
|
|
config_path = 'pytorch_caney/' + \ |
|
'tests/config/test_config.yaml' |
|
args = argparse.Namespace(cfg=config_path) |
|
self.config = get_config(args) |
|
|
|
def test_simmim_transform(self): |
|
|
|
|
|
transform = SimmimTransform(self.config) |
|
|
|
|
|
img = np.random.randn(self.config.DATA.IMG_SIZE, |
|
self.config.DATA.IMG_SIZE, |
|
7) |
|
|
|
|
|
img_transformed, mask = transform(img) |
|
|
|
|
|
self.assertIsInstance(img_transformed, torch.Tensor) |
|
self.assertEqual(img_transformed.shape, (7, |
|
self.config.DATA.IMG_SIZE, |
|
self.config.DATA.IMG_SIZE)) |
|
self.assertIsInstance(mask, np.ndarray) |
|
|
|
def test_tensor_resize_transform(self): |
|
|
|
transform = TensorResizeTransform(self.config) |
|
|
|
|
|
img = np.random.randn(self.config.DATA.IMG_SIZE, |
|
self.config.DATA.IMG_SIZE, |
|
7) |
|
|
|
target = np.random.randint(0, 5, |
|
size=((self.config.DATA.IMG_SIZE, |
|
self.config.DATA.IMG_SIZE))) |
|
|
|
|
|
img_transformed = transform(img) |
|
target_transformed = transform(target) |
|
|
|
|
|
self.assertIsInstance(img_transformed, torch.Tensor) |
|
self.assertEqual(img_transformed.shape, |
|
(7, self.config.DATA.IMG_SIZE, |
|
self.config.DATA.IMG_SIZE)) |
|
|
|
self.assertIsInstance(target_transformed, torch.Tensor) |
|
self.assertEqual(target_transformed.shape, |
|
(1, self.config.DATA.IMG_SIZE, |
|
self.config.DATA.IMG_SIZE)) |
|
|
|
|
|
if __name__ == '__main__': |
|
unittest.main() |
|
|