Caleb Spradlin
initial commit
ab687e7
from pytorch_caney.config import get_config
from pytorch_caney.data.transforms import SimmimTransform
from pytorch_caney.data.transforms import TensorResizeTransform
import argparse
import unittest
import torch
import numpy as np
class TestTransforms(unittest.TestCase):
def setUp(self):
# Initialize any required configuration here
config_path = 'pytorch_caney/' + \
'tests/config/test_config.yaml'
args = argparse.Namespace(cfg=config_path)
self.config = get_config(args)
def test_simmim_transform(self):
# Create an instance of SimmimTransform
transform = SimmimTransform(self.config)
# Create a sample ndarray
img = np.random.randn(self.config.DATA.IMG_SIZE,
self.config.DATA.IMG_SIZE,
7)
# Apply the transform
img_transformed, mask = transform(img)
# Assertions
self.assertIsInstance(img_transformed, torch.Tensor)
self.assertEqual(img_transformed.shape, (7,
self.config.DATA.IMG_SIZE,
self.config.DATA.IMG_SIZE))
self.assertIsInstance(mask, np.ndarray)
def test_tensor_resize_transform(self):
# Create an instance of TensorResizeTransform
transform = TensorResizeTransform(self.config)
# Create a sample image tensor
img = np.random.randn(self.config.DATA.IMG_SIZE,
self.config.DATA.IMG_SIZE,
7)
target = np.random.randint(0, 5,
size=((self.config.DATA.IMG_SIZE,
self.config.DATA.IMG_SIZE)))
# Apply the transform
img_transformed = transform(img)
target_transformed = transform(target)
# Assertions
self.assertIsInstance(img_transformed, torch.Tensor)
self.assertEqual(img_transformed.shape,
(7, self.config.DATA.IMG_SIZE,
self.config.DATA.IMG_SIZE))
self.assertIsInstance(target_transformed, torch.Tensor)
self.assertEqual(target_transformed.shape,
(1, self.config.DATA.IMG_SIZE,
self.config.DATA.IMG_SIZE))
if __name__ == '__main__':
unittest.main()