File size: 2,560 Bytes
7f3c2df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
from torch import nn

class Camera(nn.Module):
    def __init__(self, width, height, image, K, c2w,
                 image_name, data_device="cuda", 
                 semantic2d=None, depth=None, mask=None, timestamp=-1, optical_image=None, dynamics={}
                 ):
        super(Camera, self).__init__()

        try:
            self.data_device = torch.device(data_device)
        except Exception as e:
            print(e)
            print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
            self.data_device = torch.device("cuda")
            
        self.width = width
        self.height = height
        self.image_name = image_name
        self.timestamp = timestamp
        self.K = torch.from_numpy(K).float().cuda()
        self.c2w = torch.from_numpy(c2w).float().cuda()
        self.dynamics = dynamics

        self.original_image = torch.from_numpy(image).permute(2,0,1).float().clamp(0.0, 1.0).to(self.data_device)
        if semantic2d is not None:
            self.semantic2d = semantic2d.to(self.data_device)
        else:
            self.semantic2d = None
        if depth is not None:
            self.depth = depth.to(self.data_device)
        else:
            self.depth = None
        if mask is not None:
            self.mask = torch.from_numpy(mask).bool().to(self.data_device)
        else:
            self.mask = None
        self.image_width = self.original_image.shape[2]
        self.image_height = self.original_image.shape[1]
        if optical_image is not None:
            self.optical_gt = torch.from_numpy(optical_image).to(self.data_device)
        else:
            self.optical_gt = None


def loadCam(args, cam_info):

    if cam_info.semantic2d is not None:
        semantic2d = torch.from_numpy(cam_info.semantic2d).long()[None, ...]
    else:
        semantic2d = None

    optical_image = cam_info.optical_image
    mask = cam_info.mask
    depth = cam_info.depth

    gt_image = cam_info.image[..., :3] / 255.

    return Camera(K=cam_info.K, c2w=cam_info.c2w, width=cam_info.width, height=cam_info.height,
                  image=gt_image, image_name=cam_info.image_name, data_device=args.model.data_device,
                  semantic2d=semantic2d, depth=depth, mask=mask,
                  timestamp=cam_info.timestamp, optical_image=optical_image, dynamics=cam_info.dynamics)

def cameraList_from_camInfos(cam_infos, args):
    camera_list = []

    for c in cam_infos:
        camera_list.append(loadCam(args, c))

    return camera_list