Spaces:
Running
Running
| import torch | |
| from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN | |
| def test_unetdiscriminatorsn(): | |
| """Test arch: UNetDiscriminatorSN.""" | |
| # model init and forward (cpu) | |
| net = UNetDiscriminatorSN(num_in_ch=3, num_feat=4, skip_connection=True) | |
| img = torch.rand((1, 3, 32, 32), dtype=torch.float32) | |
| output = net(img) | |
| assert output.shape == (1, 1, 32, 32) | |
| # model init and forward (gpu) | |
| if torch.cuda.is_available(): | |
| net.cuda() | |
| output = net(img.cuda()) | |
| assert output.shape == (1, 1, 32, 32) | |