Spaces:
Running
on
Zero
Running
on
Zero
| from functools import partial | |
| from typing import Any, Dict, Optional, Sequence, Tuple, Union | |
| import torch | |
| from shap_e.models.nerf.model import NeRFModel | |
| from shap_e.models.nerf.ray import RayVolumeIntegral, StratifiedRaySampler, render_rays | |
| from shap_e.models.nn.meta import subdict | |
| from shap_e.models.nn.utils import to_torch | |
| from shap_e.models.query import Query | |
| from shap_e.models.renderer import RayRenderer, render_views_from_rays | |
| from shap_e.models.stf.base import Model | |
| from shap_e.models.stf.renderer import STFRendererBase, render_views_from_stf | |
| from shap_e.models.volume import BoundingBoxVolume, Volume | |
| from shap_e.rendering.blender.constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR | |
| from shap_e.util.collections import AttrDict | |
| class NeRSTFRenderer(RayRenderer, STFRendererBase): | |
| def __init__( | |
| self, | |
| sdf: Optional[Model], | |
| tf: Optional[Model], | |
| nerstf: Optional[Model], | |
| void: NeRFModel, | |
| volume: Volume, | |
| grid_size: int, | |
| n_coarse_samples: int, | |
| n_fine_samples: int, | |
| importance_sampling_options: Optional[Dict[str, Any]] = None, | |
| separate_shared_samples: bool = False, | |
| texture_channels: Sequence[str] = ("R", "G", "B"), | |
| channel_scale: Sequence[float] = (255.0, 255.0, 255.0), | |
| ambient_color: Union[float, Tuple[float]] = BASIC_AMBIENT_COLOR, | |
| diffuse_color: Union[float, Tuple[float]] = BASIC_DIFFUSE_COLOR, | |
| specular_color: Union[float, Tuple[float]] = 0.0, | |
| output_srgb: bool = True, | |
| device: torch.device = torch.device("cuda"), | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| assert isinstance(volume, BoundingBoxVolume), "cannot sample points in unknown volume" | |
| assert (nerstf is not None) ^ (sdf is not None and tf is not None) | |
| self.sdf = sdf | |
| self.tf = tf | |
| self.nerstf = nerstf | |
| self.void = void | |
| self.volume = volume | |
| self.grid_size = grid_size | |
| self.n_coarse_samples = n_coarse_samples | |
| self.n_fine_samples = n_fine_samples | |
| self.importance_sampling_options = AttrDict(importance_sampling_options or {}) | |
| self.separate_shared_samples = separate_shared_samples | |
| self.texture_channels = texture_channels | |
| self.channel_scale = to_torch(channel_scale).to(device) | |
| self.ambient_color = ambient_color | |
| self.diffuse_color = diffuse_color | |
| self.specular_color = specular_color | |
| self.output_srgb = output_srgb | |
| self.device = device | |
| self.to(device) | |
| def _query( | |
| self, | |
| query: Query, | |
| params: AttrDict[str, torch.Tensor], | |
| options: AttrDict[str, Any], | |
| ) -> AttrDict: | |
| no_dir_query = query.copy() | |
| no_dir_query.direction = None | |
| if options.get("rendering_mode", "stf") == "stf": | |
| assert query.direction is None | |
| if self.nerstf is not None: | |
| sdf = tf = self.nerstf( | |
| query, | |
| params=subdict(params, "nerstf"), | |
| options=options, | |
| ) | |
| else: | |
| sdf = self.sdf(no_dir_query, params=subdict(params, "sdf"), options=options) | |
| tf = self.tf(query, params=subdict(params, "tf"), options=options) | |
| return AttrDict( | |
| density=sdf.density, | |
| signed_distance=sdf.signed_distance, | |
| channels=tf.channels, | |
| aux_losses=dict(), | |
| ) | |
| def render_rays( | |
| self, | |
| batch: AttrDict, | |
| params: Optional[Dict] = None, | |
| options: Optional[AttrDict] = None, | |
| ) -> AttrDict: | |
| """ | |
| :param batch: has | |
| - rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray. | |
| :param options: Optional[Dict] | |
| """ | |
| params = self.update(params) | |
| options = AttrDict() if options is None else AttrDict(options) | |
| # Necessary to tell the TF to use specific NeRF channels. | |
| options.rendering_mode = "nerf" | |
| model = partial(self._query, params=params, options=options) | |
| # First, render rays with coarse, stratified samples. | |
| options.nerf_level = "coarse" | |
| parts = [ | |
| RayVolumeIntegral( | |
| model=model, | |
| volume=self.volume, | |
| sampler=StratifiedRaySampler(), | |
| n_samples=self.n_coarse_samples, | |
| ), | |
| ] | |
| coarse_results, samplers, coarse_raw_outputs = render_rays( | |
| batch.rays, | |
| parts, | |
| self.void, | |
| shared=not self.separate_shared_samples, | |
| render_with_direction=options.render_with_direction, | |
| importance_sampling_options=self.importance_sampling_options, | |
| ) | |
| # Then, render with additional importance-weighted ray samples. | |
| options.nerf_level = "fine" | |
| parts = [ | |
| RayVolumeIntegral( | |
| model=model, | |
| volume=self.volume, | |
| sampler=samplers[0], | |
| n_samples=self.n_fine_samples, | |
| ), | |
| ] | |
| fine_results, _, raw_outputs = render_rays( | |
| batch.rays, | |
| parts, | |
| self.void, | |
| shared=not self.separate_shared_samples, | |
| prev_raw_outputs=coarse_raw_outputs, | |
| render_with_direction=options.render_with_direction, | |
| ) | |
| raw = raw_outputs[0] | |
| aux_losses = fine_results.output.aux_losses.copy() | |
| if self.separate_shared_samples: | |
| for key, val in coarse_results.output.aux_losses.items(): | |
| aux_losses[key + "_coarse"] = val | |
| channels = fine_results.output.channels | |
| shape = [1] * (channels.ndim - 1) + [len(self.texture_channels)] | |
| channels = channels * self.channel_scale.view(*shape) | |
| res = AttrDict( | |
| channels=channels, | |
| transmittance=fine_results.transmittance, | |
| raw_signed_distance=raw.signed_distance, | |
| raw_density=raw.density, | |
| distances=fine_results.output.distances, | |
| t0=fine_results.volume_range.t0, | |
| t1=fine_results.volume_range.t1, | |
| intersected=fine_results.volume_range.intersected, | |
| aux_losses=aux_losses, | |
| ) | |
| if self.separate_shared_samples: | |
| res.update( | |
| dict( | |
| channels_coarse=( | |
| coarse_results.output.channels * self.channel_scale.view(*shape) | |
| ), | |
| distances_coarse=coarse_results.output.distances, | |
| transmittance_coarse=coarse_results.transmittance, | |
| ) | |
| ) | |
| return res | |
| def render_views( | |
| self, | |
| batch: AttrDict, | |
| params: Optional[Dict] = None, | |
| options: Optional[AttrDict] = None, | |
| ) -> AttrDict: | |
| """ | |
| Returns a backproppable rendering of a view | |
| :param batch: contains either ["poses", "camera"], or ["cameras"]. Can | |
| optionally contain any of ["height", "width", "query_batch_size"] | |
| :param params: Meta parameters | |
| contains rendering_mode in ["stf", "nerf"] | |
| :param options: controls checkpointing, caching, and rendering. | |
| Can provide a `rendering_mode` in ["stf", "nerf"] | |
| """ | |
| params = self.update(params) | |
| options = AttrDict() if options is None else AttrDict(options) | |
| if options.cache is None: | |
| created_cache = True | |
| options.cache = AttrDict() | |
| else: | |
| created_cache = False | |
| rendering_mode = options.get("rendering_mode", "stf") | |
| if rendering_mode == "nerf": | |
| output = render_views_from_rays( | |
| self.render_rays, | |
| batch, | |
| params=params, | |
| options=options, | |
| device=self.device, | |
| ) | |
| elif rendering_mode == "stf": | |
| sdf_fn = tf_fn = nerstf_fn = None | |
| if self.nerstf is not None: | |
| nerstf_fn = partial( | |
| self.nerstf.forward_batched, | |
| params=subdict(params, "nerstf"), | |
| options=options, | |
| ) | |
| else: | |
| sdf_fn = partial( | |
| self.sdf.forward_batched, | |
| params=subdict(params, "sdf"), | |
| options=options, | |
| ) | |
| tf_fn = partial( | |
| self.tf.forward_batched, | |
| params=subdict(params, "tf"), | |
| options=options, | |
| ) | |
| output = render_views_from_stf( | |
| batch, | |
| options, | |
| sdf_fn=sdf_fn, | |
| tf_fn=tf_fn, | |
| nerstf_fn=nerstf_fn, | |
| volume=self.volume, | |
| grid_size=self.grid_size, | |
| channel_scale=self.channel_scale, | |
| texture_channels=self.texture_channels, | |
| ambient_color=self.ambient_color, | |
| diffuse_color=self.diffuse_color, | |
| specular_color=self.specular_color, | |
| output_srgb=self.output_srgb, | |
| device=self.device, | |
| ) | |
| else: | |
| raise NotImplementedError | |
| if created_cache: | |
| del options["cache"] | |
| return output | |
| def get_signed_distance( | |
| self, | |
| query: Query, | |
| params: Dict[str, torch.Tensor], | |
| options: AttrDict[str, Any], | |
| ) -> torch.Tensor: | |
| if self.sdf is not None: | |
| return self.sdf(query, params=subdict(params, "sdf"), options=options).signed_distance | |
| assert self.nerstf is not None | |
| return self.nerstf(query, params=subdict(params, "nerstf"), options=options).signed_distance | |
| def get_texture( | |
| self, | |
| query: Query, | |
| params: Dict[str, torch.Tensor], | |
| options: AttrDict[str, Any], | |
| ) -> torch.Tensor: | |
| if self.tf is not None: | |
| return self.tf(query, params=subdict(params, "tf"), options=options).channels | |
| assert self.nerstf is not None | |
| return self.nerstf(query, params=subdict(params, "nerstf"), options=options).channels | |