SeqTex / utils /rasterize.py
yuanze1024's picture
init space
1d5bb62
raw
history blame
7.31 kB
import nvdiffrast.torch as dr
import torch
from torch import Tensor
from jaxtyping import Float, Integer
from typing import Union, Tuple
class NVDiffRasterizerContext:
def __init__(self, context_type: str, device: torch.device) -> None:
self.device = device
self.ctx = self.initialize_context(context_type, device)
def initialize_context(
self, context_type: str, device: torch.device
) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]:
if context_type == "gl":
return dr.RasterizeGLContext(device=device)
elif context_type == "cuda":
return dr.RasterizeCudaContext(device=device)
else:
raise ValueError(f"Unknown rasterizer context type: {context_type}")
def vertex_transform(
self, verts: Float[Tensor, "Nv 3"], mvp_mtx: Float[Tensor, "B 4 4"]
) -> Float[Tensor, "B Nv 4"]:
with torch.amp.autocast("cuda", enabled=False):
verts_homo = torch.cat(
[verts, torch.ones([verts.shape[0], 1]).to(verts)], dim=-1
)
verts_clip = torch.matmul(verts_homo, mvp_mtx.permute(0, 2, 1))
return verts_clip
def rasterize(
self,
pos: Float[Tensor, "B Nv 4"],
tri: Integer[Tensor, "Nf 3"],
resolution: Union[int, Tuple[int, int]],
):
# rasterize in instance mode (single topology)
return dr.rasterize(self.ctx, pos.float(), tri.int(), resolution, grad_db=True)
def rasterize_one(
self,
pos: Float[Tensor, "Nv 4"],
tri: Integer[Tensor, "Nf 3"],
resolution: Union[int, Tuple[int, int]],
):
# rasterize one single mesh under a single viewpoint
rast, rast_db = self.rasterize(pos[None, ...], tri, resolution)
return rast[0], rast_db[0]
def antialias(
self,
color: Float[Tensor, "B H W C"],
rast: Float[Tensor, "B H W 4"],
pos: Float[Tensor, "B Nv 4"],
tri: Integer[Tensor, "Nf 3"],
) -> Float[Tensor, "B H W C"]:
return dr.antialias(color.float(), rast, pos.float(), tri.int())
def interpolate(
self,
attr: Float[Tensor, "B Nv C"],
rast: Float[Tensor, "B H W 4"],
tri: Integer[Tensor, "Nf 3"],
rast_db=None,
diff_attrs=None,
) -> Float[Tensor, "B H W C"]:
return dr.interpolate(
attr.float(), rast, tri.int(), rast_db=rast_db, diff_attrs=diff_attrs
)
def interpolate_one(
self,
attr: Float[Tensor, "Nv C"],
rast: Float[Tensor, "B H W 4"],
tri: Integer[Tensor, "Nf 3"],
rast_db=None,
diff_attrs=None,
) -> Float[Tensor, "B H W C"]:
return self.interpolate(attr[None, ...], rast, tri, rast_db, diff_attrs)
def texture_map_to_rgb(tex_map, uv_coordinates):
return dr.texture(tex_map.float(), uv_coordinates)
def render_rgb_from_texture_mesh_with_mask(
ctx,
mesh,
tex_map: Float[Tensor, "1 H W C"],
mvp_matrix: Float[Tensor, "batch 4 4"],
image_height: int,
image_width: int,
background_color: Tensor = torch.tensor([0.0, 0.0, 0.0]),
):
batch_size = mvp_matrix.shape[0]
tex_map = tex_map.contiguous()
if tex_map.dim() == 3:
tex_map = tex_map.unsqueeze(0) # Add batch dimension if missing
vertex_positions_clip = ctx.vertex_transform(mesh.v_pos, mvp_matrix)
rasterized_output, _ = ctx.rasterize(vertex_positions_clip, mesh.t_pos_idx, (image_height, image_width))
mask = rasterized_output[..., 3:] > 0
mask_antialiased = ctx.antialias(mask.float(), rasterized_output, vertex_positions_clip, mesh.t_pos_idx)
interpolated_texture_coords, _ = ctx.interpolate_one(mesh._v_tex, rasterized_output, mesh._t_tex_idx)
rgb_foreground = texture_map_to_rgb(tex_map.float(), interpolated_texture_coords)
rgb_foreground_batched = torch.zeros(batch_size, image_height, image_width, 3).to(rgb_foreground)
rgb_background_batched = torch.zeros(batch_size, image_height, image_width, 3).to(rgb_foreground)
rgb_background_batched += background_color.view(1, 1, 1, 3).to(rgb_foreground)
selector = mask[..., 0]
rgb_foreground_batched[selector] = rgb_foreground[selector]
# Use the anti-aliased mask for blending
final_rgb = torch.lerp(rgb_background_batched, rgb_foreground_batched, mask_antialiased)
final_rgb_aa = ctx.antialias(final_rgb, rasterized_output, vertex_positions_clip, mesh.t_pos_idx)
return final_rgb_aa, selector
def render_geo_from_mesh(ctx, mesh, mvp_matrix, image_height, image_width):
device = mvp_matrix.device
vertex_positions_clip = ctx.vertex_transform(mesh.v_pos.to(device), mvp_matrix)
rasterized_output, _ = ctx.rasterize(vertex_positions_clip, mesh.t_pos_idx.to(device), (image_height, image_width))
interpolated_positions, _ = ctx.interpolate_one(mesh.v_pos.to(device), rasterized_output, mesh.t_pos_idx.to(device))
interpolated_normals, _ = ctx.interpolate_one(mesh.v_normal.to(device).contiguous(), rasterized_output, mesh.t_pos_idx.to(device))
mask = rasterized_output[..., 3:] > 0
mask_antialiased = ctx.antialias(mask.float(), rasterized_output, vertex_positions_clip, mesh.t_pos_idx.to(device))
batch_size = mvp_matrix.shape[0]
rgb_foreground_pos_batched = torch.zeros(batch_size, image_height, image_width, 3).to(interpolated_positions)
rgb_foreground_norm_batched = torch.zeros(batch_size, image_height, image_width, 3).to(interpolated_positions)
rgb_background_batched = torch.zeros(batch_size, image_height, image_width, 3).to(interpolated_positions)
selector = mask[..., 0]
rgb_foreground_pos_batched[selector] = interpolated_positions[selector]
rgb_foreground_norm_batched[selector] = interpolated_normals[selector]
final_pos_rgb = torch.lerp(rgb_background_batched, rgb_foreground_pos_batched, mask_antialiased)
final_norm_rgb = torch.lerp(rgb_background_batched, rgb_foreground_norm_batched, mask_antialiased)
final_pos_rgb_aa = ctx.antialias(final_pos_rgb, rasterized_output, vertex_positions_clip, mesh.t_pos_idx.to(device))
final_norm_rgb_aa = ctx.antialias(final_norm_rgb, rasterized_output, vertex_positions_clip, mesh.t_pos_idx.to(device))
return final_pos_rgb_aa, final_norm_rgb_aa, mask_antialiased
def rasterize_position_and_normal_maps(ctx, mesh, rasterize_height, rasterize_width):
device = ctx.device
# Convert mesh data to torch tensors
mesh_v = mesh.v_pos.to(device)
mesh_f = mesh.t_pos_idx.to(device)
uvs_tensor = mesh._v_tex.to(device)
indices_tensor = mesh._t_tex_idx.to(device)
normal_v = mesh.v_normal.to(device).contiguous()
# Interpolate mesh data
uv_clip = uvs_tensor[None, ...] * 2.0 - 1.0
uv_clip_padded = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., :1]), torch.ones_like(uv_clip[..., :1])), dim=-1)
rasterized_output, _ = ctx.rasterize(uv_clip_padded, indices_tensor.int(), (rasterize_height, rasterize_width))
# Interpolate positions.
position_map, _ = ctx.interpolate_one(mesh_v, rasterized_output, mesh_f.int())
normal_map, _ = ctx.interpolate_one(normal_v, rasterized_output, mesh_f.int())
rasterization_mask = rasterized_output[..., 3:4] > 0
return position_map, normal_map, rasterization_mask