Transcendental-Programmer
commited on
Commit
·
341b6b4
1
Parent(s):
e3af1ef
Add unit tests for core modules: latent explorer, attribute directions, and custom loss
Browse files- tests/test_attribute_directions.py +23 -0
- tests/test_custom_loss.py +28 -0
- tests/test_latent_explorer.py +44 -0
tests/test_attribute_directions.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import numpy as np
|
3 |
+
from faceforge_core.attribute_directions import LatentDirectionFinder
|
4 |
+
|
5 |
+
class TestLatentDirectionFinder(unittest.TestCase):
|
6 |
+
def setUp(self):
|
7 |
+
# 100 samples, 5D latent
|
8 |
+
self.latents = np.random.randn(100, 5)
|
9 |
+
self.labels = [0]*50 + [1]*50
|
10 |
+
self.finder = LatentDirectionFinder(self.latents)
|
11 |
+
|
12 |
+
def test_pca_direction(self):
|
13 |
+
components, explained = self.finder.pca_direction(n_components=2)
|
14 |
+
self.assertEqual(components.shape, (2, 5))
|
15 |
+
self.assertEqual(explained.shape, (2,))
|
16 |
+
|
17 |
+
def test_classifier_direction(self):
|
18 |
+
direction = self.finder.classifier_direction(self.labels)
|
19 |
+
self.assertEqual(direction.shape, (5,))
|
20 |
+
self.assertAlmostEqual(np.linalg.norm(direction), 1.0, places=5)
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
unittest.main()
|
tests/test_custom_loss.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import torch
|
3 |
+
from faceforge_core.custom_loss import attribute_preserving_loss
|
4 |
+
|
5 |
+
class TestAttributePreservingLoss(unittest.TestCase):
|
6 |
+
def setUp(self):
|
7 |
+
self.generated = torch.ones((2, 3, 4, 4))
|
8 |
+
self.original = torch.zeros((2, 3, 4, 4))
|
9 |
+
self.y_target = torch.ones((2, 1))
|
10 |
+
self.attr_predictor = lambda x: torch.ones((2, 1))
|
11 |
+
|
12 |
+
def test_loss_value(self):
|
13 |
+
loss = attribute_preserving_loss(
|
14 |
+
self.generated, self.original, self.attr_predictor, self.y_target, lambda_pred=2.0, lambda_recon=3.0
|
15 |
+
)
|
16 |
+
# pred_loss = 0, recon_loss = mean((1-0)^2) = 1
|
17 |
+
self.assertAlmostEqual(loss.item(), 3.0)
|
18 |
+
|
19 |
+
def test_loss_with_nonzero_pred(self):
|
20 |
+
attr_predictor = lambda x: torch.zeros((2, 1))
|
21 |
+
loss = attribute_preserving_loss(
|
22 |
+
self.generated, self.original, attr_predictor, self.y_target, lambda_pred=2.0, lambda_recon=3.0
|
23 |
+
)
|
24 |
+
# pred_loss = mean((0-1)^2) = 1, recon_loss = 1
|
25 |
+
self.assertAlmostEqual(loss.item(), 5.0)
|
26 |
+
|
27 |
+
if __name__ == "__main__":
|
28 |
+
unittest.main()
|
tests/test_latent_explorer.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import numpy as np
|
3 |
+
from faceforge_core.latent_explorer import LatentSpaceExplorer, LatentPoint
|
4 |
+
|
5 |
+
class TestLatentSpaceExplorer(unittest.TestCase):
|
6 |
+
def setUp(self):
|
7 |
+
self.explorer = LatentSpaceExplorer()
|
8 |
+
self.dummy_encoding = np.array([1.0, 2.0])
|
9 |
+
|
10 |
+
def test_add_point(self):
|
11 |
+
self.explorer.add_point("test", self.dummy_encoding, (0.5, 0.5))
|
12 |
+
self.assertEqual(len(self.explorer.points), 1)
|
13 |
+
self.assertEqual(self.explorer.points[0].text, "test")
|
14 |
+
np.testing.assert_array_equal(self.explorer.points[0].encoding, self.dummy_encoding)
|
15 |
+
self.assertEqual(self.explorer.points[0].xy_pos, (0.5, 0.5))
|
16 |
+
|
17 |
+
def test_delete_point(self):
|
18 |
+
self.explorer.add_point("test", self.dummy_encoding)
|
19 |
+
self.explorer.delete_point(0)
|
20 |
+
self.assertEqual(len(self.explorer.points), 0)
|
21 |
+
|
22 |
+
def test_modify_point(self):
|
23 |
+
self.explorer.add_point("test", self.dummy_encoding)
|
24 |
+
new_encoding = np.array([3.0, 4.0])
|
25 |
+
self.explorer.modify_point(0, "new", new_encoding)
|
26 |
+
self.assertEqual(self.explorer.points[0].text, "new")
|
27 |
+
np.testing.assert_array_equal(self.explorer.points[0].encoding, new_encoding)
|
28 |
+
|
29 |
+
def test_sample_encoding_distance(self):
|
30 |
+
self.explorer.add_point("a", np.array([1.0, 0.0]), (0.0, 0.0))
|
31 |
+
self.explorer.add_point("b", np.array([0.0, 1.0]), (1.0, 0.0))
|
32 |
+
sampled = self.explorer.sample_encoding((0.5, 0.0), mode="distance")
|
33 |
+
self.assertIsNotNone(sampled)
|
34 |
+
self.assertEqual(sampled.shape, (2,))
|
35 |
+
|
36 |
+
def test_sample_encoding_circle(self):
|
37 |
+
self.explorer.add_point("a", np.array([1.0, 0.0]), (1.0, 0.0))
|
38 |
+
self.explorer.add_point("b", np.array([0.0, 1.0]), (0.0, 1.0))
|
39 |
+
sampled = self.explorer.sample_encoding((1.0, 1.0), mode="circle")
|
40 |
+
self.assertIsNotNone(sampled)
|
41 |
+
self.assertEqual(sampled.shape, (2,))
|
42 |
+
|
43 |
+
if __name__ == "__main__":
|
44 |
+
unittest.main()
|