|
from dataclasses import dataclass, field
|
|
from typing import Dict, List, Optional
|
|
|
|
import torch
|
|
from jaxtyping import Float
|
|
from torch import Tensor
|
|
|
|
from spar3d.models.utils import BaseModule
|
|
|
|
from .field import RENIField
|
|
|
|
|
|
def _direction_from_coordinate(
|
|
coordinate: Float[Tensor, "*B 2"],
|
|
) -> Float[Tensor, "*B 3"]:
|
|
|
|
|
|
|
|
|
|
|
|
u, v = coordinate.unbind(-1)
|
|
theta = (2 * torch.pi * u) - torch.pi
|
|
phi = torch.pi * v
|
|
|
|
dir = torch.stack(
|
|
[
|
|
theta.sin() * phi.sin(),
|
|
phi.cos(),
|
|
-1 * theta.cos() * phi.sin(),
|
|
],
|
|
-1,
|
|
)
|
|
return dir
|
|
|
|
|
|
def _get_sample_coordinates(
|
|
resolution: List[int], device: Optional[torch.device] = None
|
|
) -> Float[Tensor, "H W 2"]:
|
|
return torch.stack(
|
|
torch.meshgrid(
|
|
(torch.arange(resolution[1], device=device) + 0.5) / resolution[1],
|
|
(torch.arange(resolution[0], device=device) + 0.5) / resolution[0],
|
|
indexing="xy",
|
|
),
|
|
-1,
|
|
)
|
|
|
|
|
|
class RENIEnvMap(BaseModule):
|
|
@dataclass
|
|
class Config(BaseModule.Config):
|
|
reni_config: dict = field(default_factory=dict)
|
|
resolution: int = 128
|
|
|
|
cfg: Config
|
|
|
|
def configure(self):
|
|
self.field = RENIField(self.cfg.reni_config)
|
|
resolution = (self.cfg.resolution, self.cfg.resolution * 2)
|
|
sample_directions = _direction_from_coordinate(
|
|
_get_sample_coordinates(resolution)
|
|
)
|
|
self.img_shape = sample_directions.shape[:-1]
|
|
|
|
sample_directions_flat = sample_directions.view(-1, 3)
|
|
|
|
sample_directions_flat = torch.stack(
|
|
[
|
|
sample_directions_flat[:, 0],
|
|
-sample_directions_flat[:, 2],
|
|
sample_directions_flat[:, 1],
|
|
],
|
|
-1,
|
|
)
|
|
self.sample_directions = torch.nn.Parameter(
|
|
sample_directions_flat, requires_grad=False
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
latent_codes: Float[Tensor, "B latent_dim 3"],
|
|
rotation: Optional[Float[Tensor, "B 3 3"]] = None,
|
|
scale: Optional[Float[Tensor, "B"]] = None,
|
|
) -> Dict[str, Tensor]:
|
|
return {
|
|
k: v.view(latent_codes.shape[0], *self.img_shape, -1)
|
|
for k, v in self.field(
|
|
self.sample_directions.unsqueeze(0).repeat(latent_codes.shape[0], 1, 1),
|
|
latent_codes,
|
|
rotation=rotation,
|
|
scale=scale,
|
|
).items()
|
|
}
|
|
|