File size: 2,368 Bytes
ab687e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from pytorch_caney.config import get_config
from pytorch_caney.data.transforms import SimmimTransform
from pytorch_caney.data.transforms import TensorResizeTransform

import argparse
import unittest
import torch
import numpy as np


class TestTransforms(unittest.TestCase):

    def setUp(self):
        # Initialize any required configuration here
        config_path = 'pytorch_caney/' + \
            'tests/config/test_config.yaml'
        args = argparse.Namespace(cfg=config_path)
        self.config = get_config(args)

    def test_simmim_transform(self):

        # Create an instance of SimmimTransform
        transform = SimmimTransform(self.config)

        # Create a sample ndarray
        img = np.random.randn(self.config.DATA.IMG_SIZE,
                              self.config.DATA.IMG_SIZE,
                              7)

        # Apply the transform
        img_transformed, mask = transform(img)

        # Assertions
        self.assertIsInstance(img_transformed, torch.Tensor)
        self.assertEqual(img_transformed.shape, (7,
                         self.config.DATA.IMG_SIZE,
                         self.config.DATA.IMG_SIZE))
        self.assertIsInstance(mask, np.ndarray)

    def test_tensor_resize_transform(self):
        # Create an instance of TensorResizeTransform
        transform = TensorResizeTransform(self.config)

        # Create a sample image tensor
        img = np.random.randn(self.config.DATA.IMG_SIZE,
                              self.config.DATA.IMG_SIZE,
                              7)

        target = np.random.randint(0, 5,
                                   size=((self.config.DATA.IMG_SIZE,
                                          self.config.DATA.IMG_SIZE)))

        # Apply the transform
        img_transformed = transform(img)
        target_transformed = transform(target)

        # Assertions
        self.assertIsInstance(img_transformed, torch.Tensor)
        self.assertEqual(img_transformed.shape,
                         (7, self.config.DATA.IMG_SIZE,
                          self.config.DATA.IMG_SIZE))

        self.assertIsInstance(target_transformed, torch.Tensor)
        self.assertEqual(target_transformed.shape,
                         (1, self.config.DATA.IMG_SIZE,
                          self.config.DATA.IMG_SIZE))


if __name__ == '__main__':
    unittest.main()