File size: 3,917 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
from math import isqrt
import torch
from e3nn.o3 import matrix_to_angles, wigner_D
from einops import einsum
from jaxtyping import Float
from torch import Tensor
def rotate_sh(
sh_coefficients: Float[Tensor, "*#batch n"],
rotations: Float[Tensor, "*#batch 3 3"],
) -> Float[Tensor, "*batch n"]:
device = sh_coefficients.device
dtype = sh_coefficients.dtype
# change the basis from YZX -> XYZ to fit the convention of e3nn
P = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]],
dtype=sh_coefficients.dtype, device=sh_coefficients.device)
inversed_P = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0], ],
dtype=sh_coefficients.dtype, device=sh_coefficients.device)
permuted_rotation_matrix = inversed_P @ rotations @ P
*_, n = sh_coefficients.shape
alpha, beta, gamma = matrix_to_angles(permuted_rotation_matrix)
result = []
for degree in range(isqrt(n)):
with torch.device(device):
sh_rotations = wigner_D(degree, alpha, -beta, gamma).type(dtype)
sh_rotated = einsum(
sh_rotations,
sh_coefficients[..., degree**2 : (degree + 1) ** 2],
"... i j, ... j -> ... i",
)
result.append(sh_rotated)
return torch.cat(result, dim=-1)
# def rotate_sh(
# sh_coefficients: Float[Tensor, "*#batch n"],
# rotations: Float[Tensor, "*#batch 3 3"],
# ) -> Float[Tensor, "*batch n"]:
# device = sh_coefficients.device
# dtype = sh_coefficients.dtype
#
# *_, n = sh_coefficients.shape
# alpha, beta, gamma = matrix_to_angles(rotations)
# result = []
# for degree in range(isqrt(n)):
# with torch.device(device):
# sh_rotations = wigner_D(degree, alpha, beta, gamma).type(dtype)
# sh_rotated = einsum(
# sh_rotations,
# sh_coefficients[..., degree**2 : (degree + 1) ** 2],
# "... i j, ... j -> ... i",
# )
# result.append(sh_rotated)
#
# return torch.cat(result, dim=-1)
if __name__ == "__main__":
from pathlib import Path
import matplotlib.pyplot as plt
from e3nn.o3 import spherical_harmonics
from matplotlib import cm
from scipy.spatial.transform.rotation import Rotation as R
device = torch.device("cuda")
# Generate random spherical harmonics coefficients.
degree = 4
coefficients = torch.rand((degree + 1) ** 2, dtype=torch.float32, device=device)
def plot_sh(sh_coefficients, path: Path) -> None:
phi = torch.linspace(0, torch.pi, 100, device=device)
theta = torch.linspace(0, 2 * torch.pi, 100, device=device)
phi, theta = torch.meshgrid(phi, theta, indexing="xy")
x = torch.sin(phi) * torch.cos(theta)
y = torch.sin(phi) * torch.sin(theta)
z = torch.cos(phi)
xyz = torch.stack([x, y, z], dim=-1)
sh = spherical_harmonics(list(range(degree + 1)), xyz, True)
result = einsum(sh, sh_coefficients, "... n, n -> ...")
result = (result - result.min()) / (result.max() - result.min())
# Set the aspect ratio to 1 so our sphere looks spherical
fig = plt.figure(figsize=plt.figaspect(1.0))
ax = fig.add_subplot(111, projection="3d")
ax.plot_surface(
x.cpu().numpy(),
y.cpu().numpy(),
z.cpu().numpy(),
rstride=1,
cstride=1,
facecolors=cm.seismic(result.cpu().numpy()),
)
# Turn off the axis planes
ax.set_axis_off()
path.parent.mkdir(exist_ok=True, parents=True)
plt.savefig(path)
for i, angle in enumerate(torch.linspace(0, 2 * torch.pi, 30)):
rotation = torch.tensor(
R.from_euler("x", angle.item()).as_matrix(), device=device
)
plot_sh(rotate_sh(coefficients, rotation), Path(f"sh_rotation/{i:0>3}.png"))
print("Done!")
|