Spaces:
Starting
Starting
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Optional | |
| import torch | |
| from jaxtyping import Float | |
| from torch import Tensor | |
| from spar3d.models.utils import BaseModule | |
| from .field import RENIField | |
| def _direction_from_coordinate( | |
| coordinate: Float[Tensor, "*B 2"], | |
| ) -> Float[Tensor, "*B 3"]: | |
| # OpenGL Convention | |
| # +X Right | |
| # +Y Up | |
| # +Z Backward | |
| u, v = coordinate.unbind(-1) | |
| theta = (2 * torch.pi * u) - torch.pi | |
| phi = torch.pi * v | |
| dir = torch.stack( | |
| [ | |
| theta.sin() * phi.sin(), | |
| phi.cos(), | |
| -1 * theta.cos() * phi.sin(), | |
| ], | |
| -1, | |
| ) | |
| return dir | |
| def _get_sample_coordinates( | |
| resolution: List[int], device: Optional[torch.device] = None | |
| ) -> Float[Tensor, "H W 2"]: | |
| return torch.stack( | |
| torch.meshgrid( | |
| (torch.arange(resolution[1], device=device) + 0.5) / resolution[1], | |
| (torch.arange(resolution[0], device=device) + 0.5) / resolution[0], | |
| indexing="xy", | |
| ), | |
| -1, | |
| ) | |
| class RENIEnvMap(BaseModule): | |
| class Config(BaseModule.Config): | |
| reni_config: dict = field(default_factory=dict) | |
| resolution: int = 128 | |
| cfg: Config | |
| def configure(self): | |
| self.field = RENIField(self.cfg.reni_config) | |
| resolution = (self.cfg.resolution, self.cfg.resolution * 2) | |
| sample_directions = _direction_from_coordinate( | |
| _get_sample_coordinates(resolution) | |
| ) | |
| self.img_shape = sample_directions.shape[:-1] | |
| sample_directions_flat = sample_directions.view(-1, 3) | |
| # Lastly these have y up but reni expects z up. Rotate 90 degrees on x axis | |
| sample_directions_flat = torch.stack( | |
| [ | |
| sample_directions_flat[:, 0], | |
| -sample_directions_flat[:, 2], | |
| sample_directions_flat[:, 1], | |
| ], | |
| -1, | |
| ) | |
| self.sample_directions = torch.nn.Parameter( | |
| sample_directions_flat, requires_grad=False | |
| ) | |
| def forward( | |
| self, | |
| latent_codes: Float[Tensor, "B latent_dim 3"], | |
| rotation: Optional[Float[Tensor, "B 3 3"]] = None, | |
| scale: Optional[Float[Tensor, "B"]] = None, | |
| ) -> Dict[str, Tensor]: | |
| return { | |
| k: v.view(latent_codes.shape[0], *self.img_shape, -1) | |
| for k, v in self.field( | |
| self.sample_directions.unsqueeze(0).repeat(latent_codes.shape[0], 1, 1), | |
| latent_codes, | |
| rotation=rotation, | |
| scale=scale, | |
| ).items() | |
| } | |