Transcendental-Programmer commited on
Commit
341b6b4
·
1 Parent(s): e3af1ef

Add unit tests for core modules: latent explorer, attribute directions, and custom loss

Browse files
tests/test_attribute_directions.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import numpy as np
3
+ from faceforge_core.attribute_directions import LatentDirectionFinder
4
+
5
+ class TestLatentDirectionFinder(unittest.TestCase):
6
+ def setUp(self):
7
+ # 100 samples, 5D latent
8
+ self.latents = np.random.randn(100, 5)
9
+ self.labels = [0]*50 + [1]*50
10
+ self.finder = LatentDirectionFinder(self.latents)
11
+
12
+ def test_pca_direction(self):
13
+ components, explained = self.finder.pca_direction(n_components=2)
14
+ self.assertEqual(components.shape, (2, 5))
15
+ self.assertEqual(explained.shape, (2,))
16
+
17
+ def test_classifier_direction(self):
18
+ direction = self.finder.classifier_direction(self.labels)
19
+ self.assertEqual(direction.shape, (5,))
20
+ self.assertAlmostEqual(np.linalg.norm(direction), 1.0, places=5)
21
+
22
+ if __name__ == "__main__":
23
+ unittest.main()
tests/test_custom_loss.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import torch
3
+ from faceforge_core.custom_loss import attribute_preserving_loss
4
+
5
+ class TestAttributePreservingLoss(unittest.TestCase):
6
+ def setUp(self):
7
+ self.generated = torch.ones((2, 3, 4, 4))
8
+ self.original = torch.zeros((2, 3, 4, 4))
9
+ self.y_target = torch.ones((2, 1))
10
+ self.attr_predictor = lambda x: torch.ones((2, 1))
11
+
12
+ def test_loss_value(self):
13
+ loss = attribute_preserving_loss(
14
+ self.generated, self.original, self.attr_predictor, self.y_target, lambda_pred=2.0, lambda_recon=3.0
15
+ )
16
+ # pred_loss = 0, recon_loss = mean((1-0)^2) = 1
17
+ self.assertAlmostEqual(loss.item(), 3.0)
18
+
19
+ def test_loss_with_nonzero_pred(self):
20
+ attr_predictor = lambda x: torch.zeros((2, 1))
21
+ loss = attribute_preserving_loss(
22
+ self.generated, self.original, attr_predictor, self.y_target, lambda_pred=2.0, lambda_recon=3.0
23
+ )
24
+ # pred_loss = mean((0-1)^2) = 1, recon_loss = 1
25
+ self.assertAlmostEqual(loss.item(), 5.0)
26
+
27
+ if __name__ == "__main__":
28
+ unittest.main()
tests/test_latent_explorer.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import numpy as np
3
+ from faceforge_core.latent_explorer import LatentSpaceExplorer, LatentPoint
4
+
5
+ class TestLatentSpaceExplorer(unittest.TestCase):
6
+ def setUp(self):
7
+ self.explorer = LatentSpaceExplorer()
8
+ self.dummy_encoding = np.array([1.0, 2.0])
9
+
10
+ def test_add_point(self):
11
+ self.explorer.add_point("test", self.dummy_encoding, (0.5, 0.5))
12
+ self.assertEqual(len(self.explorer.points), 1)
13
+ self.assertEqual(self.explorer.points[0].text, "test")
14
+ np.testing.assert_array_equal(self.explorer.points[0].encoding, self.dummy_encoding)
15
+ self.assertEqual(self.explorer.points[0].xy_pos, (0.5, 0.5))
16
+
17
+ def test_delete_point(self):
18
+ self.explorer.add_point("test", self.dummy_encoding)
19
+ self.explorer.delete_point(0)
20
+ self.assertEqual(len(self.explorer.points), 0)
21
+
22
+ def test_modify_point(self):
23
+ self.explorer.add_point("test", self.dummy_encoding)
24
+ new_encoding = np.array([3.0, 4.0])
25
+ self.explorer.modify_point(0, "new", new_encoding)
26
+ self.assertEqual(self.explorer.points[0].text, "new")
27
+ np.testing.assert_array_equal(self.explorer.points[0].encoding, new_encoding)
28
+
29
+ def test_sample_encoding_distance(self):
30
+ self.explorer.add_point("a", np.array([1.0, 0.0]), (0.0, 0.0))
31
+ self.explorer.add_point("b", np.array([0.0, 1.0]), (1.0, 0.0))
32
+ sampled = self.explorer.sample_encoding((0.5, 0.0), mode="distance")
33
+ self.assertIsNotNone(sampled)
34
+ self.assertEqual(sampled.shape, (2,))
35
+
36
+ def test_sample_encoding_circle(self):
37
+ self.explorer.add_point("a", np.array([1.0, 0.0]), (1.0, 0.0))
38
+ self.explorer.add_point("b", np.array([0.0, 1.0]), (0.0, 1.0))
39
+ sampled = self.explorer.sample_encoding((1.0, 1.0), mode="circle")
40
+ self.assertIsNotNone(sampled)
41
+ self.assertEqual(sampled.shape, (2,))
42
+
43
+ if __name__ == "__main__":
44
+ unittest.main()