File size: 1,397 Bytes
ab687e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from pytorch_caney.config import get_config

import argparse
import unittest


class TestConfig(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        cls.config_yaml_path = 'pytorch_caney/' + \
            'tests/config/test_config.yaml'

    def test_default_config(self):
        # Get the default configuration
        args = argparse.Namespace(cfg=self.config_yaml_path)
        config = get_config(args)

        # Test specific configuration values
        self.assertEqual(config.DATA.BATCH_SIZE, 128)
        self.assertEqual(config.DATA.DATASET, 'MODIS')
        self.assertEqual(config.MODEL.TYPE, 'swinv2')
        self.assertEqual(config.MODEL.NAME, 'test_config')
        self.assertEqual(config.TRAIN.EPOCHS, 800)

    def test_custom_config(self):
        # Test with custom arguments
        args = argparse.Namespace(
            cfg=self.config_yaml_path,
            batch_size=64,
            dataset='CustomDataset',
            data_paths=['solongandthanksforallthefish'],
        )
        config = get_config(args)

        # Test specific configuration values with custom arguments
        self.assertEqual(config.DATA.BATCH_SIZE, 64)
        self.assertEqual(config.DATA.DATASET, 'CustomDataset')
        self.assertEqual(config.DATA.DATA_PATHS,
                         ['solongandthanksforallthefish'])


if __name__ == '__main__':
    unittest.main()