Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) Meta Platforms, Inc. and affiliates. | |
All rights reserved. | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
from typing import List, Dict | |
import torch as th | |
import torch.nn as nn | |
from pytorch3d.renderer import ( | |
RasterizationSettings, | |
MeshRasterizer, | |
) | |
from pytorch3d.structures import Meshes | |
from pytorch3d.renderer.mesh.textures import TexturesUV | |
from pytorch3d.utils import cameras_from_opencv_projection | |
class RenderLayer(nn.Module): | |
def __init__(self, h, w, vi, vt, vti, flip_uvs=False): | |
super().__init__() | |
self.register_buffer("vi", vi, persistent=False) | |
self.register_buffer("vt", vt, persistent=False) | |
self.register_buffer("vti", vti, persistent=False) | |
raster_settings = RasterizationSettings(image_size=(h, w)) | |
self.rasterizer = MeshRasterizer(raster_settings=raster_settings) | |
self.flip_uvs = flip_uvs | |
image_size = th.as_tensor([h, w], dtype=th.int32) | |
self.register_buffer("image_size", image_size) | |
def forward(self, verts: th.Tensor, tex: th.Tensor, K: th.Tensor, Rt: th.Tensor, background: th.Tensor = None, output_filters: List[str] = None): | |
assert output_filters is None | |
assert background is None | |
device = verts.device # Get device info | |
B = verts.shape[0] | |
image_size = th.repeat_interleave(self.image_size[None], B, dim=0).to(device) | |
cameras = cameras_from_opencv_projection(Rt[:,:,:3], Rt[:,:3,3], K, image_size) | |
faces = self.vi[None].repeat(B, 1, 1).to(device) | |
faces_uvs = self.vti[None].repeat(B, 1, 1).to(device) | |
verts_uvs = self.vt[None].repeat(B, 1, 1).to(device) | |
# In-place operation for flipping and permuting tensor | |
if not self.flip_uvs: | |
tex = tex.permute(0, 2, 3, 1).flip((1,)).to(device) | |
textures = TexturesUV( | |
maps=tex, | |
faces_uvs=faces_uvs, | |
verts_uvs=verts_uvs, | |
) | |
meshes = Meshes(verts.to(device), faces, textures=textures) | |
fragments = self.rasterizer(meshes, cameras=cameras) | |
rgb = meshes.sample_textures(fragments)[:,:,:,0] | |
rgb[fragments.pix_to_face[...,0] == -1] = 0.0 | |
return {'render': rgb.permute(0, 3, 1, 2)} |