Spaces:
Starting
on
T4
Starting
on
T4
File size: 7,568 Bytes
7f3c2df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
import os
from typing import NamedTuple
import numpy as np
import json
from plyfile import PlyData, PlyElement
from utils.sh_utils import SH2RGB
from scene.gaussian_model import BasicPointCloud
import torch.nn.functional as F
from imageio.v2 import imread
import torch
class CameraInfo(NamedTuple):
K: np.array
c2w: np.array
image: np.array
image_path: str
image_name: str
width: int
height: int
semantic2d: np.array
optical_image: np.array
depth: torch.tensor
mask: np.array
timestamp: int
dynamics: dict
class SceneInfo(NamedTuple):
point_cloud: BasicPointCloud
train_cameras: list
test_cameras: list
nerf_normalization: dict
ply_path: str
verts: dict
def getNerfppNorm(cam_info, data_type):
def get_center_and_diag(cam_centers):
cam_centers = np.hstack(cam_centers)
avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
center = avg_cam_center
dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
diagonal = np.max(dist)
return center.flatten(), diagonal
cam_centers = []
for cam in cam_info:
cam_centers.append(cam.c2w[:3, 3:4]) # cam_centers in world coordinate
radius = 10
return {'radius': radius}
def fetchPly(path):
plydata = PlyData.read(path)
vertices = plydata['vertex']
positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
if 'red' in vertices:
colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
else:
print('Create random colors')
shs = np.ones((positions.shape[0], 3)) * 0.5
colors = SH2RGB(shs)
normals = np.zeros((positions.shape[0], 3))
return BasicPointCloud(points=positions, colors=colors, normals=normals)
def storePly(path, xyz, rgb):
# Define the dtype for the structured array
dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
normals = np.zeros_like(xyz)
elements = np.empty(xyz.shape[0], dtype=dtype)
attributes = np.concatenate((xyz, normals, rgb), axis=1)
elements[:] = list(map(tuple, attributes))
# Create the PlyData object and write to file
vertex_element = PlyElement.describe(elements, 'vertex')
ply_data = PlyData([vertex_element])
ply_data.write(path)
def readHUGSIMCameras(path, data_type, ignore_dynamic):
train_cam_infos, test_cam_infos = [], []
with open(os.path.join(path, 'meta_data.json')) as json_file:
meta_data = json.load(json_file)
verts = {}
if 'verts' in meta_data and not ignore_dynamic:
verts_list = meta_data['verts']
for k, v in verts_list.items():
verts[k] = np.array(v)
frames = meta_data['frames']
for idx, frame in enumerate(frames):
c2w = np.array(frame['camtoworld'])
rgb_path = os.path.join(path, frame['rgb_path'].replace('./', ''))
rgb_split = rgb_path.split('/')
image_name = '_'.join([rgb_split[-2], rgb_split[-1][:-4]])
image = imread(rgb_path)
semantic_2d = None
semantic_pth = rgb_path.replace("images", "semantics").replace('.png', '.npy').replace('.jpg', '.npy')
if os.path.exists(semantic_pth):
semantic_2d = np.load(semantic_pth)
semantic_2d[(semantic_2d == 14) | (semantic_2d == 15)] = 13
optical_path = rgb_path.replace("images", "flow").replace('.png', '_flow.npy').replace('.jpg', '_flow.npy')
if os.path.exists(optical_path):
optical_image = np.load(optical_path)
else:
optical_image = None
depth_path = rgb_path.replace("images", "depth").replace('.png', '.pt').replace('.jpg', '.pt')
if os.path.exists(depth_path):
depth = torch.load(depth_path, weights_only=True)
else:
depth = None
mask = None
mask_path = rgb_path.replace("images", "masks").replace('.png', '.npy').replace('.jpg', '.npy')
if os.path.exists(mask_path):
mask = np.load(mask_path)
timestamp = frame.get('timestamp', -1)
intrinsic = np.array(frame['intrinsics'])
dynamics = {}
if 'dynamics' in frame and not ignore_dynamic:
dynamics_list = frame['dynamics']
for iid in dynamics_list.keys():
dynamics[iid] = torch.tensor(dynamics_list[iid]).cuda()
cam_info = CameraInfo(K=intrinsic, c2w=c2w, image=np.array(image),
image_path=rgb_path, image_name=image_name, height=image.shape[0],
width=image.shape[1], semantic2d=semantic_2d,
optical_image=optical_image, depth=depth, mask=mask, timestamp=timestamp, dynamics=dynamics)
if data_type == 'kitti360':
if idx < 20:
train_cam_infos.append(cam_info)
elif idx % 20 < 16:
train_cam_infos.append(cam_info)
elif idx % 20 >= 16:
test_cam_infos.append(cam_info)
else:
continue
elif data_type == 'kitti':
if idx < 10 or idx >= len(frames) - 4:
train_cam_infos.append(cam_info)
elif idx % 4 < 2:
train_cam_infos.append(cam_info)
elif idx % 4 == 2:
test_cam_infos.append(cam_info)
else:
continue
elif data_type == "nuscenes":
if idx % 30 >= 24:
test_cam_infos.append(cam_info)
else:
train_cam_infos.append(cam_info)
elif data_type == "waymo":
if idx % 15 >= 12:
test_cam_infos.append(cam_info)
else:
train_cam_infos.append(cam_info)
elif data_type == "pandaset":
if idx > 30 and idx % 30 >= 24:
test_cam_infos.append(cam_info)
else:
train_cam_infos.append(cam_info)
else:
raise NotImplementedError
return train_cam_infos, test_cam_infos, verts
def readHUGSIMInfo(path, data_type, ignore_dynamic):
train_cam_infos, test_cam_infos, verts = readHUGSIMCameras(path, data_type, ignore_dynamic)
print(f'Loaded {len(train_cam_infos)} train cameras and {len(test_cam_infos)} test cameras')
nerf_normalization = getNerfppNorm(train_cam_infos, data_type)
ply_path = os.path.join(path, "points3d.ply")
if not os.path.exists(ply_path):
assert False, "Requires for initialize 3d points as inputs"
try:
pcd = fetchPly(ply_path)
except Exception as e:
print('When loading point clound, meet error:', e)
exit(0)
scene_info = SceneInfo(point_cloud=pcd,
train_cameras=train_cam_infos,
test_cameras=test_cam_infos,
nerf_normalization=nerf_normalization,
ply_path=ply_path,
verts=verts)
return scene_info
sceneLoadTypeCallbacks = {
"HUGSIM": readHUGSIMInfo,
} |