“Transcendental-Programmer” commited on
Commit
2bf5660
·
1 Parent(s): c70fcb3

feat: added requirements and sampling

Browse files
latent_space_explorer/sampling.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+ from .utils import recursive_find_device, recursive_find_dtype
7
+
8
+ class EncodingSampler:
9
+ """
10
+ Class to sample encodings given low dimensional spatial relationships.
11
+ """
12
+ def __init__(self, encodes):
13
+ self.encodes = encodes
14
+
15
+ def apply_coefs(self, coefs):
16
+ """
17
+ Linear combination of encodings given coefs
18
+ """
19
+ device = recursive_find_device(self.encodes)
20
+ dtype = recursive_find_dtype(self.encodes)
21
+ # NOTE: Convert from float64 first to `dtype` and *then* to `device` to
22
+ # prevent issues with certain devices not supporting f64
23
+ # (*cough cough* Apple)
24
+ coefs = torch.from_numpy(coefs).to(dtype).to(device)
25
+
26
+ def single_apply(encodes):
27
+ if encodes is None:
28
+ return None
29
+ elif len(encodes.shape) == 3:
30
+ return (coefs[:,None,None] * encodes).sum(0)
31
+ elif len(encodes.shape) == 2:
32
+ return (coefs[:,None] * encodes).sum(0)
33
+ else:
34
+ raise ValueError("Encoding Sampler couldn't figure out shape of encodings")
35
+
36
+ if isinstance(self.encodes, list) or isinstance(self.encodes, tuple):
37
+ return list(map(single_apply, self.encodes))
38
+ else:
39
+ return single_apply(self.encodes)
40
+
41
+ @abstractmethod
42
+ def __call__(self, point, other_points):
43
+ """
44
+ :param point: Point in low space representing user input ([2,] array)
45
+ :param other_points: Points in low space representing existing prompts ([N,2] array)
46
+ """
47
+ pass
48
+
49
+ class DistanceSampling(EncodingSampler):
50
+ """
51
+ Sample based on distances between points in low dim space
52
+ """
53
+ def __call__(self, point, other_points):
54
+ coefs = 1. / ((1. + np.linalg.norm(point[None,:] - other_points, axis = 1) ** 2))
55
+ return self.apply_coefs(coefs)
56
+
57
+ class CircleSampling(EncodingSampler):
58
+ """
59
+ Sampler that views all encodings as points on a unit circle
60
+ """
61
+ def __call__(self, point, other_points):
62
+ # Idea: weight of points in same direction should be 1
63
+ # weight of points in opposite should be 0
64
+ cos_sims = point @ other_points.transpose() # [2] x [2, N] -> N
65
+
66
+ # Negative values don't work, but we want something analagous for "negative signals"
67
+ # tanh is like -x for low values, but then caps out at 1
68
+ #cos_sims = np.where(cos_sims<0, np.tanh(cos_sims), cos_sims)
69
+ return self.apply_coefs(cos_sims)
latent_space_explorer/utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import math
3
+ import torch
4
+
5
+ def random_circle_init(min_r : float = 0.5, on_edge : bool = False):
6
+ theta = random.uniform(0, 2 * math.pi)
7
+ if on_edge:
8
+ r = 1.0
9
+ else:
10
+ r = random.uniform(min_r, 1.0)
11
+ x = r * math.cos(theta)
12
+ y = r * math.sin(theta)
13
+
14
+ return x, y
15
+
16
+ def recursive_find_dtype(x):
17
+ """
18
+ Assuming x is some list/tuple of things that could be tensors, searches for any tensors and returns dtype
19
+ """
20
+ for i in x:
21
+ if isinstance(i, list):
22
+ res = recursive_find_dtype(i)
23
+ if res is None:
24
+ continue
25
+ else:
26
+ return res
27
+ elif isinstance(i, torch.Tensor):
28
+ return i.dtype
29
+
30
+ def recursive_find_device(x):
31
+ """
32
+ Assuming x is some list/tuple of things that could be tensors, searches for any tensors and returns device
33
+ """
34
+ for i in x:
35
+ if isinstance(i, list):
36
+ res = recursive_find_device(i)
37
+ if res is None:
38
+ continue
39
+ return res
40
+ elif isinstance(i, torch.Tensor):
41
+ return i.device
main.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from latent_space_explorer import GameConfig, LatentSpaceExplorer
2
+
3
+ if __name__ == "__main__":
4
+ config = GameConfig(
5
+ call_every = 100
6
+ )
7
+
8
+ explorer = LatentSpaceExplorer(config)
9
+
10
+ explorer.set_prompts(
11
+ [
12
+ "A photo of a cat",
13
+ "A space-aged ferrari",
14
+ "artwork of the titanic hitting an iceberg",
15
+ "a photo of a dog"
16
+ ]
17
+ )
18
+
19
+ while True:
20
+ explorer.update()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ diffusers
2
+ pygame
3
+ torch
4
+ torchvision