hyzhou404's picture
private scenes
7f3c2df
import torch
from torch import nn
class Camera(nn.Module):
def __init__(self, width, height, image, K, c2w,
image_name, data_device="cuda",
semantic2d=None, depth=None, mask=None, timestamp=-1, optical_image=None, dynamics={}
):
super(Camera, self).__init__()
try:
self.data_device = torch.device(data_device)
except Exception as e:
print(e)
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
self.data_device = torch.device("cuda")
self.width = width
self.height = height
self.image_name = image_name
self.timestamp = timestamp
self.K = torch.from_numpy(K).float().cuda()
self.c2w = torch.from_numpy(c2w).float().cuda()
self.dynamics = dynamics
self.original_image = torch.from_numpy(image).permute(2,0,1).float().clamp(0.0, 1.0).to(self.data_device)
if semantic2d is not None:
self.semantic2d = semantic2d.to(self.data_device)
else:
self.semantic2d = None
if depth is not None:
self.depth = depth.to(self.data_device)
else:
self.depth = None
if mask is not None:
self.mask = torch.from_numpy(mask).bool().to(self.data_device)
else:
self.mask = None
self.image_width = self.original_image.shape[2]
self.image_height = self.original_image.shape[1]
if optical_image is not None:
self.optical_gt = torch.from_numpy(optical_image).to(self.data_device)
else:
self.optical_gt = None
def loadCam(args, cam_info):
if cam_info.semantic2d is not None:
semantic2d = torch.from_numpy(cam_info.semantic2d).long()[None, ...]
else:
semantic2d = None
optical_image = cam_info.optical_image
mask = cam_info.mask
depth = cam_info.depth
gt_image = cam_info.image[..., :3] / 255.
return Camera(K=cam_info.K, c2w=cam_info.c2w, width=cam_info.width, height=cam_info.height,
image=gt_image, image_name=cam_info.image_name, data_device=args.model.data_device,
semantic2d=semantic2d, depth=depth, mask=mask,
timestamp=cam_info.timestamp, optical_image=optical_image, dynamics=cam_info.dynamics)
def cameraList_from_camInfos(cam_infos, args):
camera_list = []
for c in cam_infos:
camera_list.append(loadCam(args, c))
return camera_list