File size: 4,647 Bytes
55f226f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
import torch.nn.functional as F

from pytorch3d.ops import interpolate_face_attributes

from pytorch3d.renderer import (
	look_at_view_transform,
	FoVPerspectiveCameras, 
	AmbientLights,
	PointLights, 
	DirectionalLights, 
	Materials, 
	RasterizationSettings, 
	MeshRenderer, 
	MeshRasterizer,  
	SoftPhongShader,
	SoftSilhouetteShader,
	HardPhongShader,
	TexturesVertex,
	TexturesUV,
	Materials,

)
from pytorch3d.renderer.blending import BlendParams,hard_rgb_blend
from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties
from pytorch3d.renderer.mesh.shader import ShaderBase


def get_cos_angle(
	points, normals, camera_position
):
	'''
		calculate cosine similarity between view->surface and surface normal.
	'''

	if points.shape != normals.shape:
		msg = "Expected points and normals to have the same shape: got %r, %r"
		raise ValueError(msg % (points.shape, normals.shape))

	# Ensure all inputs have same batch dimension as points
	matched_tensors = convert_to_tensors_and_broadcast(
		points, camera_position, device=points.device
	)
	_, camera_position = matched_tensors

	# Reshape direction and color so they have all the arbitrary intermediate
	# dimensions as points. Assume first dim = batch dim and last dim = 3.
	points_dims = points.shape[1:-1]
	expand_dims = (-1,) + (1,) * len(points_dims)

	if camera_position.shape != normals.shape:
		camera_position = camera_position.view(expand_dims + (3,))

	normals = F.normalize(normals, p=2, dim=-1, eps=1e-6)

	# Calculate the cosine value.
	view_direction = camera_position - points
	view_direction = F.normalize(view_direction, p=2, dim=-1, eps=1e-6)
	cos_angle = torch.sum(view_direction * normals, dim=-1, keepdim=True)
	cos_angle = cos_angle.clamp(0, 1)

	# Cosine of the angle between the reflected light ray and the viewer
	return cos_angle


def _geometry_shading_with_pixels(
	meshes, fragments, lights, cameras, materials, texels
):
	"""
	Render pixel space vertex position, normal(world), depth, and cos angle

	Args:
		meshes: Batch of meshes
		fragments: Fragments named tuple with the outputs of rasterization
		lights: Lights class containing a batch of lights
		cameras: Cameras class containing a batch of cameras
		materials: Materials class containing a batch of material properties
		texels: texture per pixel of shape (N, H, W, K, 3)

	Returns:
		colors: (N, H, W, K, 3)
		pixel_coords: (N, H, W, K, 3), camera coordinates of each intersection.
	"""
	verts = meshes.verts_packed()  # (V, 3)
	faces = meshes.faces_packed()  # (F, 3)
	vertex_normals = meshes.verts_normals_packed()  # (V, 3)
	faces_verts = verts[faces]
	faces_normals = vertex_normals[faces]
	pixel_coords_in_camera = interpolate_face_attributes(
		fragments.pix_to_face, fragments.bary_coords, faces_verts
	)
	pixel_normals = interpolate_face_attributes(
		fragments.pix_to_face, fragments.bary_coords, faces_normals
	)

	cos_angles = get_cos_angle(pixel_coords_in_camera, pixel_normals, cameras.get_camera_center())

	return pixel_coords_in_camera, pixel_normals, fragments.zbuf[...,None], cos_angles 


class HardGeometryShader(ShaderBase):
	"""
	renders common geometric informations.
	
	
	"""

	def forward(self, fragments, meshes, **kwargs):
		cameras = super()._get_cameras(**kwargs)
		texels = self.texel_from_uv(fragments, meshes)

		lights = kwargs.get("lights", self.lights)
		materials = kwargs.get("materials", self.materials)
		blend_params = kwargs.get("blend_params", self.blend_params)
		verts, normals, depths, cos_angles = _geometry_shading_with_pixels(
			meshes=meshes,
			fragments=fragments,
			texels=texels,
			lights=lights,
			cameras=cameras,
			materials=materials,
		)
		verts = hard_rgb_blend(verts, fragments, blend_params)
		normals = hard_rgb_blend(normals, fragments, blend_params)
		depths = hard_rgb_blend(depths, fragments, blend_params)
		cos_angles = hard_rgb_blend(cos_angles, fragments, blend_params)
		texels = hard_rgb_blend(texels, fragments, blend_params)
		return verts, normals, depths, cos_angles, texels, fragments

	def texel_from_uv(self, fragments, meshes):
		texture_tmp = meshes.textures
		maps_tmp = texture_tmp.maps_padded()
		uv_color = [ [[1,0],[1,1]],[[0,0],[0,1]] ]
		uv_color = torch.FloatTensor(uv_color).to(maps_tmp[0].device).type(maps_tmp[0].dtype)
		uv_texture = TexturesUV([uv_color.clone() for t in maps_tmp], texture_tmp.faces_uvs_padded(), texture_tmp.verts_uvs_padded(), sampling_mode="bilinear")
		meshes.textures = uv_texture
		texels = meshes.sample_textures(fragments)
		meshes.textures = texture_tmp
		texels  = torch.cat((texels, texels[...,-1:]*0), dim=-1)
		return texels