faceforge / faceforge_core /sampling.py
Transcendental-Programmer
Refactor core logic: move and modularize all latent space, sampling, and utility code into faceforge_core/
e3af1ef
raw
history blame
2.57 kB
from abc import abstractmethod
import torch
import numpy as np
from .utils import recursive_find_device, recursive_find_dtype
class EncodingSampler:
"""
Class to sample encodings given low dimensional spatial relationships.
"""
def __init__(self, encodes):
self.encodes = encodes
def apply_coefs(self, coefs):
"""
Linear combination of encodings given coefs
"""
device = recursive_find_device(self.encodes)
dtype = recursive_find_dtype(self.encodes)
# NOTE: Convert from float64 first to `dtype` and *then* to `device` to
# prevent issues with certain devices not supporting f64
# (*cough cough* Apple)
coefs = torch.from_numpy(coefs).to(dtype).to(device)
def single_apply(encodes):
if encodes is None:
return None
elif len(encodes.shape) == 3:
return (coefs[:,None,None] * encodes).sum(0)
elif len(encodes.shape) == 2:
return (coefs[:,None] * encodes).sum(0)
else:
raise ValueError("Encoding Sampler couldn't figure out shape of encodings")
if isinstance(self.encodes, list) or isinstance(self.encodes, tuple):
return list(map(single_apply, self.encodes))
else:
return single_apply(self.encodes)
@abstractmethod
def __call__(self, point, other_points):
"""
:param point: Point in low space representing user input ([2,] array)
:param other_points: Points in low space representing existing prompts ([N,2] array)
"""
pass
class DistanceSampling(EncodingSampler):
"""
Sample based on distances between points in low dim space
"""
def __call__(self, point, other_points):
coefs = 1. / ((1. + np.linalg.norm(point[None,:] - other_points, axis = 1) ** 2))
return self.apply_coefs(coefs)
class CircleSampling(EncodingSampler):
"""
Sampler that views all encodings as points on a unit circle
"""
def __call__(self, point, other_points):
# Idea: weight of points in same direction should be 1
# weight of points in opposite should be 0
cos_sims = point @ other_points.transpose() # [2] x [2, N] -> N
# Negative values don't work, but we want something analagous for "negative signals"
# tanh is like -x for low values, but then caps out at 1
#cos_sims = np.where(cos_sims<0, np.tanh(cos_sims), cos_sims)
return self.apply_coefs(cos_sims)