|
"""Copyright (c) Meta Platforms, Inc. and affiliates.""" |
|
|
|
import math |
|
import torch |
|
from geoopt.manifolds import Sphere as geoopt_Sphere |
|
|
|
|
|
class Sphere(geoopt_Sphere): |
|
def transp(self, x, y, v): |
|
denom = 1 + self.inner(x, x, y, keepdim=True) |
|
res = v - self.inner(x, y, v, keepdim=True) / denom * (x + y) |
|
cond = denom.gt(1e-3) |
|
return torch.where(cond, res, -v) |
|
|
|
def uniform_logprob(self, x): |
|
dim = x.shape[-1] |
|
return torch.full_like( |
|
x[..., 0], |
|
math.lgamma(dim / 2) - (math.log(2) + (dim / 2) * math.log(math.pi)), |
|
) |
|
|
|
def random_base(self, *args, **kwargs): |
|
return self.random_uniform(*args, **kwargs) |
|
|
|
def base_logprob(self, *args, **kwargs): |
|
return self.uniform_logprob(*args, **kwargs) |
|
|
|
|
|
def geodesic(manifold, start_point, end_point): |
|
shooting_tangent_vec = manifold.logmap(start_point, end_point) |
|
|
|
def path(t): |
|
"""Generate parameterized function for geodesic curve. |
|
Parameters |
|
---------- |
|
t : array-like, shape=[n_points,] |
|
Times at which to compute points of the geodesics. |
|
""" |
|
tangent_vecs = torch.einsum("i,...k->...ik", t, shooting_tangent_vec) |
|
points_at_time_t = manifold.expmap(start_point.unsqueeze(-2), tangent_vecs) |
|
return points_at_time_t |
|
|
|
return path |
|
|