Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,314 Bytes
1d5bb62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import nvdiffrast.torch as dr
import torch
from torch import Tensor
from jaxtyping import Float, Integer
from typing import Union, Tuple
class NVDiffRasterizerContext:
def __init__(self, context_type: str, device: torch.device) -> None:
self.device = device
self.ctx = self.initialize_context(context_type, device)
def initialize_context(
self, context_type: str, device: torch.device
) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]:
if context_type == "gl":
return dr.RasterizeGLContext(device=device)
elif context_type == "cuda":
return dr.RasterizeCudaContext(device=device)
else:
raise ValueError(f"Unknown rasterizer context type: {context_type}")
def vertex_transform(
self, verts: Float[Tensor, "Nv 3"], mvp_mtx: Float[Tensor, "B 4 4"]
) -> Float[Tensor, "B Nv 4"]:
with torch.amp.autocast("cuda", enabled=False):
verts_homo = torch.cat(
[verts, torch.ones([verts.shape[0], 1]).to(verts)], dim=-1
)
verts_clip = torch.matmul(verts_homo, mvp_mtx.permute(0, 2, 1))
return verts_clip
def rasterize(
self,
pos: Float[Tensor, "B Nv 4"],
tri: Integer[Tensor, "Nf 3"],
resolution: Union[int, Tuple[int, int]],
):
# rasterize in instance mode (single topology)
return dr.rasterize(self.ctx, pos.float(), tri.int(), resolution, grad_db=True)
def rasterize_one(
self,
pos: Float[Tensor, "Nv 4"],
tri: Integer[Tensor, "Nf 3"],
resolution: Union[int, Tuple[int, int]],
):
# rasterize one single mesh under a single viewpoint
rast, rast_db = self.rasterize(pos[None, ...], tri, resolution)
return rast[0], rast_db[0]
def antialias(
self,
color: Float[Tensor, "B H W C"],
rast: Float[Tensor, "B H W 4"],
pos: Float[Tensor, "B Nv 4"],
tri: Integer[Tensor, "Nf 3"],
) -> Float[Tensor, "B H W C"]:
return dr.antialias(color.float(), rast, pos.float(), tri.int())
def interpolate(
self,
attr: Float[Tensor, "B Nv C"],
rast: Float[Tensor, "B H W 4"],
tri: Integer[Tensor, "Nf 3"],
rast_db=None,
diff_attrs=None,
) -> Float[Tensor, "B H W C"]:
return dr.interpolate(
attr.float(), rast, tri.int(), rast_db=rast_db, diff_attrs=diff_attrs
)
def interpolate_one(
self,
attr: Float[Tensor, "Nv C"],
rast: Float[Tensor, "B H W 4"],
tri: Integer[Tensor, "Nf 3"],
rast_db=None,
diff_attrs=None,
) -> Float[Tensor, "B H W C"]:
return self.interpolate(attr[None, ...], rast, tri, rast_db, diff_attrs)
def texture_map_to_rgb(tex_map, uv_coordinates):
return dr.texture(tex_map.float(), uv_coordinates)
def render_rgb_from_texture_mesh_with_mask(
ctx,
mesh,
tex_map: Float[Tensor, "1 H W C"],
mvp_matrix: Float[Tensor, "batch 4 4"],
image_height: int,
image_width: int,
background_color: Tensor = torch.tensor([0.0, 0.0, 0.0]),
):
batch_size = mvp_matrix.shape[0]
tex_map = tex_map.contiguous()
if tex_map.dim() == 3:
tex_map = tex_map.unsqueeze(0) # Add batch dimension if missing
vertex_positions_clip = ctx.vertex_transform(mesh.v_pos, mvp_matrix)
rasterized_output, _ = ctx.rasterize(vertex_positions_clip, mesh.t_pos_idx, (image_height, image_width))
mask = rasterized_output[..., 3:] > 0
mask_antialiased = ctx.antialias(mask.float(), rasterized_output, vertex_positions_clip, mesh.t_pos_idx)
interpolated_texture_coords, _ = ctx.interpolate_one(mesh._v_tex, rasterized_output, mesh._t_tex_idx)
rgb_foreground = texture_map_to_rgb(tex_map.float(), interpolated_texture_coords)
rgb_foreground_batched = torch.zeros(batch_size, image_height, image_width, 3).to(rgb_foreground)
rgb_background_batched = torch.zeros(batch_size, image_height, image_width, 3).to(rgb_foreground)
rgb_background_batched += background_color.view(1, 1, 1, 3).to(rgb_foreground)
selector = mask[..., 0]
rgb_foreground_batched[selector] = rgb_foreground[selector]
# Use the anti-aliased mask for blending
final_rgb = torch.lerp(rgb_background_batched, rgb_foreground_batched, mask_antialiased)
final_rgb_aa = ctx.antialias(final_rgb, rasterized_output, vertex_positions_clip, mesh.t_pos_idx)
return final_rgb_aa, selector
def render_geo_from_mesh(ctx, mesh, mvp_matrix, image_height, image_width):
device = mvp_matrix.device
vertex_positions_clip = ctx.vertex_transform(mesh.v_pos.to(device), mvp_matrix)
rasterized_output, _ = ctx.rasterize(vertex_positions_clip, mesh.t_pos_idx.to(device), (image_height, image_width))
interpolated_positions, _ = ctx.interpolate_one(mesh.v_pos.to(device), rasterized_output, mesh.t_pos_idx.to(device))
interpolated_normals, _ = ctx.interpolate_one(mesh.v_normal.to(device).contiguous(), rasterized_output, mesh.t_pos_idx.to(device))
mask = rasterized_output[..., 3:] > 0
mask_antialiased = ctx.antialias(mask.float(), rasterized_output, vertex_positions_clip, mesh.t_pos_idx.to(device))
batch_size = mvp_matrix.shape[0]
rgb_foreground_pos_batched = torch.zeros(batch_size, image_height, image_width, 3).to(interpolated_positions)
rgb_foreground_norm_batched = torch.zeros(batch_size, image_height, image_width, 3).to(interpolated_positions)
rgb_background_batched = torch.zeros(batch_size, image_height, image_width, 3).to(interpolated_positions)
selector = mask[..., 0]
rgb_foreground_pos_batched[selector] = interpolated_positions[selector]
rgb_foreground_norm_batched[selector] = interpolated_normals[selector]
final_pos_rgb = torch.lerp(rgb_background_batched, rgb_foreground_pos_batched, mask_antialiased)
final_norm_rgb = torch.lerp(rgb_background_batched, rgb_foreground_norm_batched, mask_antialiased)
final_pos_rgb_aa = ctx.antialias(final_pos_rgb, rasterized_output, vertex_positions_clip, mesh.t_pos_idx.to(device))
final_norm_rgb_aa = ctx.antialias(final_norm_rgb, rasterized_output, vertex_positions_clip, mesh.t_pos_idx.to(device))
return final_pos_rgb_aa, final_norm_rgb_aa, mask_antialiased
def rasterize_position_and_normal_maps(ctx, mesh, rasterize_height, rasterize_width):
device = ctx.device
# Convert mesh data to torch tensors
mesh_v = mesh.v_pos.to(device)
mesh_f = mesh.t_pos_idx.to(device)
uvs_tensor = mesh._v_tex.to(device)
indices_tensor = mesh._t_tex_idx.to(device)
normal_v = mesh.v_normal.to(device).contiguous()
# Interpolate mesh data
uv_clip = uvs_tensor[None, ...] * 2.0 - 1.0
uv_clip_padded = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., :1]), torch.ones_like(uv_clip[..., :1])), dim=-1)
rasterized_output, _ = ctx.rasterize(uv_clip_padded, indices_tensor.int(), (rasterize_height, rasterize_width))
# Interpolate positions.
position_map, _ = ctx.interpolate_one(mesh_v, rasterized_output, mesh_f.int())
normal_map, _ = ctx.interpolate_one(normal_v, rasterized_output, mesh_f.int())
rasterization_mask = rasterized_output[..., 3:4] > 0
return position_map, normal_map, rasterization_mask |