Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,979 Bytes
2ac1c2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
from typing import Optional
import torch
import pytorch3d
from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj
from pytorch3d.ops import interpolate_face_attributes
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
look_at_view_transform,
FoVPerspectiveCameras,
AmbientLights,
PointLights,
DirectionalLights,
Materials,
RasterizationSettings,
MeshRenderer,
MeshRasterizer,
SoftPhongShader,
SoftSilhouetteShader,
HardPhongShader,
TexturesVertex,
TexturesUV,
Materials,
)
from pytorch3d.renderer.blending import BlendParams, hard_rgb_blend
from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties
from pytorch3d.renderer.lighting import AmbientLights
from pytorch3d.renderer.materials import Materials
from pytorch3d.renderer.mesh.shader import ShaderBase
from pytorch3d.renderer.mesh.shading import _apply_lighting, flat_shading
from pytorch3d.renderer.mesh.rasterizer import Fragments
"""
Customized the original pytorch3d hard flat shader to support N channel flat shading
"""
class HardNChannelFlatShader(ShaderBase):
"""
Per face lighting - the lighting model is applied using the average face
position and the face normal. The blending function hard assigns
the color of the closest face for each pixel.
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = HardFlatShader(device=torch.device("cuda:0"))
"""
def __init__(
self,
device="cpu",
cameras: Optional[TensorProperties] = None,
lights: Optional[TensorProperties] = None,
materials: Optional[Materials] = None,
blend_params: Optional[BlendParams] = None,
channels: int = 3,
):
self.channels = channels
ones = ((1.0,) * channels,)
zeros = ((0.0,) * channels,)
if (
not isinstance(lights, AmbientLights)
or not lights.ambient_color.shape[-1] == channels
):
lights = AmbientLights(
ambient_color=ones,
device=device,
)
if not materials or not materials.ambient_color.shape[-1] == channels:
materials = Materials(
device=device,
diffuse_color=zeros,
ambient_color=ones,
specular_color=zeros,
shininess=0.0,
)
blend_params_new = BlendParams(background_color=(1.0,) * channels)
if not isinstance(blend_params, BlendParams):
blend_params = blend_params_new
else:
background_color_ = blend_params.background_color
if (
isinstance(background_color_, Sequence[float])
and not len(background_color_) == channels
):
blend_params = blend_params_new
if (
isinstance(background_color_, torch.Tensor)
and not background_color_.shape[-1] == channels
):
blend_params = blend_params_new
super().__init__(
device,
cameras,
lights,
materials,
blend_params,
)
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = super()._get_cameras(**kwargs)
texels = meshes.sample_textures(fragments)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
blend_params = kwargs.get("blend_params", self.blend_params)
colors = flat_shading(
meshes=meshes,
fragments=fragments,
texels=texels,
lights=lights,
cameras=cameras,
materials=materials,
)
images = hard_rgb_blend(colors, fragments, blend_params)
return images
|