Spaces:
Starting
on
T4
Starting
on
T4
import torch | |
from torch import nn | |
class Camera(nn.Module): | |
def __init__(self, width, height, image, K, c2w, | |
image_name, data_device="cuda", | |
semantic2d=None, depth=None, mask=None, timestamp=-1, optical_image=None, dynamics={} | |
): | |
super(Camera, self).__init__() | |
try: | |
self.data_device = torch.device(data_device) | |
except Exception as e: | |
print(e) | |
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) | |
self.data_device = torch.device("cuda") | |
self.width = width | |
self.height = height | |
self.image_name = image_name | |
self.timestamp = timestamp | |
self.K = torch.from_numpy(K).float().cuda() | |
self.c2w = torch.from_numpy(c2w).float().cuda() | |
self.dynamics = dynamics | |
self.original_image = torch.from_numpy(image).permute(2,0,1).float().clamp(0.0, 1.0).to(self.data_device) | |
if semantic2d is not None: | |
self.semantic2d = semantic2d.to(self.data_device) | |
else: | |
self.semantic2d = None | |
if depth is not None: | |
self.depth = depth.to(self.data_device) | |
else: | |
self.depth = None | |
if mask is not None: | |
self.mask = torch.from_numpy(mask).bool().to(self.data_device) | |
else: | |
self.mask = None | |
self.image_width = self.original_image.shape[2] | |
self.image_height = self.original_image.shape[1] | |
if optical_image is not None: | |
self.optical_gt = torch.from_numpy(optical_image).to(self.data_device) | |
else: | |
self.optical_gt = None | |
def loadCam(args, cam_info): | |
if cam_info.semantic2d is not None: | |
semantic2d = torch.from_numpy(cam_info.semantic2d).long()[None, ...] | |
else: | |
semantic2d = None | |
optical_image = cam_info.optical_image | |
mask = cam_info.mask | |
depth = cam_info.depth | |
gt_image = cam_info.image[..., :3] / 255. | |
return Camera(K=cam_info.K, c2w=cam_info.c2w, width=cam_info.width, height=cam_info.height, | |
image=gt_image, image_name=cam_info.image_name, data_device=args.model.data_device, | |
semantic2d=semantic2d, depth=depth, mask=mask, | |
timestamp=cam_info.timestamp, optical_image=optical_image, dynamics=cam_info.dynamics) | |
def cameraList_from_camInfos(cam_infos, args): | |
camera_list = [] | |
for c in cam_infos: | |
camera_list.append(loadCam(args, c)) | |
return camera_list | |