Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| from functools import partial | |
| from typing import Any, Dict, Optional | |
| import torch | |
| from shap_e.models.nn.meta import subdict | |
| from shap_e.models.renderer import RayRenderer | |
| from shap_e.models.volume import Volume | |
| from shap_e.util.collections import AttrDict | |
| from .model import NeRFModel | |
| from .ray import RayVolumeIntegral, StratifiedRaySampler, render_rays | |
| class TwoStepNeRFRenderer(RayRenderer): | |
| """ | |
| Coarse and fine-grained rendering as proposed by NeRF. This class | |
| additionally supports background rendering like NeRF++. | |
| """ | |
| def __init__( | |
| self, | |
| n_coarse_samples: int, | |
| n_fine_samples: int, | |
| void_model: NeRFModel, | |
| fine_model: NeRFModel, | |
| volume: Volume, | |
| coarse_model: Optional[NeRFModel] = None, | |
| coarse_background_model: Optional[NeRFModel] = None, | |
| fine_background_model: Optional[NeRFModel] = None, | |
| outer_volume: Optional[Volume] = None, | |
| foreground_stratified_depth_sampling_mode: str = "linear", | |
| background_stratified_depth_sampling_mode: str = "linear", | |
| importance_sampling_options: Optional[Dict[str, Any]] = None, | |
| channel_scale: float = 255, | |
| device: torch.device = torch.device("cuda"), | |
| **kwargs, | |
| ): | |
| """ | |
| :param outer_volume: is where distant objects are encoded. | |
| """ | |
| super().__init__(**kwargs) | |
| if coarse_model is None: | |
| assert ( | |
| fine_background_model is None or coarse_background_model is None | |
| ), "models should be shared for both fg and bg" | |
| self.n_coarse_samples = n_coarse_samples | |
| self.n_fine_samples = n_fine_samples | |
| self.void_model = void_model | |
| self.coarse_model = coarse_model | |
| self.fine_model = fine_model | |
| self.volume = volume | |
| self.coarse_background_model = coarse_background_model | |
| self.fine_background_model = fine_background_model | |
| self.outer_volume = outer_volume | |
| self.foreground_stratified_depth_sampling_mode = foreground_stratified_depth_sampling_mode | |
| self.background_stratified_depth_sampling_mode = background_stratified_depth_sampling_mode | |
| self.importance_sampling_options = AttrDict(importance_sampling_options or {}) | |
| self.channel_scale = channel_scale | |
| self.device = device | |
| self.to(device) | |
| if self.coarse_background_model is not None: | |
| assert self.fine_background_model is not None | |
| assert self.outer_volume is not None | |
| def render_rays( | |
| self, | |
| batch: Dict, | |
| params: Optional[Dict] = None, | |
| options: Optional[Dict] = None, | |
| ) -> AttrDict: | |
| params = self.update(params) | |
| batch = AttrDict(batch) | |
| if options is None: | |
| options = AttrDict() | |
| options.setdefault("render_background", True) | |
| options.setdefault("render_with_direction", True) | |
| options.setdefault("n_coarse_samples", self.n_coarse_samples) | |
| options.setdefault("n_fine_samples", self.n_fine_samples) | |
| options.setdefault( | |
| "foreground_stratified_depth_sampling_mode", | |
| self.foreground_stratified_depth_sampling_mode, | |
| ) | |
| options.setdefault( | |
| "background_stratified_depth_sampling_mode", | |
| self.background_stratified_depth_sampling_mode, | |
| ) | |
| shared = self.coarse_model is None | |
| # First, render rays using the coarse models with stratified ray samples. | |
| coarse_model, coarse_key = ( | |
| (self.fine_model, "fine_model") if shared else (self.coarse_model, "coarse_model") | |
| ) | |
| coarse_model = partial( | |
| coarse_model, | |
| params=subdict(params, coarse_key), | |
| options=options, | |
| ) | |
| parts = [ | |
| RayVolumeIntegral( | |
| model=coarse_model, | |
| volume=self.volume, | |
| sampler=StratifiedRaySampler( | |
| depth_mode=options.foreground_stratified_depth_sampling_mode, | |
| ), | |
| n_samples=options.n_coarse_samples, | |
| ), | |
| ] | |
| if options.render_background and self.outer_volume is not None: | |
| coarse_background_model, coarse_background_key = ( | |
| (self.fine_background_model, "fine_background_model") | |
| if shared | |
| else (self.coarse_background_model, "coarse_background_model") | |
| ) | |
| coarse_background_model = partial( | |
| coarse_background_model, | |
| params=subdict(params, coarse_background_key), | |
| options=options, | |
| ) | |
| parts.append( | |
| RayVolumeIntegral( | |
| model=coarse_background_model, | |
| volume=self.outer_volume, | |
| sampler=StratifiedRaySampler( | |
| depth_mode=options.background_stratified_depth_sampling_mode, | |
| ), | |
| n_samples=options.n_coarse_samples, | |
| ) | |
| ) | |
| coarse_results, samplers, coarse_raw_outputs = render_rays( | |
| batch.rays, | |
| parts, | |
| partial(self.void_model, options=options), | |
| shared=shared, | |
| render_with_direction=options.render_with_direction, | |
| importance_sampling_options=AttrDict(self.importance_sampling_options), | |
| ) | |
| # Then, render rays using the fine models with importance-weighted ray samples. | |
| fine_model = partial( | |
| self.fine_model, | |
| params=subdict(params, "fine_model"), | |
| options=options, | |
| ) | |
| parts = [ | |
| RayVolumeIntegral( | |
| model=fine_model, | |
| volume=self.volume, | |
| sampler=samplers[0], | |
| n_samples=options.n_fine_samples, | |
| ), | |
| ] | |
| if options.render_background and self.outer_volume is not None: | |
| fine_background_model = partial( | |
| self.fine_background_model, | |
| params=subdict(params, "fine_background_model"), | |
| options=options, | |
| ) | |
| parts.append( | |
| RayVolumeIntegral( | |
| model=fine_background_model, | |
| volume=self.outer_volume, | |
| sampler=samplers[1], | |
| n_samples=options.n_fine_samples, | |
| ) | |
| ) | |
| fine_results, *_ = render_rays( | |
| batch.rays, | |
| parts, | |
| partial(self.void_model, options=options), | |
| shared=shared, | |
| prev_raw_outputs=coarse_raw_outputs, | |
| render_with_direction=options.render_with_direction, | |
| ) | |
| # Combine results | |
| aux_losses = fine_results.output.aux_losses.copy() | |
| for key, val in coarse_results.output.aux_losses.items(): | |
| aux_losses[key + "_coarse"] = val | |
| return AttrDict( | |
| channels=fine_results.output.channels * self.channel_scale, | |
| channels_coarse=coarse_results.output.channels * self.channel_scale, | |
| distances=fine_results.output.distances, | |
| transmittance=fine_results.transmittance, | |
| transmittance_coarse=coarse_results.transmittance, | |
| t0=fine_results.volume_range.t0, | |
| t1=fine_results.volume_range.t1, | |
| intersected=fine_results.volume_range.intersected, | |
| aux_losses=aux_losses, | |
| ) | |
| class OneStepNeRFRenderer(RayRenderer): | |
| """ | |
| Renders rays using stratified sampling only unlike vanilla NeRF. | |
| The same setup as NeRF++. | |
| """ | |
| def __init__( | |
| self, | |
| n_samples: int, | |
| void_model: NeRFModel, | |
| foreground_model: NeRFModel, | |
| volume: Volume, | |
| background_model: Optional[NeRFModel] = None, | |
| outer_volume: Optional[Volume] = None, | |
| foreground_stratified_depth_sampling_mode: str = "linear", | |
| background_stratified_depth_sampling_mode: str = "linear", | |
| channel_scale: float = 255, | |
| device: torch.device = torch.device("cuda"), | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.n_samples = n_samples | |
| self.void_model = void_model | |
| self.foreground_model = foreground_model | |
| self.volume = volume | |
| self.background_model = background_model | |
| self.outer_volume = outer_volume | |
| self.foreground_stratified_depth_sampling_mode = foreground_stratified_depth_sampling_mode | |
| self.background_stratified_depth_sampling_mode = background_stratified_depth_sampling_mode | |
| self.channel_scale = channel_scale | |
| self.device = device | |
| self.to(device) | |
| def render_rays( | |
| self, | |
| batch: Dict, | |
| params: Optional[Dict] = None, | |
| options: Optional[Dict] = None, | |
| ) -> AttrDict: | |
| params = self.update(params) | |
| batch = AttrDict(batch) | |
| if options is None: | |
| options = AttrDict() | |
| options.setdefault("render_background", True) | |
| options.setdefault("render_with_direction", True) | |
| options.setdefault("n_samples", self.n_samples) | |
| options.setdefault( | |
| "foreground_stratified_depth_sampling_mode", | |
| self.foreground_stratified_depth_sampling_mode, | |
| ) | |
| options.setdefault( | |
| "background_stratified_depth_sampling_mode", | |
| self.background_stratified_depth_sampling_mode, | |
| ) | |
| foreground_model = partial( | |
| self.foreground_model, | |
| params=subdict(params, "foreground_model"), | |
| options=options, | |
| ) | |
| parts = [ | |
| RayVolumeIntegral( | |
| model=foreground_model, | |
| volume=self.volume, | |
| sampler=StratifiedRaySampler( | |
| depth_mode=options.foreground_stratified_depth_sampling_mode | |
| ), | |
| n_samples=options.n_samples, | |
| ), | |
| ] | |
| if options.render_background and self.outer_volume is not None: | |
| background_model = partial( | |
| self.background_model, | |
| params=subdict(params, "background_model"), | |
| options=options, | |
| ) | |
| parts.append( | |
| RayVolumeIntegral( | |
| model=background_model, | |
| volume=self.outer_volume, | |
| sampler=StratifiedRaySampler( | |
| depth_mode=options.background_stratified_depth_sampling_mode | |
| ), | |
| n_samples=options.n_samples, | |
| ) | |
| ) | |
| results, *_ = render_rays( | |
| batch.rays, | |
| parts, | |
| self.void_model, | |
| render_with_direction=options.render_with_direction, | |
| ) | |
| return AttrDict( | |
| channels=results.output.channels * self.channel_scale, | |
| distances=results.output.distances, | |
| transmittance=results.transmittance, | |
| t0=results.volume_range.t0, | |
| t1=results.volume_range.t1, | |
| intersected=results.volume_range.intersected, | |
| aux_losses=results.output.aux_losses, | |
| ) | |