whyun13's picture
Upload folder using huggingface_hub
882f6e2 verified
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
from typing import List, Dict
import torch as th
import torch.nn as nn
from pytorch3d.renderer import (
RasterizationSettings,
MeshRasterizer,
)
from pytorch3d.structures import Meshes
from pytorch3d.renderer.mesh.textures import TexturesUV
from pytorch3d.utils import cameras_from_opencv_projection
class RenderLayer(nn.Module):
def __init__(self, h, w, vi, vt, vti, flip_uvs=False):
super().__init__()
self.register_buffer("vi", vi, persistent=False)
self.register_buffer("vt", vt, persistent=False)
self.register_buffer("vti", vti, persistent=False)
raster_settings = RasterizationSettings(image_size=(h, w))
self.rasterizer = MeshRasterizer(raster_settings=raster_settings)
self.flip_uvs = flip_uvs
image_size = th.as_tensor([h, w], dtype=th.int32)
self.register_buffer("image_size", image_size)
def forward(self, verts: th.Tensor, tex: th.Tensor, K: th.Tensor, Rt: th.Tensor, background: th.Tensor = None, output_filters: List[str] = None):
assert output_filters is None
assert background is None
device = verts.device # Get device info
B = verts.shape[0]
image_size = th.repeat_interleave(self.image_size[None], B, dim=0).to(device)
cameras = cameras_from_opencv_projection(Rt[:,:,:3], Rt[:,:3,3], K, image_size)
faces = self.vi[None].repeat(B, 1, 1).to(device)
faces_uvs = self.vti[None].repeat(B, 1, 1).to(device)
verts_uvs = self.vt[None].repeat(B, 1, 1).to(device)
# In-place operation for flipping and permuting tensor
if not self.flip_uvs:
tex = tex.permute(0, 2, 3, 1).flip((1,)).to(device)
textures = TexturesUV(
maps=tex,
faces_uvs=faces_uvs,
verts_uvs=verts_uvs,
)
meshes = Meshes(verts.to(device), faces, textures=textures)
fragments = self.rasterizer(meshes, cameras=cameras)
rgb = meshes.sample_textures(fragments)[:,:,:,0]
rgb[fragments.pix_to_face[...,0] == -1] = 0.0
return {'render': rgb.permute(0, 3, 1, 2)}