File size: 3,330 Bytes
55f226f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from typing import Optional

import torch
import pytorch3d


from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj
from pytorch3d.ops import interpolate_face_attributes

from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
	look_at_view_transform,
	FoVPerspectiveCameras, 
	AmbientLights,
	PointLights, 
	DirectionalLights, 
	Materials, 
	RasterizationSettings, 
	MeshRenderer, 
	MeshRasterizer,  
	SoftPhongShader,
	SoftSilhouetteShader,
	HardPhongShader,
	TexturesVertex,
	TexturesUV,
	Materials,

)
from pytorch3d.renderer.blending import BlendParams,hard_rgb_blend
from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties

from pytorch3d.renderer.lighting import AmbientLights
from pytorch3d.renderer.materials import Materials
from pytorch3d.renderer.mesh.shader import ShaderBase
from pytorch3d.renderer.mesh.shading import _apply_lighting, flat_shading
from pytorch3d.renderer.mesh.rasterizer import Fragments


'''
	Customized the original pytorch3d hard flat shader to support N channel flat shading
'''
class HardNChannelFlatShader(ShaderBase):
	"""
	Per face lighting - the lighting model is applied using the average face
	position and the face normal. The blending function hard assigns
	the color of the closest face for each pixel.

	To use the default values, simply initialize the shader with the desired
	device e.g.

	.. code-block::

		shader = HardFlatShader(device=torch.device("cuda:0"))
	"""

	def __init__(
		self,
		device = "cpu",
		cameras: Optional[TensorProperties] = None,
		lights: Optional[TensorProperties] = None,
		materials: Optional[Materials] = None,
		blend_params: Optional[BlendParams] = None,
		channels: int = 3,
	):
		self.channels = channels
		ones = ((1.0,)*channels,)
		zeros = ((0.0,)*channels,)
		
		if not isinstance(lights, AmbientLights) or not lights.ambient_color.shape[-1] == channels:
			lights = AmbientLights(
				ambient_color=ones,
				device=device,
			)

		if not materials or not materials.ambient_color.shape[-1] == channels:
			materials = Materials(
				device=device,
				diffuse_color=zeros,
				ambient_color=ones,
				specular_color=zeros,
				shininess=0.0,
			)

		blend_params_new = BlendParams(background_color=(1.0,)*channels)
		if not isinstance(blend_params, BlendParams):
			blend_params = blend_params_new
		else:
			background_color_ = blend_params.background_color
			if isinstance(background_color_, Sequence[float]) and not len(background_color_) == channels:
				blend_params = blend_params_new
			if isinstance(background_color_, torch.Tensor) and not background_color_.shape[-1] == channels:
				blend_params = blend_params_new

		super().__init__(
			device,
			cameras,
			lights,
			materials,
			blend_params,
		)
		

	def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
		cameras = super()._get_cameras(**kwargs)
		texels = meshes.sample_textures(fragments)
		lights = kwargs.get("lights", self.lights)
		materials = kwargs.get("materials", self.materials)
		blend_params = kwargs.get("blend_params", self.blend_params)
		colors = flat_shading(
			meshes=meshes,
			fragments=fragments,
			texels=texels,
			lights=lights,
			cameras=cameras,
			materials=materials,
		)
		images = hard_rgb_blend(colors, fragments, blend_params)
		return images