File size: 3,979 Bytes
2ac1c2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from typing import Optional

import torch
import pytorch3d


from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj
from pytorch3d.ops import interpolate_face_attributes

from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
    look_at_view_transform,
    FoVPerspectiveCameras,
    AmbientLights,
    PointLights,
    DirectionalLights,
    Materials,
    RasterizationSettings,
    MeshRenderer,
    MeshRasterizer,
    SoftPhongShader,
    SoftSilhouetteShader,
    HardPhongShader,
    TexturesVertex,
    TexturesUV,
    Materials,
)
from pytorch3d.renderer.blending import BlendParams, hard_rgb_blend
from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties

from pytorch3d.renderer.lighting import AmbientLights
from pytorch3d.renderer.materials import Materials
from pytorch3d.renderer.mesh.shader import ShaderBase
from pytorch3d.renderer.mesh.shading import _apply_lighting, flat_shading
from pytorch3d.renderer.mesh.rasterizer import Fragments


"""
	Customized the original pytorch3d hard flat shader to support N channel flat shading
"""


class HardNChannelFlatShader(ShaderBase):
    """
    Per face lighting - the lighting model is applied using the average face
    position and the face normal. The blending function hard assigns
    the color of the closest face for each pixel.

    To use the default values, simply initialize the shader with the desired
    device e.g.

    .. code-block::

            shader = HardFlatShader(device=torch.device("cuda:0"))
    """

    def __init__(
        self,
        device="cpu",
        cameras: Optional[TensorProperties] = None,
        lights: Optional[TensorProperties] = None,
        materials: Optional[Materials] = None,
        blend_params: Optional[BlendParams] = None,
        channels: int = 3,
    ):
        self.channels = channels
        ones = ((1.0,) * channels,)
        zeros = ((0.0,) * channels,)

        if (
            not isinstance(lights, AmbientLights)
            or not lights.ambient_color.shape[-1] == channels
        ):
            lights = AmbientLights(
                ambient_color=ones,
                device=device,
            )

        if not materials or not materials.ambient_color.shape[-1] == channels:
            materials = Materials(
                device=device,
                diffuse_color=zeros,
                ambient_color=ones,
                specular_color=zeros,
                shininess=0.0,
            )

        blend_params_new = BlendParams(background_color=(1.0,) * channels)
        if not isinstance(blend_params, BlendParams):
            blend_params = blend_params_new
        else:
            background_color_ = blend_params.background_color
            if (
                isinstance(background_color_, Sequence[float])
                and not len(background_color_) == channels
            ):
                blend_params = blend_params_new
            if (
                isinstance(background_color_, torch.Tensor)
                and not background_color_.shape[-1] == channels
            ):
                blend_params = blend_params_new

        super().__init__(
            device,
            cameras,
            lights,
            materials,
            blend_params,
        )

    def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
        cameras = super()._get_cameras(**kwargs)
        texels = meshes.sample_textures(fragments)
        lights = kwargs.get("lights", self.lights)
        materials = kwargs.get("materials", self.materials)
        blend_params = kwargs.get("blend_params", self.blend_params)
        colors = flat_shading(
            meshes=meshes,
            fragments=fragments,
            texels=texels,
            lights=lights,
            cameras=cameras,
            materials=materials,
        )
        images = hard_rgb_blend(colors, fragments, blend_params)
        return images