File size: 2,368 Bytes
ab687e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
from pytorch_caney.config import get_config
from pytorch_caney.data.transforms import SimmimTransform
from pytorch_caney.data.transforms import TensorResizeTransform
import argparse
import unittest
import torch
import numpy as np
class TestTransforms(unittest.TestCase):
def setUp(self):
# Initialize any required configuration here
config_path = 'pytorch_caney/' + \
'tests/config/test_config.yaml'
args = argparse.Namespace(cfg=config_path)
self.config = get_config(args)
def test_simmim_transform(self):
# Create an instance of SimmimTransform
transform = SimmimTransform(self.config)
# Create a sample ndarray
img = np.random.randn(self.config.DATA.IMG_SIZE,
self.config.DATA.IMG_SIZE,
7)
# Apply the transform
img_transformed, mask = transform(img)
# Assertions
self.assertIsInstance(img_transformed, torch.Tensor)
self.assertEqual(img_transformed.shape, (7,
self.config.DATA.IMG_SIZE,
self.config.DATA.IMG_SIZE))
self.assertIsInstance(mask, np.ndarray)
def test_tensor_resize_transform(self):
# Create an instance of TensorResizeTransform
transform = TensorResizeTransform(self.config)
# Create a sample image tensor
img = np.random.randn(self.config.DATA.IMG_SIZE,
self.config.DATA.IMG_SIZE,
7)
target = np.random.randint(0, 5,
size=((self.config.DATA.IMG_SIZE,
self.config.DATA.IMG_SIZE)))
# Apply the transform
img_transformed = transform(img)
target_transformed = transform(target)
# Assertions
self.assertIsInstance(img_transformed, torch.Tensor)
self.assertEqual(img_transformed.shape,
(7, self.config.DATA.IMG_SIZE,
self.config.DATA.IMG_SIZE))
self.assertIsInstance(target_transformed, torch.Tensor)
self.assertEqual(target_transformed.shape,
(1, self.config.DATA.IMG_SIZE,
self.config.DATA.IMG_SIZE))
if __name__ == '__main__':
unittest.main()
|