File size: 2,569 Bytes
2bf5660 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from abc import abstractmethod
import torch
import numpy as np
from .utils import recursive_find_device, recursive_find_dtype
class EncodingSampler:
"""
Class to sample encodings given low dimensional spatial relationships.
"""
def __init__(self, encodes):
self.encodes = encodes
def apply_coefs(self, coefs):
"""
Linear combination of encodings given coefs
"""
device = recursive_find_device(self.encodes)
dtype = recursive_find_dtype(self.encodes)
# NOTE: Convert from float64 first to `dtype` and *then* to `device` to
# prevent issues with certain devices not supporting f64
# (*cough cough* Apple)
coefs = torch.from_numpy(coefs).to(dtype).to(device)
def single_apply(encodes):
if encodes is None:
return None
elif len(encodes.shape) == 3:
return (coefs[:,None,None] * encodes).sum(0)
elif len(encodes.shape) == 2:
return (coefs[:,None] * encodes).sum(0)
else:
raise ValueError("Encoding Sampler couldn't figure out shape of encodings")
if isinstance(self.encodes, list) or isinstance(self.encodes, tuple):
return list(map(single_apply, self.encodes))
else:
return single_apply(self.encodes)
@abstractmethod
def __call__(self, point, other_points):
"""
:param point: Point in low space representing user input ([2,] array)
:param other_points: Points in low space representing existing prompts ([N,2] array)
"""
pass
class DistanceSampling(EncodingSampler):
"""
Sample based on distances between points in low dim space
"""
def __call__(self, point, other_points):
coefs = 1. / ((1. + np.linalg.norm(point[None,:] - other_points, axis = 1) ** 2))
return self.apply_coefs(coefs)
class CircleSampling(EncodingSampler):
"""
Sampler that views all encodings as points on a unit circle
"""
def __call__(self, point, other_points):
# Idea: weight of points in same direction should be 1
# weight of points in opposite should be 0
cos_sims = point @ other_points.transpose() # [2] x [2, N] -> N
# Negative values don't work, but we want something analagous for "negative signals"
# tanh is like -x for low values, but then caps out at 1
#cos_sims = np.where(cos_sims<0, np.tanh(cos_sims), cos_sims)
return self.apply_coefs(cos_sims)
|