hyzhou404 commited on
Commit
7f3c2df
·
1 Parent(s): 394bc84

private scenes

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +5 -5
  2. code/gaussian_renderer/__init__.py +231 -0
  3. code/gaussian_renderer/__pycache__/__init__.cpython-311.pyc +0 -0
  4. code/scene/__init__.py +111 -0
  5. code/scene/__pycache__/__init__.cpython-311.pyc +0 -0
  6. code/scene/__pycache__/cameras.cpython-311.pyc +0 -0
  7. code/scene/__pycache__/dataset_readers.cpython-311.pyc +0 -0
  8. code/scene/__pycache__/gaussian_model.cpython-311.pyc +0 -0
  9. code/scene/__pycache__/ground_model.cpython-311.pyc +0 -0
  10. code/scene/__pycache__/obj_model.cpython-311.pyc +0 -0
  11. code/scene/cameras.py +71 -0
  12. code/scene/dataset_readers.py +212 -0
  13. code/scene/gaussian_model.py +636 -0
  14. code/scene/ground_model.py +360 -0
  15. code/scene/obj_model.py +567 -0
  16. code/sim/hugsim_env.egg-info/PKG-INFO +4 -0
  17. code/sim/hugsim_env.egg-info/SOURCES.txt +10 -0
  18. code/sim/hugsim_env.egg-info/dependency_links.txt +1 -0
  19. code/sim/hugsim_env.egg-info/requires.txt +1 -0
  20. code/sim/hugsim_env.egg-info/top_level.txt +1 -0
  21. code/sim/hugsim_env/__init__.py +8 -0
  22. code/sim/hugsim_env/__pycache__/__init__.cpython-311.pyc +0 -0
  23. code/sim/hugsim_env/envs/__init__.py +1 -0
  24. code/sim/hugsim_env/envs/__pycache__/__init__.cpython-311.pyc +0 -0
  25. code/sim/hugsim_env/envs/__pycache__/hug_sim.cpython-311.pyc +0 -0
  26. code/sim/hugsim_env/envs/hug_sim.py +333 -0
  27. code/sim/ilqr/__pycache__/lqr.cpython-311.pyc +0 -0
  28. code/sim/ilqr/__pycache__/lqr_solver.cpython-311.pyc +0 -0
  29. code/sim/ilqr/__pycache__/utils.cpython-311.pyc +0 -0
  30. code/sim/ilqr/lqr.py +55 -0
  31. code/sim/ilqr/lqr_solver.py +689 -0
  32. code/sim/ilqr/utils.py +346 -0
  33. code/sim/pyproject.toml +9 -0
  34. code/sim/setup.py +7 -0
  35. code/sim/utils/__pycache__/agent_controller.cpython-311.pyc +0 -0
  36. code/sim/utils/__pycache__/plan.cpython-311.pyc +0 -0
  37. code/sim/utils/__pycache__/score_calculator.cpython-311.pyc +0 -0
  38. code/sim/utils/__pycache__/sim_utils.cpython-311.pyc +0 -0
  39. code/sim/utils/agent_controller.py +323 -0
  40. code/sim/utils/launch_ad.py +30 -0
  41. code/sim/utils/plan.py +238 -0
  42. code/sim/utils/score_calculator.py +562 -0
  43. code/sim/utils/sim_utils.py +122 -0
  44. code/submodules/Pplan/Policy/base.py +16 -0
  45. code/submodules/Pplan/Policy/sampling_planner.py +0 -0
  46. code/submodules/Pplan/Sampling/__init__.py +0 -0
  47. code/submodules/Pplan/Sampling/__pycache__/__init__.cpython-311.pyc +0 -0
  48. code/submodules/Pplan/Sampling/__pycache__/forward_sampler.cpython-311.pyc +0 -0
  49. code/submodules/Pplan/Sampling/__pycache__/spline_planner.cpython-311.pyc +0 -0
  50. code/submodules/Pplan/Sampling/forward_sampler.py +141 -0
Dockerfile CHANGED
@@ -15,9 +15,9 @@ ENV PATH /app/miniconda/bin:$PATH
15
 
16
  SHELL ["conda", "run","--no-capture-output", "-p","/app/env", "/bin/bash", "-c"]
17
 
18
- COPY --chown=1000:1000 ./web_server.py /app/web_server.py
19
- COPY --chown=1000:1000 ./docker/web_server_config/scene-0383-medium-00.yaml /app/docker/web_server_config/scene-0383-medium-00.yaml
20
- COPY --chown=1000:1000 ./download_pre_datas.py /app/download_pre_datas.py
21
 
22
  ENV TCNN_CUDA_ARCHITECTURES 75
23
 
@@ -32,6 +32,6 @@ RUN ./.pixi/envs/default/bin/python3 -m pip install psutil
32
 
33
  RUN ./.pixi/envs/default/bin/python3 -m pip install moviepy
34
 
35
- RUN ./.pixi/envs/default/bin/python /app/download_pre_datas.py
36
 
37
- CMD ["./.pixi/envs/default/bin/python", "web_server.py"]
 
15
 
16
  SHELL ["conda", "run","--no-capture-output", "-p","/app/env", "/bin/bash", "-c"]
17
 
18
+ COPY --chown=1000:1000 ./code /app/code
19
+ COPY --chown=1000:1000 ./web_server.py /app/code/web_server.py
20
+ COPY --chown=1000:1000 ./download_pre_datas.py /app/code/download_pre_datas.py
21
 
22
  ENV TCNN_CUDA_ARCHITECTURES 75
23
 
 
32
 
33
  RUN ./.pixi/envs/default/bin/python3 -m pip install moviepy
34
 
35
+ RUN ./.pixi/envs/default/bin/python /app/code/download_pre_datas.py
36
 
37
+ CMD ["./.pixi/envs/default/bin/python", "/app/code/web_server.py"]
code/gaussian_renderer/__init__.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from scene.gaussian_model import GaussianModel
3
+ from scene.ground_model import GroundModel
4
+ from gsplat.rendering import rasterization
5
+ import roma
6
+ from scene.cameras import Camera
7
+ from torch import Tensor
8
+
9
+ def euler2matrix(yaw):
10
+ return torch.tensor([
11
+ [torch.cos(-yaw), 0, torch.sin(-yaw)],
12
+ [0, 1, 0],
13
+ [-torch.sin(-yaw), 0, torch.cos(-yaw)]
14
+ ]).cuda()
15
+
16
+ def cat_bgfg(bg, fg, only_xyz=False):
17
+ if only_xyz:
18
+ if bg.ground_model is None:
19
+ bg_feats = [bg.get_xyz]
20
+ else:
21
+ bg_feats = [bg.get_full_xyz]
22
+ else:
23
+ if bg.ground_model is None:
24
+ bg_feats = [bg.get_xyz, bg.get_opacity, bg.get_scaling, bg.get_rotation, bg.get_features, bg.get_3D_features]
25
+ else:
26
+ bg_feats = [bg.get_full_xyz, bg.get_full_opacity, bg.get_full_scaling, bg.get_full_rotation, bg.get_full_features, bg.get_full_3D_features]
27
+
28
+
29
+ if len(fg) == 0:
30
+ return bg_feats
31
+
32
+ output = []
33
+ for fg_feat, bg_feat in zip(fg, bg_feats):
34
+ if fg_feat is None:
35
+ output.append(bg_feat)
36
+ else:
37
+ if bg_feat.shape[1] != fg_feat.shape[1]:
38
+ fg_feat = fg_feat[:, :bg_feat.shape[1], :]
39
+ output.append(torch.cat((bg_feat, fg_feat), dim=0))
40
+
41
+ return output
42
+
43
+ def concatenate_all(all_fg):
44
+ output = []
45
+ for feat in list(zip(*all_fg)):
46
+ output.append(torch.cat(feat, dim=0))
47
+ return output
48
+
49
+ def proj_uv(xyz, cam):
50
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
+ intr = torch.as_tensor(cam.K[:3, :3]).float().to(device) # (3, 3)
52
+ w2c = torch.linalg.inv(cam.c2w)[:3, :] # (3, 4)
53
+
54
+ c_xyz = (w2c[:3, :3] @ xyz.T).T + w2c[:3, 3]
55
+ i_xyz = (intr @ c_xyz.mT).mT # (N, 3)
56
+ uv = i_xyz[:, [1,0]] / i_xyz[:, -1:].clip(1e-3) # (N, 2)
57
+ return uv
58
+
59
+
60
+ def unicycle_b2w(timestamp, model):
61
+ pred = model(timestamp)
62
+ if pred is None:
63
+ return None
64
+ pred_a, pred_b, pred_v, pitchroll, pred_yaw, pred_h = pred
65
+ rt = torch.eye(4).float().cuda()
66
+ rt[:3,:3] = roma.euler_to_rotmat('xzy', [-pitchroll[0]+torch.pi/2, -pitchroll[1]+torch.pi/2, -pred_yaw+torch.pi/2])
67
+ rt[1, 3], rt[0, 3], rt[2, 3] = pred_h, pred_a, pred_b
68
+ return rt
69
+
70
+
71
+ def render(viewpoint:Camera, prev_viewpoint:Camera, pc:GaussianModel, dynamic_gaussians:dict,
72
+ unicycles:dict, bg_color:Tensor, render_optical=False, planning=[]):
73
+ """
74
+ Render the scene.
75
+
76
+ Background tensor (bg_color) must be on GPU!
77
+ """
78
+ timestamp = viewpoint.timestamp
79
+
80
+ all_fg = [None, None, None, None, None, None]
81
+ prev_all_fg = [None]
82
+
83
+ if unicycles is None or len(unicycles) == 0:
84
+ track_dict = viewpoint.dynamics
85
+ if prev_viewpoint is not None:
86
+ prev_track_dict = prev_viewpoint.dynamics
87
+ else:
88
+ track_dict, prev_track_dict = {}, {}
89
+ for track_id, B2W in viewpoint.dynamics.items():
90
+ if track_id in unicycles:
91
+ B2W = unicycle_b2w(timestamp, unicycles[track_id]['model'])
92
+ track_dict[track_id] = B2W
93
+ if prev_viewpoint is not None:
94
+ prev_B2W = unicycle_b2w(prev_viewpoint.timestamp, unicycles[track_id]['model'])
95
+ prev_track_dict[track_id] = prev_B2W
96
+ if len(planning) > 0:
97
+ for plan_id, B2W in planning[0].items():
98
+ track_dict[plan_id] = B2W
99
+ if prev_viewpoint is not None:
100
+ for plan_id, B2W in planning[1].items():
101
+ prev_track_dict[plan_id] = B2W
102
+
103
+ all_fg, prev_all_fg = [], []
104
+ for track_id, B2W in track_dict.items():
105
+ w_dxyz = (B2W[:3, :3] @ dynamic_gaussians[track_id].get_xyz.T).T + B2W[:3, 3]
106
+
107
+ drot = roma.quat_wxyz_to_xyzw(dynamic_gaussians[track_id].get_rotation)
108
+ drot = roma.unitquat_to_rotmat(drot)
109
+ w_drot = roma.quat_xyzw_to_wxyz(roma.rotmat_to_unitquat(B2W[:3, :3] @ drot))
110
+ fg = [w_dxyz,
111
+ dynamic_gaussians[track_id].get_opacity,
112
+ dynamic_gaussians[track_id].get_scaling,
113
+ w_drot,
114
+ # dynamic_gaussians[track_id].get_rotation,
115
+ dynamic_gaussians[track_id].get_features,
116
+ dynamic_gaussians[track_id].get_3D_features]
117
+
118
+ all_fg.append(fg)
119
+
120
+ if render_optical and prev_viewpoint is not None:
121
+ if track_id in prev_track_dict:
122
+ prev_B2W = prev_track_dict[track_id]
123
+ prev_w_dxyz = torch.mm(prev_B2W[:3, :3], dynamic_gaussians[track_id].get_xyz.T).T + prev_B2W[:3, 3]
124
+ prev_all_fg.append([prev_w_dxyz])
125
+ else:
126
+ prev_all_fg.append([w_dxyz])
127
+
128
+ all_fg = concatenate_all(all_fg)
129
+ xyz, opacities, scales, rotations, shs, feats3D = cat_bgfg(pc, all_fg)
130
+
131
+ if render_optical and prev_viewpoint is not None:
132
+ prev_all_fg = concatenate_all(prev_all_fg)
133
+ prev_xyz = cat_bgfg(pc, prev_all_fg, only_xyz=True)[0]
134
+ uv = proj_uv(xyz, viewpoint)
135
+ prev_uv = proj_uv(prev_xyz, prev_viewpoint)
136
+ delta_uv = prev_uv - uv
137
+ delta_uv = torch.cat([delta_uv, torch.ones_like(delta_uv[:, :1], device=delta_uv.device)], dim=-1)
138
+ else:
139
+ delta_uv = torch.zeros_like(xyz)
140
+
141
+ if pc.affine:
142
+ cam_xyz, cam_dir = viewpoint.c2w[:3, 3].cuda(), viewpoint.c2w[:3, 2].cuda()
143
+ o_enc = pc.pos_enc(cam_xyz[None, :] / 60)
144
+ d_enc = pc.dir_enc(cam_dir[None, :])
145
+ appearance = pc.appearance_model(torch.cat([o_enc, d_enc], dim=1)) * 1e-1
146
+ affine_weight, affine_bias = appearance[:, :9].view(3, 3), appearance[:, -3:]
147
+ affine_weight = affine_weight + torch.eye(3, device=appearance.device)
148
+
149
+ if render_optical:
150
+ render_mode = 'RGB+ED+S+F'
151
+ else:
152
+ render_mode = 'RGB+ED+S'
153
+
154
+ renders, render_alphas, info = rasterization(
155
+ means=xyz,
156
+ quats=rotations,
157
+ scales=scales,
158
+ opacities=opacities[:, 0],
159
+ colors=shs,
160
+ viewmats=torch.linalg.inv(viewpoint.c2w)[None, ...], # [C, 4, 4]
161
+ Ks=viewpoint.K[None, :3, :3], # [C, 3, 3]
162
+ width=viewpoint.width,
163
+ height=viewpoint.height,
164
+ smts=feats3D[None, ...],
165
+ flows= delta_uv[None, ...],
166
+ render_mode=render_mode,
167
+ sh_degree=pc.active_sh_degree,
168
+ near_plane=0.01,
169
+ far_plane=500,
170
+ packed=False,
171
+ backgrounds=bg_color[None, :],
172
+ )
173
+
174
+ renders = renders[0]
175
+ rendered_image = renders[..., :3].permute(2,0,1)
176
+ depth = renders[..., 3][None, ...]
177
+ smt = renders[..., 4:(4+feats3D.shape[-1])].permute(2,0,1)
178
+
179
+ if pc.affine:
180
+ colors = rendered_image.view(3, -1).permute(1, 0) # (H*W, 3)
181
+ refined_image = (colors @ affine_weight + affine_bias).clip(0, 1).permute(1, 0).view(*rendered_image.shape)
182
+ else:
183
+ refined_image = rendered_image
184
+
185
+ return {"render": refined_image,
186
+ "feats": smt,
187
+ "depth": depth,
188
+ "opticalflow": renders[..., -2:].permute(2,0,1) if render_optical else None,
189
+ "alphas": render_alphas,
190
+ "viewspace_points": info["means2d"],
191
+ "info": info,
192
+ }
193
+
194
+
195
+ def render_ground(viewpoint:Camera, pc:GroundModel, bg_color:Tensor):
196
+ xyz, opacities, scales = pc.get_xyz, pc.get_opacity, pc.get_scaling
197
+ rotations, shs, feats3D = pc.get_rotation, pc.get_features, pc.get_3D_features
198
+
199
+ K = viewpoint.K[None, :3, :3]
200
+ renders, render_alphas, info = rasterization(
201
+ means=xyz,
202
+ quats=rotations,
203
+ scales=scales,
204
+ opacities=opacities[:, 0],
205
+ colors=shs,
206
+ viewmats=torch.linalg.inv(viewpoint.c2w)[None, ...], # [C, 4, 4]
207
+ Ks=K, # [C, 3, 3]
208
+ width=viewpoint.width,
209
+ height=viewpoint.height,
210
+ smts=feats3D[None, ...],
211
+ render_mode='RGB+ED+S',
212
+ sh_degree=pc.active_sh_degree,
213
+ near_plane=0.01,
214
+ far_plane=500,
215
+ packed=False,
216
+ backgrounds=bg_color[None, :],
217
+ )
218
+
219
+ renders = renders[0]
220
+ rendered_image = renders[..., :3].permute(2,0,1)
221
+ depth = renders[..., 3][None, ...]
222
+ smt = renders[..., 4:(4+feats3D.shape[-1])].permute(2,0,1)
223
+
224
+ return {"render": rendered_image,
225
+ "feats": smt,
226
+ "depth": depth,
227
+ "opticalflow": None,
228
+ "alphas": render_alphas,
229
+ "viewspace_points": info["means2d"],
230
+ "info": info,
231
+ }
code/gaussian_renderer/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (12.9 kB). View file
 
code/scene/__init__.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import json
4
+ from utils.system_utils import searchForMaxIteration
5
+ from scene.dataset_readers import sceneLoadTypeCallbacks
6
+ from scene.gaussian_model import GaussianModel
7
+ from scene.obj_model import ObjModel
8
+ from scene.cameras import cameraList_from_camInfos
9
+ import torch
10
+ import open3d as o3d
11
+ import numpy as np
12
+ import shutil
13
+
14
+
15
+ def load_cameras(args, data_type, ignore_dynamic=False):
16
+ train_cameras = {}
17
+ test_cameras = {}
18
+ if os.path.exists(os.path.join(args.source_path, "meta_data.json")):
19
+ print("Found meta_data.json file, assuming HUGSIM format data set!")
20
+ scene_info = sceneLoadTypeCallbacks['HUGSIM'](args.source_path, data_type, ignore_dynamic)
21
+ else:
22
+ assert False, "Could not recognize scene type! "+args.source_path
23
+
24
+ print("Loading Training Cameras")
25
+ train_cameras = cameraList_from_camInfos(scene_info.train_cameras, args)
26
+ print("Loading Test Cameras")
27
+ test_cameras = cameraList_from_camInfos(scene_info.test_cameras, args)
28
+ return train_cameras, test_cameras, scene_info
29
+
30
+ class Scene:
31
+
32
+ def __init__(self, args, gaussians:GaussianModel, load_iteration=None, shuffle=True,
33
+ data_type='kitti360', ignore_dynamic=False, planning=None):
34
+ """b
35
+ :param path: Path to colmap scene main folder.
36
+ """
37
+ self.model_path = args.model_path
38
+ self.loaded_iter = None
39
+ self.gaussians = gaussians
40
+ self.data_type = data_type
41
+
42
+ if load_iteration:
43
+ if load_iteration == -1:
44
+ self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "ckpts"))
45
+ else:
46
+ self.loaded_iter = load_iteration
47
+ print("Loading trained model at iteration {}".format(self.loaded_iter))
48
+
49
+ self.train_cameras, self.test_cameras, scene_info = load_cameras(args, data_type, ignore_dynamic)
50
+
51
+ self.dynamic_verts = scene_info.verts
52
+ self.dynamic_gaussians = {}
53
+ for track_id in scene_info.verts:
54
+ self.dynamic_gaussians[track_id] = ObjModel(args.model.sh_degree, feat_mutable=False)
55
+ if planning is not None:
56
+ for plan_id in planning.keys():
57
+ self.dynamic_gaussians[plan_id] = ObjModel(args.model.sh_degree, feat_mutable=False)
58
+
59
+ if not self.loaded_iter:
60
+ shutil.copyfile(scene_info.ply_path, os.path.join(self.model_path, "input.ply"))
61
+ shutil.copyfile(os.path.join(args.source_path, 'meta_data.json'), os.path.join(self.model_path, 'meta_data.json'))
62
+ shutil.copyfile(os.path.join(args.source_path, 'ground_param.pkl'), os.path.join(self.model_path, 'ground_param.pkl'))
63
+
64
+ if shuffle:
65
+ random.shuffle(scene_info.train_cameras)
66
+ random.shuffle(scene_info.test_cameras)
67
+
68
+ self.cameras_extent = scene_info.nerf_normalization["radius"]
69
+
70
+ if self.loaded_iter:
71
+ (model_params, first_iter) = torch.load(os.path.join(self.model_path, "ckpts", f"chkpnt{self.loaded_iter}.pth"))
72
+ gaussians.restore(model_params, None)
73
+ for iid, dynamic_gaussian in self.dynamic_gaussians.items():
74
+ if planning is None or iid not in planning:
75
+ (model_params, first_iter) = torch.load(os.path.join(self.model_path, "ckpts", f"dynamic_{iid}_chkpnt{self.loaded_iter}.pth"))
76
+ dynamic_gaussian.restore(model_params, None)
77
+ else:
78
+ (model_params, first_iter) = torch.load(planning[iid])
79
+ model_params = list(model_params)
80
+ model_params.append(None)
81
+ dynamic_gaussian.restore(model_params, None)
82
+ # for iid, unicycle_pkg in self.unicycles.items():
83
+ # model_params = torch.load(os.path.join(self.model_path, "ckpts", f"unicycle_{iid}_chkpnt{self.loaded_iter}.pth"))
84
+ # unicycle_pkg['model'].restore(model_params)
85
+
86
+ else:
87
+ self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
88
+ for track_id in self.dynamic_gaussians.keys():
89
+ vertices = scene_info.verts[track_id]
90
+
91
+ # init from template
92
+ l, h, w = vertices[:, 0].max() - vertices[:, 0].min(), vertices[:, 1].max() - vertices[:, 1].min(), vertices[:, 2].max() - vertices[:, 2].min()
93
+ pcd = o3d.io.read_point_cloud(f"utils/vehicle_template/benz_{data_type}.ply")
94
+ points = np.array(pcd.points) * np.array([l, h, w])
95
+ pcd.points = o3d.utility.Vector3dVector(points)
96
+ pcd.colors = o3d.utility.Vector3dVector(np.ones_like(points) * 0.5)
97
+
98
+ self.dynamic_gaussians[track_id].create_from_pcd(pcd, self.cameras_extent)
99
+
100
+ def save(self, iteration):
101
+ # self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
102
+ point_cloud_vis_path = os.path.join(self.model_path, "point_cloud_vis/iteration_{}".format(iteration))
103
+ self.gaussians.save_vis_ply(os.path.join(point_cloud_vis_path, "point.ply"))
104
+ for iid, dynamic_gaussian in self.dynamic_gaussians.items():
105
+ dynamic_gaussian.save_vis_ply(os.path.join(point_cloud_vis_path, f"dynamic_{iid}.ply"))
106
+
107
+ def getTrainCameras(self):
108
+ return self.train_cameras
109
+
110
+ def getTestCameras(self):
111
+ return self.test_cameras
code/scene/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (8.49 kB). View file
 
code/scene/__pycache__/cameras.cpython-311.pyc ADDED
Binary file (4.29 kB). View file
 
code/scene/__pycache__/dataset_readers.cpython-311.pyc ADDED
Binary file (11.6 kB). View file
 
code/scene/__pycache__/gaussian_model.cpython-311.pyc ADDED
Binary file (50.3 kB). View file
 
code/scene/__pycache__/ground_model.cpython-311.pyc ADDED
Binary file (28.7 kB). View file
 
code/scene/__pycache__/obj_model.cpython-311.pyc ADDED
Binary file (44.1 kB). View file
 
code/scene/cameras.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class Camera(nn.Module):
5
+ def __init__(self, width, height, image, K, c2w,
6
+ image_name, data_device="cuda",
7
+ semantic2d=None, depth=None, mask=None, timestamp=-1, optical_image=None, dynamics={}
8
+ ):
9
+ super(Camera, self).__init__()
10
+
11
+ try:
12
+ self.data_device = torch.device(data_device)
13
+ except Exception as e:
14
+ print(e)
15
+ print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
16
+ self.data_device = torch.device("cuda")
17
+
18
+ self.width = width
19
+ self.height = height
20
+ self.image_name = image_name
21
+ self.timestamp = timestamp
22
+ self.K = torch.from_numpy(K).float().cuda()
23
+ self.c2w = torch.from_numpy(c2w).float().cuda()
24
+ self.dynamics = dynamics
25
+
26
+ self.original_image = torch.from_numpy(image).permute(2,0,1).float().clamp(0.0, 1.0).to(self.data_device)
27
+ if semantic2d is not None:
28
+ self.semantic2d = semantic2d.to(self.data_device)
29
+ else:
30
+ self.semantic2d = None
31
+ if depth is not None:
32
+ self.depth = depth.to(self.data_device)
33
+ else:
34
+ self.depth = None
35
+ if mask is not None:
36
+ self.mask = torch.from_numpy(mask).bool().to(self.data_device)
37
+ else:
38
+ self.mask = None
39
+ self.image_width = self.original_image.shape[2]
40
+ self.image_height = self.original_image.shape[1]
41
+ if optical_image is not None:
42
+ self.optical_gt = torch.from_numpy(optical_image).to(self.data_device)
43
+ else:
44
+ self.optical_gt = None
45
+
46
+
47
+ def loadCam(args, cam_info):
48
+
49
+ if cam_info.semantic2d is not None:
50
+ semantic2d = torch.from_numpy(cam_info.semantic2d).long()[None, ...]
51
+ else:
52
+ semantic2d = None
53
+
54
+ optical_image = cam_info.optical_image
55
+ mask = cam_info.mask
56
+ depth = cam_info.depth
57
+
58
+ gt_image = cam_info.image[..., :3] / 255.
59
+
60
+ return Camera(K=cam_info.K, c2w=cam_info.c2w, width=cam_info.width, height=cam_info.height,
61
+ image=gt_image, image_name=cam_info.image_name, data_device=args.model.data_device,
62
+ semantic2d=semantic2d, depth=depth, mask=mask,
63
+ timestamp=cam_info.timestamp, optical_image=optical_image, dynamics=cam_info.dynamics)
64
+
65
+ def cameraList_from_camInfos(cam_infos, args):
66
+ camera_list = []
67
+
68
+ for c in cam_infos:
69
+ camera_list.append(loadCam(args, c))
70
+
71
+ return camera_list
code/scene/dataset_readers.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import NamedTuple
3
+ import numpy as np
4
+ import json
5
+ from plyfile import PlyData, PlyElement
6
+ from utils.sh_utils import SH2RGB
7
+ from scene.gaussian_model import BasicPointCloud
8
+ import torch.nn.functional as F
9
+ from imageio.v2 import imread
10
+ import torch
11
+
12
+
13
+ class CameraInfo(NamedTuple):
14
+ K: np.array
15
+ c2w: np.array
16
+ image: np.array
17
+ image_path: str
18
+ image_name: str
19
+ width: int
20
+ height: int
21
+ semantic2d: np.array
22
+ optical_image: np.array
23
+ depth: torch.tensor
24
+ mask: np.array
25
+ timestamp: int
26
+ dynamics: dict
27
+
28
+ class SceneInfo(NamedTuple):
29
+ point_cloud: BasicPointCloud
30
+ train_cameras: list
31
+ test_cameras: list
32
+ nerf_normalization: dict
33
+ ply_path: str
34
+ verts: dict
35
+
36
+ def getNerfppNorm(cam_info, data_type):
37
+ def get_center_and_diag(cam_centers):
38
+ cam_centers = np.hstack(cam_centers)
39
+ avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
40
+ center = avg_cam_center
41
+ dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
42
+ diagonal = np.max(dist)
43
+ return center.flatten(), diagonal
44
+
45
+ cam_centers = []
46
+ for cam in cam_info:
47
+ cam_centers.append(cam.c2w[:3, 3:4]) # cam_centers in world coordinate
48
+
49
+ radius = 10
50
+
51
+ return {'radius': radius}
52
+
53
+ def fetchPly(path):
54
+ plydata = PlyData.read(path)
55
+ vertices = plydata['vertex']
56
+ positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
57
+ if 'red' in vertices:
58
+ colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
59
+ else:
60
+ print('Create random colors')
61
+ shs = np.ones((positions.shape[0], 3)) * 0.5
62
+ colors = SH2RGB(shs)
63
+ normals = np.zeros((positions.shape[0], 3))
64
+ return BasicPointCloud(points=positions, colors=colors, normals=normals)
65
+
66
+ def storePly(path, xyz, rgb):
67
+ # Define the dtype for the structured array
68
+ dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
69
+ ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
70
+ ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
71
+
72
+ normals = np.zeros_like(xyz)
73
+
74
+ elements = np.empty(xyz.shape[0], dtype=dtype)
75
+ attributes = np.concatenate((xyz, normals, rgb), axis=1)
76
+ elements[:] = list(map(tuple, attributes))
77
+
78
+ # Create the PlyData object and write to file
79
+ vertex_element = PlyElement.describe(elements, 'vertex')
80
+ ply_data = PlyData([vertex_element])
81
+ ply_data.write(path)
82
+
83
+ def readHUGSIMCameras(path, data_type, ignore_dynamic):
84
+ train_cam_infos, test_cam_infos = [], []
85
+ with open(os.path.join(path, 'meta_data.json')) as json_file:
86
+ meta_data = json.load(json_file)
87
+
88
+ verts = {}
89
+ if 'verts' in meta_data and not ignore_dynamic:
90
+ verts_list = meta_data['verts']
91
+ for k, v in verts_list.items():
92
+ verts[k] = np.array(v)
93
+
94
+ frames = meta_data['frames']
95
+ for idx, frame in enumerate(frames):
96
+ c2w = np.array(frame['camtoworld'])
97
+
98
+ rgb_path = os.path.join(path, frame['rgb_path'].replace('./', ''))
99
+
100
+ rgb_split = rgb_path.split('/')
101
+ image_name = '_'.join([rgb_split[-2], rgb_split[-1][:-4]])
102
+ image = imread(rgb_path)
103
+
104
+ semantic_2d = None
105
+ semantic_pth = rgb_path.replace("images", "semantics").replace('.png', '.npy').replace('.jpg', '.npy')
106
+ if os.path.exists(semantic_pth):
107
+ semantic_2d = np.load(semantic_pth)
108
+ semantic_2d[(semantic_2d == 14) | (semantic_2d == 15)] = 13
109
+
110
+ optical_path = rgb_path.replace("images", "flow").replace('.png', '_flow.npy').replace('.jpg', '_flow.npy')
111
+ if os.path.exists(optical_path):
112
+ optical_image = np.load(optical_path)
113
+ else:
114
+ optical_image = None
115
+
116
+ depth_path = rgb_path.replace("images", "depth").replace('.png', '.pt').replace('.jpg', '.pt')
117
+ if os.path.exists(depth_path):
118
+ depth = torch.load(depth_path, weights_only=True)
119
+ else:
120
+ depth = None
121
+
122
+ mask = None
123
+ mask_path = rgb_path.replace("images", "masks").replace('.png', '.npy').replace('.jpg', '.npy')
124
+ if os.path.exists(mask_path):
125
+ mask = np.load(mask_path)
126
+
127
+ timestamp = frame.get('timestamp', -1)
128
+
129
+ intrinsic = np.array(frame['intrinsics'])
130
+
131
+ dynamics = {}
132
+ if 'dynamics' in frame and not ignore_dynamic:
133
+ dynamics_list = frame['dynamics']
134
+ for iid in dynamics_list.keys():
135
+ dynamics[iid] = torch.tensor(dynamics_list[iid]).cuda()
136
+
137
+ cam_info = CameraInfo(K=intrinsic, c2w=c2w, image=np.array(image),
138
+ image_path=rgb_path, image_name=image_name, height=image.shape[0],
139
+ width=image.shape[1], semantic2d=semantic_2d,
140
+ optical_image=optical_image, depth=depth, mask=mask, timestamp=timestamp, dynamics=dynamics)
141
+
142
+ if data_type == 'kitti360':
143
+ if idx < 20:
144
+ train_cam_infos.append(cam_info)
145
+ elif idx % 20 < 16:
146
+ train_cam_infos.append(cam_info)
147
+ elif idx % 20 >= 16:
148
+ test_cam_infos.append(cam_info)
149
+ else:
150
+ continue
151
+
152
+ elif data_type == 'kitti':
153
+ if idx < 10 or idx >= len(frames) - 4:
154
+ train_cam_infos.append(cam_info)
155
+ elif idx % 4 < 2:
156
+ train_cam_infos.append(cam_info)
157
+ elif idx % 4 == 2:
158
+ test_cam_infos.append(cam_info)
159
+ else:
160
+ continue
161
+
162
+ elif data_type == "nuscenes":
163
+ if idx % 30 >= 24:
164
+ test_cam_infos.append(cam_info)
165
+ else:
166
+ train_cam_infos.append(cam_info)
167
+
168
+ elif data_type == "waymo":
169
+ if idx % 15 >= 12:
170
+ test_cam_infos.append(cam_info)
171
+ else:
172
+ train_cam_infos.append(cam_info)
173
+
174
+ elif data_type == "pandaset":
175
+ if idx > 30 and idx % 30 >= 24:
176
+ test_cam_infos.append(cam_info)
177
+ else:
178
+ train_cam_infos.append(cam_info)
179
+
180
+ else:
181
+ raise NotImplementedError
182
+
183
+ return train_cam_infos, test_cam_infos, verts
184
+
185
+
186
+ def readHUGSIMInfo(path, data_type, ignore_dynamic):
187
+ train_cam_infos, test_cam_infos, verts = readHUGSIMCameras(path, data_type, ignore_dynamic)
188
+
189
+ print(f'Loaded {len(train_cam_infos)} train cameras and {len(test_cam_infos)} test cameras')
190
+ nerf_normalization = getNerfppNorm(train_cam_infos, data_type)
191
+
192
+ ply_path = os.path.join(path, "points3d.ply")
193
+ if not os.path.exists(ply_path):
194
+ assert False, "Requires for initialize 3d points as inputs"
195
+ try:
196
+ pcd = fetchPly(ply_path)
197
+ except Exception as e:
198
+ print('When loading point clound, meet error:', e)
199
+ exit(0)
200
+
201
+ scene_info = SceneInfo(point_cloud=pcd,
202
+ train_cameras=train_cam_infos,
203
+ test_cameras=test_cam_infos,
204
+ nerf_normalization=nerf_normalization,
205
+ ply_path=ply_path,
206
+ verts=verts)
207
+ return scene_info
208
+
209
+
210
+ sceneLoadTypeCallbacks = {
211
+ "HUGSIM": readHUGSIMInfo,
212
+ }
code/scene/gaussian_model.py ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
4
+ from torch import nn
5
+ import os
6
+ from utils.system_utils import mkdir_p
7
+ from plyfile import PlyData, PlyElement
8
+ from utils.sh_utils import RGB2SH, SH2RGB
9
+ from simple_knn._C import distCUDA2
10
+ from utils.graphics_utils import BasicPointCloud
11
+ from utils.general_utils import strip_symmetric, build_scaling_rotation
12
+ import open3d as o3d
13
+ import tinycudann as tcnn
14
+ from math import sqrt
15
+ from scene.ground_model import GroundModel
16
+ from io import BytesIO
17
+
18
+
19
+ class GaussianModel:
20
+
21
+ def setup_functions(self):
22
+ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
23
+ L = build_scaling_rotation(scaling_modifier * scaling, rotation)
24
+ actual_covariance = L @ L.transpose(1, 2)
25
+ symm = strip_symmetric(actual_covariance)
26
+ return symm
27
+
28
+ self.scaling_activation = torch.exp
29
+ self.scaling_inverse_activation = torch.log
30
+
31
+ self.covariance_activation = build_covariance_from_scaling_rotation
32
+
33
+ self.opacity_activation = torch.sigmoid
34
+ self.inverse_opacity_activation = torch.logit
35
+
36
+ self.rotation_activation = torch.nn.functional.normalize
37
+
38
+
39
+ def __init__(self, sh_degree : int, feat_mutable=True, affine=False, ground_args=None):
40
+ self.active_sh_degree = 0
41
+ self.max_sh_degree = sh_degree
42
+ self._xyz = torch.empty(0)
43
+ self._features_dc = torch.empty(0)
44
+ self._features_rest = torch.empty(0)
45
+ self._feats3D = torch.empty(0)
46
+ self._scaling = torch.empty(0)
47
+ self._rotation = torch.empty(0)
48
+ self._opacity = torch.empty(0)
49
+ self.max_radii2D = torch.empty(0)
50
+ self.xyz_gradient_accum = torch.empty(0)
51
+ self.denom = torch.empty(0)
52
+ self.optimizer = None
53
+ self.percent_dense = 0
54
+ self.spatial_lr_scale = 0
55
+ self.feat_mutable = feat_mutable
56
+ self.setup_functions()
57
+
58
+ self.pos_enc = tcnn.Encoding(
59
+ n_input_dims=3,
60
+ encoding_config={"otype": "Frequency", "n_frequencies": 2},
61
+ )
62
+ self.dir_enc = tcnn.Encoding(
63
+ n_input_dims=3,
64
+ encoding_config={
65
+ "otype": "SphericalHarmonics",
66
+ "degree": 3,
67
+ },
68
+ )
69
+
70
+ self.affine = affine
71
+ if affine:
72
+ self.appearance_model = tcnn.Network(
73
+ n_input_dims=self.pos_enc.n_output_dims + self.dir_enc.n_output_dims,
74
+ n_output_dims=12,
75
+ network_config={
76
+ "otype": "FullyFusedMLP",
77
+ "activation": "ReLU",
78
+ "output_activation": "None",
79
+ "n_neurons": 32,
80
+ "n_hidden_layers": 2,
81
+ }
82
+ )
83
+ else:
84
+ self.appearance_model = None
85
+
86
+ if ground_args:
87
+ self.ground_model = GroundModel(sh_degree, model_args=ground_args, finetune=True)
88
+ else:
89
+ self.ground_model = None
90
+
91
+ def capture(self):
92
+ if self.ground_model is not None:
93
+ ground_model_params = self.ground_model.capture()
94
+ else:
95
+ ground_model_params = None
96
+ return (
97
+ self.active_sh_degree,
98
+ self._xyz,
99
+ self._features_dc,
100
+ self._features_rest,
101
+ self._feats3D,
102
+ self._scaling,
103
+ self._rotation,
104
+ self._opacity,
105
+ self.spatial_lr_scale,
106
+ self.appearance_model.state_dict(),
107
+ ground_model_params,
108
+ )
109
+
110
+ def restore(self, model_args, training_args):
111
+ (self.active_sh_degree,
112
+ self._xyz,
113
+ self._features_dc,
114
+ self._features_rest,
115
+ self._feats3D,
116
+ self._scaling,
117
+ self._rotation,
118
+ self._opacity,
119
+ self.spatial_lr_scale,
120
+ appearance_state_dict,
121
+ ground_model_params,
122
+ ) = model_args
123
+ self.appearance_model.load_state_dict(appearance_state_dict, strict=False)
124
+ if training_args is not None:
125
+ self.training_setup(training_args)
126
+ if ground_model_params is not None:
127
+ self.ground_model = GroundModel(self.max_sh_degree, model_args=ground_model_params)
128
+
129
+ @property
130
+ def get_scaling(self):
131
+ return self.scaling_activation(self._scaling)
132
+
133
+ @property
134
+ def get_full_scaling(self):
135
+ assert self.ground_model is not None
136
+ return torch.cat([self.scaling_activation(self._scaling), self.ground_model.get_scaling])
137
+
138
+ @property
139
+ def get_rotation(self):
140
+ return self.rotation_activation(self._rotation)
141
+
142
+ @property
143
+ def get_full_rotation(self):
144
+ assert self.ground_model is not None
145
+ return torch.cat([self.rotation_activation(self._rotation), self.ground_model.get_rotation])
146
+
147
+ @property
148
+ def get_xyz(self):
149
+ return self._xyz
150
+
151
+ @property
152
+ def get_full_xyz(self):
153
+ assert self.ground_model is not None
154
+ return torch.cat([self._xyz, self.ground_model.get_xyz])
155
+
156
+ @property
157
+ def get_features(self):
158
+ features_dc = self._features_dc
159
+ features_rest = self._features_rest
160
+ return torch.cat((features_dc, features_rest), dim=1)
161
+
162
+ @property
163
+ def get_full_features(self):
164
+ assert self.ground_model is not None
165
+ sh = torch.cat((self._features_dc, self._features_rest), dim=1)
166
+ return torch.cat([sh, self.ground_model.get_features])
167
+
168
+ @property
169
+ def get_3D_features(self):
170
+ return torch.softmax(self._feats3D, dim=-1)
171
+
172
+ @property
173
+ def get_full_3D_features(self):
174
+ assert self.ground_model is not None
175
+ return torch.cat([torch.softmax(self._feats3D, dim=-1), self.ground_model.get_3D_features])
176
+
177
+ @property
178
+ def get_opacity(self):
179
+ return self.opacity_activation(self._opacity)
180
+
181
+ @property
182
+ def get_full_opacity(self):
183
+ assert self.ground_model is not None
184
+ return torch.cat([self.opacity_activation(self._opacity), self.ground_model.get_opacity])
185
+
186
+ # def get_covariance(self, scaling_modifier = 1):
187
+ # return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
188
+
189
+ def oneupSHdegree(self):
190
+ if self.active_sh_degree < self.max_sh_degree:
191
+ self.active_sh_degree += 1
192
+
193
+ def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
194
+ # self.spatial_lr_scale = 1
195
+ self.spatial_lr_scale = spatial_lr_scale
196
+ fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
197
+ fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
198
+ features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
199
+ features[:, :3, 0 ] = fused_color
200
+ features[:, 3:, 1:] = 0.0
201
+
202
+ if self.feat_mutable:
203
+ feats3D = torch.rand(fused_color.shape[0], 20).float().cuda()
204
+ self._feats3D = nn.Parameter(feats3D.requires_grad_(True))
205
+ else:
206
+ feats3D = torch.zeros(fused_color.shape[0], 20).float().cuda()
207
+ feats3D[:, 13] = 1
208
+ self._feats3D = feats3D
209
+
210
+ print("Number of points at initialization : ", fused_point_cloud.shape[0])
211
+
212
+ dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
213
+ scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
214
+ rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
215
+ rots[:, 0] = 1
216
+
217
+ opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
218
+
219
+ self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
220
+ self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
221
+ self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
222
+ self._scaling = nn.Parameter(scales.requires_grad_(True))
223
+ self._rotation = nn.Parameter(rots.requires_grad_(True))
224
+ self._opacity = nn.Parameter(opacities.requires_grad_(True))
225
+ self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
226
+
227
+ def training_setup(self, training_args):
228
+ self.percent_dense = training_args.percent_dense
229
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
230
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
231
+
232
+ # self.spatial_lr_scale /= 3
233
+
234
+ l = [
235
+ {'params': [self._xyz], 'lr': training_args.position_lr_init*self.spatial_lr_scale, "name": "xyz"},
236
+ {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
237
+ {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
238
+ {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
239
+ {'params': [self._scaling], 'lr': training_args.scaling_lr*self.spatial_lr_scale, "name": "scaling"},
240
+ {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"},
241
+ ]
242
+
243
+ if self.affine:
244
+ l.append({'params': [*self.appearance_model.parameters()], 'lr': 1e-3, "name": "appearance_model"})
245
+
246
+ if self.feat_mutable:
247
+ l.append({'params': [self._feats3D], 'lr': 1e-2, "name": "feats3D"})
248
+
249
+ if self.ground_model is not None:
250
+ self.ground_optimizer = self.ground_model.optimizer
251
+ else:
252
+ self.ground_optimizer = None
253
+
254
+ self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
255
+ self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
256
+ lr_final=training_args.position_lr_final*self.spatial_lr_scale,
257
+ lr_delay_mult=training_args.position_lr_delay_mult,
258
+ max_steps=training_args.position_lr_max_steps)
259
+
260
+ def update_learning_rate(self, iteration):
261
+ ''' Learning rate scheduling per step '''
262
+ for param_group in self.optimizer.param_groups:
263
+ if param_group["name"] == "xyz":
264
+ lr = self.xyz_scheduler_args(iteration)
265
+ param_group['lr'] = lr
266
+ return lr
267
+
268
+ def construct_list_of_attributes(self):
269
+ l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
270
+ # All channels except the 3 DC
271
+ for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
272
+ l.append('f_dc_{}'.format(i))
273
+ for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
274
+ l.append('f_rest_{}'.format(i))
275
+ for i in range(self._feats3D.shape[1]):
276
+ l.append('semantic_{}'.format(i))
277
+ l.append('opacity')
278
+ for i in range(self._scaling.shape[1]):
279
+ l.append('scale_{}'.format(i))
280
+ for i in range(self._rotation.shape[1]):
281
+ l.append('rot_{}'.format(i))
282
+ return l
283
+
284
+ def save_ply(self, path=None):
285
+ mkdir_p(os.path.dirname(path))
286
+
287
+ if self.ground_model is not None:
288
+ xyz = self.get_full_xyz.detach().cpu().numpy()
289
+ normals = np.zeros_like(xyz)
290
+ f_dc = torch.cat([self._features_dc, self.ground_model._features_dc]).detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
291
+ f_rest = torch.cat([self._features_rest, self.ground_model._features_rest]).detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
292
+ feats3D = torch.cat([self._feats3D, self.ground_model._feats3D]).detach().cpu().numpy()
293
+ opacities = torch.cat([self._opacity, self.ground_model._opacity]).detach().cpu().numpy()
294
+ scale = self.scaling_inverse_activation(self.get_full_scaling).detach().cpu().numpy()
295
+ rotation = torch.cat([self._rotation, self.ground_model._rotation]).detach().cpu().numpy()
296
+ else:
297
+ xyz = self.get_xyz.detach().cpu().numpy()
298
+ normals = np.zeros_like(xyz)
299
+ f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
300
+ f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
301
+ feats3D = self._feats3D.detach().cpu().numpy()
302
+ opacities = self._opacity.detach().cpu().numpy()
303
+ scale = self.scaling_inverse_activation(self.get_scaling).detach().cpu().numpy()
304
+ rotation = self._rotation.detach().cpu().numpy()
305
+
306
+ dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
307
+
308
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
309
+ attributes = np.concatenate((xyz, normals, f_dc, f_rest, feats3D, opacities, scale, rotation), axis=1)
310
+ elements[:] = list(map(tuple, attributes))
311
+ el = PlyElement.describe(elements, 'vertex')
312
+ plydata = PlyData([el])
313
+ if path is not None:
314
+ plydata.write(path)
315
+ return plydata
316
+
317
+ def save_splat(self, ply_path, splat_path):
318
+ plydata = self.save_ply(ply_path)
319
+ vert = plydata["vertex"]
320
+ sorted_indices = np.argsort(
321
+ -np.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"])
322
+ / (1 + np.exp(-vert["opacity"]))
323
+ )
324
+ buffer = BytesIO()
325
+ for idx in sorted_indices:
326
+ v = plydata["vertex"][idx]
327
+ position = np.array([v["x"], v["y"], v["z"]], dtype=np.float32)
328
+ scales = np.exp(
329
+ np.array(
330
+ [v["scale_0"], v["scale_1"], v["scale_2"]],
331
+ dtype=np.float32,
332
+ )
333
+ )
334
+ rot = np.array(
335
+ [v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]],
336
+ dtype=np.float32,
337
+ )
338
+ SH_C0 = 0.28209479177387814
339
+ color = np.array(
340
+ [
341
+ 0.5 + SH_C0 * v["f_dc_0"],
342
+ 0.5 + SH_C0 * v["f_dc_1"],
343
+ 0.5 + SH_C0 * v["f_dc_2"],
344
+ 1 / (1 + np.exp(-v["opacity"])),
345
+ ]
346
+ )
347
+ buffer.write(position.tobytes())
348
+ buffer.write(scales.tobytes())
349
+ buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes())
350
+ buffer.write(
351
+ ((rot / np.linalg.norm(rot)) * 128 + 128)
352
+ .clip(0, 255)
353
+ .astype(np.uint8)
354
+ .tobytes()
355
+ )
356
+ with open(splat_path, "wb") as f:
357
+ f.write(buffer.getvalue())
358
+
359
+ def save_semantic_pcd(self, path):
360
+ color_dict = {
361
+ 0: np.array([128, 64, 128]), # Road
362
+ 1: np.array([244, 35, 232]), # Sidewalk
363
+ 2: np.array([70, 70, 70]), # Building
364
+ 3: np.array([102, 102, 156]), # Wall
365
+ 4: np.array([190, 153, 153]), # Fence
366
+ 5: np.array([153, 153, 153]), # Pole
367
+ 6: np.array([250, 170, 30]), # Traffic Light
368
+ 7: np.array([220, 220, 0]), # Traffic Sign
369
+ 8: np.array([107, 142, 35]), # Vegetation
370
+ 9: np.array([152, 251, 152]), # Terrain
371
+ 10: np.array([0, 0, 0]), # Black (trainId 10)
372
+ 11: np.array([70, 130, 180]), # Sky
373
+ 12: np.array([220, 20, 60]), # Person
374
+ 13: np.array([255, 0, 0]), # Rider
375
+ 14: np.array([0, 0, 142]), # Car
376
+ 15: np.array([0, 0, 70]), # Truck
377
+ 16: np.array([0, 60, 100]), # Bus
378
+ 17: np.array([0, 80, 100]), # Train
379
+ 18: np.array([0, 0, 230]), # Motorcycle
380
+ 19: np.array([119, 11, 32]) # Bicycle
381
+ }
382
+ semantic_idx = torch.argmax(self.get_full_3D_features, dim=-1, keepdim=True)
383
+ opacities = self.get_full_opacity[:, 0]
384
+ mask = ((semantic_idx != 10)[:, 0]) & ((semantic_idx != 8)[:, 0]) & (opacities > 0.2)
385
+
386
+ semantic_idx = semantic_idx[mask]
387
+ semantic_rgb = torch.zeros_like(semantic_idx).repeat(1, 3)
388
+ for idx in range(20):
389
+ rgb = torch.from_numpy(color_dict[idx]).to(semantic_rgb.device)[None, :]
390
+ semantic_rgb[(semantic_idx == idx)[:, 0], :] = rgb
391
+ semantic_rgb = semantic_rgb.float() / 255.0
392
+ pcd_xyz = self.get_full_xyz[mask]
393
+ smt_pcd = o3d.geometry.PointCloud()
394
+ smt_pcd.points = o3d.utility.Vector3dVector(pcd_xyz.detach().cpu().numpy())
395
+ smt_pcd.colors = o3d.utility.Vector3dVector(semantic_rgb.detach().cpu().numpy())
396
+ o3d.io.write_point_cloud(path, smt_pcd)
397
+
398
+ def save_vis_ply(self, path):
399
+ mkdir_p(os.path.dirname(path))
400
+ xyz = self.get_xyz.detach().cpu().numpy()
401
+ if self.ground_model:
402
+ xyz = np.concatenate([xyz, self.ground_model.get_xyz.detach().cpu().numpy()])
403
+ pcd = o3d.geometry.PointCloud()
404
+ pcd.points = o3d.utility.Vector3dVector(xyz)
405
+ colors = SH2RGB(self._features_dc[:, 0, :].detach().cpu().numpy()).clip(0, 1)
406
+ if self.ground_model:
407
+ ground_colors = SH2RGB(self.ground_model._features_dc[:, 0, :].detach().cpu().numpy()).clip(0, 1)
408
+ colors = np.concatenate([colors, ground_colors])
409
+ pcd.colors = o3d.utility.Vector3dVector(colors)
410
+ o3d.io.write_point_cloud(path, pcd)
411
+
412
+ def reset_opacity(self):
413
+ opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
414
+ optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
415
+ self._opacity = optimizable_tensors["opacity"]
416
+
417
+ def load_ply(self, path):
418
+ plydata = PlyData.read(path)
419
+
420
+ xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
421
+ np.asarray(plydata.elements[0]["y"]),
422
+ np.asarray(plydata.elements[0]["z"])), axis=1)
423
+ opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
424
+
425
+ features_dc = np.zeros((xyz.shape[0], 3, 1))
426
+ features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
427
+ features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
428
+ features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
429
+
430
+ extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
431
+ assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
432
+ features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
433
+ for idx, attr_name in enumerate(extra_f_names):
434
+ features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
435
+ # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
436
+ features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
437
+
438
+ scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
439
+ scales = np.zeros((xyz.shape[0], len(scale_names)))
440
+ for idx, attr_name in enumerate(scale_names):
441
+ scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
442
+
443
+ rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
444
+ rots = np.zeros((xyz.shape[0], len(rot_names)))
445
+ for idx, attr_name in enumerate(rot_names):
446
+ rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
447
+
448
+ self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
449
+ self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
450
+ self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
451
+ self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
452
+ self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
453
+ self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
454
+
455
+ self.active_sh_degree = self.max_sh_degree
456
+
457
+ def replace_tensor_to_optimizer(self, tensor, name):
458
+ optimizable_tensors = {}
459
+ for group in self.optimizer.param_groups:
460
+ if group["name"] == name:
461
+ stored_state = self.optimizer.state.get(group['params'][0], None)
462
+ if stored_state is not None:
463
+ stored_state["exp_avg"] = torch.zeros_like(tensor)
464
+ stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
465
+ del self.optimizer.state[group['params'][0]]
466
+ group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
467
+ self.optimizer.state[group['params'][0]] = stored_state
468
+ optimizable_tensors[group["name"]] = group["params"][0]
469
+ else:
470
+ group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
471
+ optimizable_tensors[group["name"]] = group["params"][0]
472
+ return optimizable_tensors
473
+
474
+ def _prune_optimizer(self, mask):
475
+ optimizable_tensors = {}
476
+ for group in self.optimizer.param_groups:
477
+ if group['name'] == 'appearance_model':
478
+ continue
479
+ stored_state = self.optimizer.state.get(group['params'][0], None)
480
+ if stored_state is not None:
481
+ stored_state["exp_avg"] = stored_state["exp_avg"][mask]
482
+ stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
483
+
484
+ del self.optimizer.state[group['params'][0]]
485
+ group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
486
+ self.optimizer.state[group['params'][0]] = stored_state
487
+
488
+ optimizable_tensors[group["name"]] = group["params"][0]
489
+ else:
490
+ group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
491
+ optimizable_tensors[group["name"]] = group["params"][0]
492
+ return optimizable_tensors
493
+
494
+ def prune_points(self, mask):
495
+ valid_points_mask = ~mask
496
+ optimizable_tensors = self._prune_optimizer(valid_points_mask)
497
+
498
+ self._xyz = optimizable_tensors["xyz"]
499
+ self._features_dc = optimizable_tensors["f_dc"]
500
+ self._features_rest = optimizable_tensors["f_rest"]
501
+ if self.feat_mutable:
502
+ self._feats3D = optimizable_tensors["feats3D"]
503
+ else:
504
+ self._feats3D = self._feats3D[1, :].repeat((self._xyz.shape[0], 1))
505
+ self._opacity = optimizable_tensors["opacity"]
506
+ self._scaling = optimizable_tensors["scaling"]
507
+ self._rotation = optimizable_tensors["rotation"]
508
+
509
+ self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
510
+
511
+ self.denom = self.denom[valid_points_mask]
512
+ self.max_radii2D = self.max_radii2D[valid_points_mask]
513
+
514
+ def cat_tensors_to_optimizer(self, tensors_dict):
515
+ optimizable_tensors = {}
516
+ for group in self.optimizer.param_groups:
517
+ if group['name'] not in tensors_dict:
518
+ continue
519
+ assert len(group["params"]) == 1
520
+ extension_tensor = tensors_dict[group["name"]]
521
+ stored_state = self.optimizer.state.get(group["params"][0], None)
522
+ if stored_state is not None:
523
+
524
+ stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
525
+ stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
526
+
527
+ del self.optimizer.state[group["params"][0]]
528
+ group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
529
+ self.optimizer.state[group["params"][0]] = stored_state
530
+
531
+ optimizable_tensors[group["name"]] = group["params"][0]
532
+ else:
533
+ group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
534
+ optimizable_tensors[group["name"]] = group["params"][0]
535
+
536
+ return optimizable_tensors
537
+
538
+ def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_feats3D, new_opacities, new_scaling, new_rotation):
539
+ d = {"xyz": new_xyz,
540
+ "f_dc": new_features_dc,
541
+ "f_rest": new_features_rest,
542
+ "feats3D": new_feats3D,
543
+ "opacity": new_opacities,
544
+ "scaling" : new_scaling,
545
+ "rotation" : new_rotation}
546
+
547
+ optimizable_tensors = self.cat_tensors_to_optimizer(d)
548
+ self._xyz = optimizable_tensors["xyz"]
549
+ self._features_dc = optimizable_tensors["f_dc"]
550
+ if self.feat_mutable:
551
+ self._feats3D = optimizable_tensors["feats3D"]
552
+ else:
553
+ self._feats3D = self._feats3D[1, :].repeat((self._xyz.shape[0], 1))
554
+ self._features_rest = optimizable_tensors["f_rest"]
555
+ self._opacity = optimizable_tensors["opacity"]
556
+ self._scaling = optimizable_tensors["scaling"]
557
+ self._rotation = optimizable_tensors["rotation"]
558
+
559
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
560
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
561
+ self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
562
+
563
+ def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
564
+ n_init_points = self.get_xyz.shape[0]
565
+ # Extract points that satisfy the gradient condition
566
+ padded_grad = torch.zeros((n_init_points), device="cuda")
567
+ padded_grad[:grads.shape[0]] = grads.squeeze()
568
+ selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
569
+ selected_pts_mask = torch.logical_and(selected_pts_mask,
570
+ torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
571
+
572
+ stds = self.get_scaling[selected_pts_mask].repeat(N,1)
573
+ means =torch.zeros((stds.size(0), 3),device="cuda")
574
+ samples = torch.normal(mean=means, std=stds)
575
+ rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
576
+ new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
577
+ new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
578
+ new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
579
+ new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
580
+ new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
581
+ new_feats3D = self._feats3D[selected_pts_mask].repeat(N,1)
582
+ new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
583
+
584
+ self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_feats3D, new_opacity, new_scaling, new_rotation)
585
+
586
+ prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
587
+ self.prune_points(prune_filter)
588
+
589
+ def densify_and_clone(self, grads, grad_threshold, scene_extent):
590
+ # Extract points that satisfy the gradient condition
591
+ selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
592
+ selected_pts_mask = torch.logical_and(selected_pts_mask,
593
+ torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
594
+
595
+ new_xyz = self._xyz[selected_pts_mask]
596
+ new_features_dc = self._features_dc[selected_pts_mask]
597
+ new_features_rest = self._features_rest[selected_pts_mask]
598
+ new_feats3D = self._feats3D[selected_pts_mask]
599
+ new_opacities = self._opacity[selected_pts_mask]
600
+ new_scaling = self._scaling[selected_pts_mask]
601
+ new_rotation = self._rotation[selected_pts_mask]
602
+
603
+ self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_feats3D, new_opacities, new_scaling, new_rotation)
604
+
605
+ def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size, cam_pos=None):
606
+ grads = self.xyz_gradient_accum / self.denom
607
+ grads[grads.isnan()] = 0.0
608
+
609
+ self.densify_and_clone(grads, max_grad, extent)
610
+ self.densify_and_split(grads, max_grad, extent)
611
+
612
+ prune_mask = (self.get_opacity < min_opacity).squeeze()
613
+ if max_screen_size:
614
+ big_points_vs = self.max_radii2D > max_screen_size
615
+ if cam_pos is not None:
616
+ # points_cam_dist = torch.abs(self.get_xyz[:, None, :] - cam_pos[None, ...])
617
+ # points_cam_nearest_idx = torch.argmin(torch.norm(points_cam_dist, dim=-1), dim=1)
618
+ # points_cam_dist = points_cam_dist[torch.arange(points_cam_dist.shape[0]), points_cam_nearest_idx, :]
619
+ # near_mask1 = (points_cam_dist[:, 1] < 5) & (points_cam_dist[:, 0] < 10) & (points_cam_dist[:, 2] < 10)
620
+ # big_points_ws1 = near_mask1 & (self.get_scaling.max(dim=1).values > 1.0)
621
+ # near_mask2 = (points_cam_dist[:, 1] < 10) & (points_cam_dist[:, 0] < 20) & (points_cam_dist[:, 2] < 20)
622
+ # big_points_ws2 = near_mask2 & (self.get_scaling.max(dim=1).values > 3.0)
623
+ # big_points_ws = (self.get_scaling.max(dim=1).values > 10.0) | big_points_ws1 | big_points_ws2
624
+ big_points_ws = self.get_scaling.max(dim=1).values > 10
625
+ prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
626
+ else:
627
+ big_points_ws = self.get_scaling.max(dim=1).values > 5
628
+ prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
629
+ self.prune_points(prune_mask)
630
+
631
+ torch.cuda.empty_cache()
632
+
633
+ def add_densification_stats_grad(self, tensor_grad, update_filter):
634
+ self.xyz_gradient_accum[update_filter] += torch.norm(tensor_grad[update_filter,:2], dim=-1, keepdim=True)
635
+ self.denom[update_filter] += 1
636
+
code/scene/ground_model.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
4
+ from torch import nn
5
+ import os
6
+ from utils.system_utils import mkdir_p
7
+ from plyfile import PlyData, PlyElement
8
+ from utils.sh_utils import RGB2SH, SH2RGB
9
+ from simple_knn._C import distCUDA2
10
+ from utils.graphics_utils import BasicPointCloud
11
+ from utils.general_utils import strip_symmetric, build_scaling_rotation
12
+ import open3d as o3d
13
+ import math
14
+ from utils.graphics_utils import BasicPointCloud
15
+ from utils.sh_utils import RGB2SH
16
+
17
+ class GroundModel:
18
+
19
+ def setup_functions(self):
20
+ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
21
+ L = build_scaling_rotation(scaling_modifier * scaling, rotation)
22
+ actual_covariance = L @ L.transpose(1, 2)
23
+ symm = strip_symmetric(actual_covariance)
24
+ return symm
25
+
26
+ self.scaling_activation = torch.exp
27
+ self.scaling_inverse_activation = torch.log
28
+
29
+ self.covariance_activation = build_covariance_from_scaling_rotation
30
+
31
+ self.opacity_activation = torch.sigmoid
32
+ self.inverse_opacity_activation = torch.logit
33
+
34
+ self.rotation_activation = torch.nn.functional.normalize
35
+
36
+
37
+ def __init__(self, sh_degree: int, ground_pcd: BasicPointCloud=None, model_args=None, finetune=False):
38
+ assert not ((ground_pcd is None) and (model_args is None)), "Need at least one way of initialization"
39
+ self.active_sh_degree = 0
40
+ self.max_sh_degree = sh_degree
41
+
42
+ self.scale = 0.1
43
+
44
+ if ground_pcd is not None:
45
+ self._xyz = nn.Parameter(torch.from_numpy(ground_pcd.points).float().cuda())
46
+ fused_color = RGB2SH(torch.tensor(np.asarray(ground_pcd.colors)).float().cuda())
47
+ features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
48
+ features[:, :3, 0 ] = fused_color
49
+ features[:, 3:, 1:] = 0.0
50
+ self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
51
+ self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
52
+
53
+ self._feats3D = torch.zeros((self._xyz.shape[0], 20)).cuda()
54
+ self._feats3D[:, 1] = 1
55
+ self._feats3D = nn.Parameter(self._feats3D)
56
+ self._rotation = torch.zeros((self._xyz.shape[0], 4)).cuda()
57
+ self._rotation[:, 0] = 1
58
+ self._opacity = inverse_sigmoid(torch.ones((self._xyz.shape[0], 1)).cuda() * 0.99)
59
+ self._scaling = nn.Parameter(torch.ones((self._xyz.shape[0], 2)).float().cuda() * math.log(self.scale))
60
+
61
+ self.max_radii2D = torch.zeros((self._xyz.shape[0]), device="cuda")
62
+ self.percent_dense = 0.01
63
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
64
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
65
+ else:
66
+ self.restore(model_args)
67
+
68
+ if finetune:
69
+ self.param_groups = [
70
+ {'params': [self._features_dc], 'lr': 2.5e-3, "name": "f_dc"},
71
+ {'params': [self._features_rest], 'lr': 2.5e-3 / 20.0, "name": "f_rest"},
72
+ {'params': [self._feats3D], 'lr': 1e-3, "name": "feats3D"},
73
+ ]
74
+ else:
75
+ self.param_groups = [
76
+ {'params': [self._xyz], 'lr': 1.6e-4, "name": "xyz"},
77
+ {'params': [self._features_dc], 'lr': 2.5e-3, "name": "f_dc"},
78
+ {'params': [self._features_rest], 'lr': 2.5e-3 / 20.0, "name": "f_rest"},
79
+ {'params': [self._feats3D], 'lr': 1e-2, "name": "feats3D"},
80
+ {'params': [self._opacity], 'lr': 0.05, "name": "opacity"},
81
+ {'params': [self._scaling], 'lr': 1e-3, "name": "scaling"},
82
+ ]
83
+ self.optimizer = torch.optim.Adam(self.param_groups, lr=0.0, eps=1e-15)
84
+ self.setup_functions()
85
+
86
+ def capture(self):
87
+ return (
88
+ self.active_sh_degree,
89
+ self._xyz,
90
+ # self._y,
91
+ # self._z,
92
+ self._features_dc,
93
+ self._features_rest,
94
+ self._feats3D,
95
+ self._scaling,
96
+ self._rotation,
97
+ self._opacity,
98
+ )
99
+
100
+ def restore(self, model_args):
101
+ (self.active_sh_degree,
102
+ self._xyz,
103
+ # self._y,
104
+ # self._z,
105
+ self._features_dc,
106
+ self._features_rest,
107
+ self._feats3D,
108
+ self._scaling,
109
+ self._rotation,
110
+ self._opacity) = model_args
111
+
112
+ @property
113
+ def get_scaling(self):
114
+ scale_y = torch.ones_like(self._xyz[:, 0]) * math.log(0.001)
115
+ scaling = torch.stack((self._scaling[:, 0], scale_y, self._scaling[:, 1]), dim=1).cuda()
116
+ # scaling = torch.stack((self._scaling, scale_y, self._scaling), dim=1).cuda()
117
+ return self.scaling_activation(scaling)
118
+
119
+ @property
120
+ def get_rotation(self):
121
+ return self.rotation_activation(self._rotation)
122
+
123
+ @property
124
+ def get_xyz(self):
125
+ return self._xyz
126
+
127
+ @property
128
+ def get_features(self):
129
+ features_dc = self._features_dc
130
+ features_rest = self._features_rest
131
+ return torch.cat((features_dc, features_rest), dim=1)
132
+
133
+ @property
134
+ def get_3D_features(self):
135
+ return torch.softmax(self._feats3D, dim=-1)
136
+
137
+ @property
138
+ def get_opacity(self):
139
+ return self.opacity_activation(self._opacity)
140
+
141
+ def get_covariance(self, scaling_modifier = 1):
142
+ return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
143
+
144
+ def oneupSHdegree(self):
145
+ if self.active_sh_degree < self.max_sh_degree:
146
+ self.active_sh_degree += 1
147
+
148
+ def construct_list_of_attributes(self):
149
+ l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
150
+ # All channels except the 3 DC
151
+ for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
152
+ l.append('f_dc_{}'.format(i))
153
+ for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
154
+ l.append('f_rest_{}'.format(i))
155
+ for i in range(self._feats3D.shape[1]):
156
+ l.append('semantic_{}'.format(i))
157
+ l.append('opacity')
158
+ for i in range(self._scaling.shape[1]):
159
+ l.append('scale_{}'.format(i))
160
+ for i in range(self._rotation.shape[1]):
161
+ l.append('rot_{}'.format(i))
162
+ return l
163
+
164
+ def save_ply(self, path):
165
+ mkdir_p(os.path.dirname(path))
166
+
167
+ xyz = self.get_xyz.detach().cpu().numpy()
168
+ normals = np.zeros_like(xyz)
169
+ f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
170
+ f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
171
+ feats3D = self._feats3D.detach().cpu().numpy()
172
+ opacities = self._opacity.detach().cpu().numpy()
173
+ scale = self._scaling.detach().cpu().numpy()
174
+ rotation = self._rotation.detach().cpu().numpy()
175
+
176
+ dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
177
+
178
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
179
+ attributes = np.concatenate((xyz, normals, f_dc, f_rest, feats3D, opacities, scale, rotation), axis=1)
180
+ elements[:] = list(map(tuple, attributes))
181
+ el = PlyElement.describe(elements, 'vertex')
182
+ PlyData([el]).write(path)
183
+
184
+ def save_vis_ply(self, path):
185
+ mkdir_p(os.path.dirname(path))
186
+ xyz = self.get_xyz.detach().cpu().numpy()
187
+ pcd = o3d.geometry.PointCloud()
188
+ pcd.points = o3d.utility.Vector3dVector(xyz)
189
+ colors = SH2RGB(self._features_dc[:, 0, :].detach().cpu().numpy()).clip(0, 1)
190
+ pcd.colors = o3d.utility.Vector3dVector(colors)
191
+ o3d.io.write_point_cloud(path, pcd)
192
+
193
+ def reset_opacity(self):
194
+ opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
195
+ optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
196
+ self._opacity = optimizable_tensors["opacity"]
197
+
198
+ def replace_tensor_to_optimizer(self, tensor, name):
199
+ optimizable_tensors = {}
200
+ for group in self.optimizer.param_groups:
201
+ if group["name"] == name:
202
+ stored_state = self.optimizer.state.get(group['params'][0], None)
203
+ if stored_state is not None:
204
+ stored_state["exp_avg"] = torch.zeros_like(tensor)
205
+ stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
206
+ del self.optimizer.state[group['params'][0]]
207
+ group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
208
+ self.optimizer.state[group['params'][0]] = stored_state
209
+ optimizable_tensors[group["name"]] = group["params"][0]
210
+ else:
211
+ group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
212
+ optimizable_tensors[group["name"]] = group["params"][0]
213
+ return optimizable_tensors
214
+
215
+ def _prune_optimizer(self, mask):
216
+ optimizable_tensors = {}
217
+ for group in self.optimizer.param_groups:
218
+ if group['name'] == 'appearance_model':
219
+ continue
220
+ stored_state = self.optimizer.state.get(group['params'][0], None)
221
+ if stored_state is not None:
222
+ stored_state["exp_avg"] = stored_state["exp_avg"][mask]
223
+ stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
224
+
225
+ del self.optimizer.state[group['params'][0]]
226
+ group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
227
+ self.optimizer.state[group['params'][0]] = stored_state
228
+
229
+ optimizable_tensors[group["name"]] = group["params"][0]
230
+ else:
231
+ group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
232
+ optimizable_tensors[group["name"]] = group["params"][0]
233
+ return optimizable_tensors
234
+
235
+ def prune_points(self, mask):
236
+ valid_points_mask = ~mask
237
+ optimizable_tensors = self._prune_optimizer(valid_points_mask)
238
+
239
+ self._xyz = optimizable_tensors["xyz"]
240
+ self._features_dc = optimizable_tensors["f_dc"]
241
+ self._features_rest = optimizable_tensors["f_rest"]
242
+ self._feats3D = optimizable_tensors["feats3D"]
243
+ self._opacity = optimizable_tensors["opacity"]
244
+ self._scaling = optimizable_tensors["scaling"]
245
+ self._rotation = self._rotation[0, :].repeat((self._xyz.shape[0], 1))
246
+
247
+ self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
248
+
249
+ self.denom = self.denom[valid_points_mask]
250
+ self.max_radii2D = self.max_radii2D[valid_points_mask]
251
+
252
+ def cat_tensors_to_optimizer(self, tensors_dict):
253
+ optimizable_tensors = {}
254
+ for group in self.optimizer.param_groups:
255
+ if group['name'] not in tensors_dict:
256
+ continue
257
+ assert len(group["params"]) == 1
258
+ extension_tensor = tensors_dict[group["name"]]
259
+ stored_state = self.optimizer.state.get(group["params"][0], None)
260
+ if stored_state is not None:
261
+ stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
262
+ stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
263
+
264
+ del self.optimizer.state[group["params"][0]]
265
+ group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
266
+ self.optimizer.state[group["params"][0]] = stored_state
267
+
268
+ optimizable_tensors[group["name"]] = group["params"][0]
269
+ else:
270
+ group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
271
+ optimizable_tensors[group["name"]] = group["params"][0]
272
+
273
+ return optimizable_tensors
274
+
275
+ def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_feats3D, new_opacities, new_scaling, new_rotation):
276
+ d = {"xyz": new_xyz,
277
+ "f_dc": new_features_dc,
278
+ "f_rest": new_features_rest,
279
+ "feats3D": new_feats3D,
280
+ "opacity": new_opacities,
281
+ "scaling" : new_scaling}
282
+
283
+ optimizable_tensors = self.cat_tensors_to_optimizer(d)
284
+ self._xyz = optimizable_tensors["xyz"]
285
+ self._features_dc = optimizable_tensors["f_dc"]
286
+ self._feats3D = optimizable_tensors["feats3D"]
287
+ self._features_rest = optimizable_tensors["f_rest"]
288
+ self._opacity = optimizable_tensors["opacity"]
289
+ self._scaling = optimizable_tensors["scaling"]
290
+ self._rotation = self._rotation[0, :].repeat((self._xyz.shape[0], 1))
291
+
292
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
293
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
294
+ self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
295
+
296
+ def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
297
+ n_init_points = self.get_xyz.shape[0]
298
+ # Extract points that satisfy the gradient condition
299
+ padded_grad = torch.zeros((n_init_points), device="cuda")
300
+ padded_grad[:grads.shape[0]] = grads.squeeze()
301
+ selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
302
+ selected_pts_mask = torch.logical_and(selected_pts_mask,
303
+ torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
304
+
305
+ stds = self.get_scaling[selected_pts_mask].repeat(N,1)
306
+ means =torch.zeros((stds.size(0), 3),device="cuda")
307
+ samples = torch.normal(mean=means, std=stds)
308
+ rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
309
+ new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
310
+ new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))[:, [0,2]]
311
+ new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
312
+ new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
313
+ new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
314
+ new_feats3D = self._feats3D[selected_pts_mask].repeat(N,1)
315
+ new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
316
+
317
+ self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_feats3D, new_opacity, new_scaling, new_rotation)
318
+
319
+ prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
320
+ self.prune_points(prune_filter)
321
+
322
+ def densify_and_clone(self, grads, grad_threshold, scene_extent):
323
+ # Extract points that satisfy the gradient condition
324
+ selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
325
+ selected_pts_mask = torch.logical_and(selected_pts_mask,
326
+ torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
327
+
328
+ new_xyz = self._xyz[selected_pts_mask]
329
+ new_features_dc = self._features_dc[selected_pts_mask]
330
+ new_features_rest = self._features_rest[selected_pts_mask]
331
+ new_feats3D = self._feats3D[selected_pts_mask]
332
+ new_opacities = self._opacity[selected_pts_mask]
333
+ new_scaling = self._scaling[selected_pts_mask]
334
+ new_rotation = self._rotation[selected_pts_mask]
335
+
336
+ self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_feats3D, new_opacities, new_scaling, new_rotation)
337
+
338
+ def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
339
+ grads = self.xyz_gradient_accum / self.denom
340
+ grads[grads.isnan()] = 0.0
341
+
342
+ self.densify_and_clone(grads, max_grad, extent)
343
+ self.densify_and_split(grads, max_grad, extent)
344
+
345
+ prune_mask = (self.get_opacity < min_opacity).squeeze()
346
+ if max_screen_size:
347
+ big_points_vs = self.max_radii2D > max_screen_size
348
+ big_points_ws = self.get_scaling.max(dim=1).values > 1.0
349
+ prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
350
+ self.prune_points(prune_mask)
351
+
352
+ torch.cuda.empty_cache()
353
+
354
+ def add_densification_stats(self, viewspace_point_tensor, update_filter):
355
+ self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
356
+ self.denom[update_filter] += 1
357
+
358
+ def add_densification_stats_grad(self, tensor_grad, update_filter):
359
+ self.xyz_gradient_accum[update_filter] += torch.norm(tensor_grad[update_filter,:2], dim=-1, keepdim=True)
360
+ self.denom[update_filter] += 1
code/scene/obj_model.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
4
+ from torch import nn
5
+ import os
6
+ from utils.system_utils import mkdir_p
7
+ from plyfile import PlyData, PlyElement
8
+ from utils.sh_utils import RGB2SH, SH2RGB
9
+ from simple_knn._C import distCUDA2
10
+ from utils.graphics_utils import BasicPointCloud
11
+ from utils.general_utils import strip_symmetric, build_scaling_rotation
12
+ import open3d as o3d
13
+ import tinycudann as tcnn
14
+ from io import BytesIO
15
+
16
+
17
+ class ObjModel:
18
+
19
+ def setup_functions(self):
20
+ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
21
+ L = build_scaling_rotation(scaling_modifier * scaling, rotation)
22
+ actual_covariance = L @ L.transpose(1, 2)
23
+ symm = strip_symmetric(actual_covariance)
24
+ return symm
25
+
26
+ self.scaling_activation = torch.exp
27
+ self.scaling_inverse_activation = torch.log
28
+
29
+ self.covariance_activation = build_covariance_from_scaling_rotation
30
+
31
+ self.opacity_activation = torch.sigmoid
32
+ self.inverse_opacity_activation = torch.logit
33
+
34
+ self.rotation_activation = torch.nn.functional.normalize
35
+
36
+
37
+ def __init__(self, sh_degree : int, feat_mutable=True, affine=False):
38
+ self.active_sh_degree = 0
39
+ self.max_sh_degree = sh_degree
40
+ self._xyz = torch.empty(0)
41
+ self._features_dc = torch.empty(0)
42
+ self._features_rest = torch.empty(0)
43
+ self._feats3D = torch.empty(0)
44
+ self._scaling = torch.empty(0)
45
+ self._rotation = torch.empty(0)
46
+ self._opacity = torch.empty(0)
47
+ self.max_radii2D = torch.empty(0)
48
+ self.xyz_gradient_accum = torch.empty(0)
49
+ self.denom = torch.empty(0)
50
+ self.optimizer = None
51
+ self.percent_dense = 0
52
+ self.spatial_lr_scale = 0
53
+ self.feat_mutable = feat_mutable
54
+ self.setup_functions()
55
+
56
+ self.pos_enc = tcnn.Encoding(
57
+ n_input_dims=3,
58
+ encoding_config={"otype": "Frequency", "n_frequencies": 2},
59
+ )
60
+ self.dir_enc = tcnn.Encoding(
61
+ n_input_dims=3,
62
+ encoding_config={
63
+ "otype": "SphericalHarmonics",
64
+ "degree": 3,
65
+ },
66
+ )
67
+
68
+ self.affine = affine
69
+ if affine:
70
+ self.appearance_model = tcnn.Network(
71
+ n_input_dims=self.pos_enc.n_output_dims + self.dir_enc.n_output_dims,
72
+ n_output_dims=12,
73
+ network_config={
74
+ "otype": "FullyFusedMLP",
75
+ "activation": "ReLU",
76
+ "output_activation": "None",
77
+ "n_neurons": 32,
78
+ "n_hidden_layers": 2,
79
+ }
80
+ )
81
+ else:
82
+ self.appearance_model = None
83
+
84
+ def capture(self):
85
+ return (
86
+ self.active_sh_degree,
87
+ self._xyz,
88
+ self._features_dc,
89
+ self._features_rest,
90
+ self._feats3D,
91
+ self._scaling,
92
+ self._rotation,
93
+ self._opacity,
94
+ self.spatial_lr_scale,
95
+ )
96
+
97
+ def restore(self, model_args, training_args):
98
+ (self.active_sh_degree,
99
+ self._xyz,
100
+ self._features_dc,
101
+ self._features_rest,
102
+ self._feats3D,
103
+ self._scaling,
104
+ self._rotation,
105
+ self._opacity,
106
+ self.spatial_lr_scale,
107
+ ) = model_args
108
+ if training_args is not None:
109
+ self.training_setup(training_args)
110
+
111
+ @property
112
+ def get_scaling(self):
113
+ return self.scaling_activation(self._scaling)
114
+
115
+ @property
116
+ def get_rotation(self):
117
+ return self.rotation_activation(self._rotation)
118
+
119
+ @property
120
+ def get_xyz(self):
121
+ return self._xyz
122
+
123
+ @property
124
+ def get_features(self):
125
+ features_dc = self._features_dc
126
+ features_rest = self._features_rest
127
+ return torch.cat((features_dc, features_rest), dim=1)
128
+
129
+ @property
130
+ def get_3D_features(self):
131
+ return torch.softmax(self._feats3D, dim=-1)
132
+
133
+ @property
134
+ def get_opacity(self):
135
+ return self.opacity_activation(self._opacity)
136
+
137
+ # def get_covariance(self, scaling_modifier = 1):
138
+ # return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
139
+
140
+ def oneupSHdegree(self):
141
+ if self.active_sh_degree < self.max_sh_degree:
142
+ self.active_sh_degree += 1
143
+
144
+ def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
145
+ # self.spatial_lr_scale = 1
146
+ self.spatial_lr_scale = spatial_lr_scale
147
+ fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
148
+ fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
149
+ features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
150
+ features[:, :3, 0 ] = fused_color
151
+ features[:, 3:, 1:] = 0.0
152
+
153
+ if self.feat_mutable:
154
+ feats3D = torch.rand(fused_color.shape[0], 20).float().cuda()
155
+ self._feats3D = nn.Parameter(feats3D.requires_grad_(True))
156
+ else:
157
+ feats3D = torch.zeros(fused_color.shape[0], 20).float().cuda()
158
+ feats3D[:, 13] = 1
159
+ self._feats3D = feats3D
160
+
161
+ print("Number of points at initialization : ", fused_point_cloud.shape[0])
162
+
163
+ dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
164
+ scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
165
+ rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
166
+ rots[:, 0] = 1
167
+
168
+ opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
169
+
170
+ self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
171
+ self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
172
+ self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
173
+ self._scaling = nn.Parameter(scales.requires_grad_(True))
174
+ self._rotation = nn.Parameter(rots.requires_grad_(True))
175
+ self._opacity = nn.Parameter(opacities.requires_grad_(True))
176
+ self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
177
+
178
+ def training_setup(self, training_args):
179
+ self.percent_dense = training_args.percent_dense
180
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
181
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
182
+
183
+ # self.spatial_lr_scale /= 3
184
+
185
+ l = [
186
+ {'params': [self._xyz], 'lr': training_args.position_lr_init * 0.5, "name": "xyz"},
187
+ {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
188
+ {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
189
+ {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
190
+ {'params': [self._scaling], 'lr': training_args.scaling_lr * 0.5, "name": "scaling"},
191
+ {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"},
192
+ ]
193
+
194
+ if self.affine:
195
+ l.append({'params': [*self.appearance_model.parameters()], 'lr': 1e-3, "name": "appearance_model"})
196
+
197
+ if self.feat_mutable:
198
+ l.append({'params': [self._feats3D], 'lr': 1e-2, "name": "feats3D"})
199
+
200
+ self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
201
+ self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
202
+ lr_final=training_args.position_lr_final*self.spatial_lr_scale,
203
+ lr_delay_mult=training_args.position_lr_delay_mult,
204
+ max_steps=training_args.position_lr_max_steps)
205
+
206
+ def update_learning_rate(self, iteration):
207
+ ''' Learning rate scheduling per step '''
208
+ for param_group in self.optimizer.param_groups:
209
+ if param_group["name"] == "xyz":
210
+ lr = self.xyz_scheduler_args(iteration)
211
+ param_group['lr'] = lr
212
+ return lr
213
+
214
+ def construct_list_of_attributes(self):
215
+ l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
216
+ # All channels except the 3 DC
217
+ for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
218
+ l.append('f_dc_{}'.format(i))
219
+ for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
220
+ l.append('f_rest_{}'.format(i))
221
+ for i in range(self._feats3D.shape[1]):
222
+ l.append('semantic_{}'.format(i))
223
+ l.append('opacity')
224
+ for i in range(self._scaling.shape[1]):
225
+ l.append('scale_{}'.format(i))
226
+ for i in range(self._rotation.shape[1]):
227
+ l.append('rot_{}'.format(i))
228
+ return l
229
+
230
+ def save_ply(self, path=None):
231
+ mkdir_p(os.path.dirname(path))
232
+
233
+ xyz = self.get_xyz.detach().cpu().numpy()
234
+ normals = np.zeros_like(xyz)
235
+ f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
236
+ f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
237
+ feats3D = self._feats3D.detach().cpu().numpy()
238
+ opacities = self._opacity.detach().cpu().numpy()
239
+ scale = self.scaling_inverse_activation(self.get_scaling).detach().cpu().numpy()
240
+ rotation = self._rotation.detach().cpu().numpy()
241
+
242
+ dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
243
+
244
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
245
+ attributes = np.concatenate((xyz, normals, f_dc, f_rest, feats3D, opacities, scale, rotation), axis=1)
246
+ elements[:] = list(map(tuple, attributes))
247
+ el = PlyElement.describe(elements, 'vertex')
248
+ plydata = PlyData([el])
249
+ if path is not None:
250
+ plydata.write(path)
251
+ return plydata
252
+
253
+ def save_splat(self, ply_path, splat_path):
254
+ plydata = self.save_ply(ply_path)
255
+ vert = plydata["vertex"]
256
+ sorted_indices = np.argsort(
257
+ -np.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"])
258
+ / (1 + np.exp(-vert["opacity"]))
259
+ )
260
+ buffer = BytesIO()
261
+ for idx in sorted_indices:
262
+ v = plydata["vertex"][idx]
263
+ position = np.array([v["x"], v["y"], v["z"]], dtype=np.float32)
264
+ scales = np.exp(
265
+ np.array(
266
+ [v["scale_0"], v["scale_1"], v["scale_2"]],
267
+ dtype=np.float32,
268
+ )
269
+ )
270
+ rot = np.array(
271
+ [v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]],
272
+ dtype=np.float32,
273
+ )
274
+ SH_C0 = 0.28209479177387814
275
+ color = np.array(
276
+ [
277
+ 0.5 + SH_C0 * v["f_dc_0"],
278
+ 0.5 + SH_C0 * v["f_dc_1"],
279
+ 0.5 + SH_C0 * v["f_dc_2"],
280
+ 1 / (1 + np.exp(-v["opacity"])),
281
+ ]
282
+ )
283
+ buffer.write(position.tobytes())
284
+ buffer.write(scales.tobytes())
285
+ buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes())
286
+ buffer.write(
287
+ ((rot / np.linalg.norm(rot)) * 128 + 128)
288
+ .clip(0, 255)
289
+ .astype(np.uint8)
290
+ .tobytes()
291
+ )
292
+ with open(splat_path, "wb") as f:
293
+ f.write(buffer.getvalue())
294
+
295
+ def save_semantic_pcd(self, path):
296
+ color_dict = {
297
+ 0: np.array([128, 64, 128]), # Road
298
+ 1: np.array([244, 35, 232]), # Sidewalk
299
+ 2: np.array([70, 70, 70]), # Building
300
+ 3: np.array([102, 102, 156]), # Wall
301
+ 4: np.array([190, 153, 153]), # Fence
302
+ 5: np.array([153, 153, 153]), # Pole
303
+ 6: np.array([250, 170, 30]), # Traffic Light
304
+ 7: np.array([220, 220, 0]), # Traffic Sign
305
+ 8: np.array([107, 142, 35]), # Vegetation
306
+ 9: np.array([152, 251, 152]), # Terrain
307
+ 10: np.array([0, 0, 0]), # Black (trainId 10)
308
+ 11: np.array([70, 130, 180]), # Sky
309
+ 12: np.array([220, 20, 60]), # Person
310
+ 13: np.array([255, 0, 0]), # Rider
311
+ 14: np.array([0, 0, 142]), # Car
312
+ 15: np.array([0, 0, 70]), # Truck
313
+ 16: np.array([0, 60, 100]), # Bus
314
+ 17: np.array([0, 80, 100]), # Train
315
+ 18: np.array([0, 0, 230]), # Motorcycle
316
+ 19: np.array([119, 11, 32]) # Bicycle
317
+ }
318
+ semantic_idx = torch.argmax(self.get_3D_features, dim=-1, keepdim=True)
319
+ opacities = self.get_opacity[:, 0]
320
+ mask = ((semantic_idx != 10)[:, 0]) & ((semantic_idx != 8)[:, 0]) & (opacities > 0.2)
321
+
322
+ semantic_idx = semantic_idx[mask]
323
+ semantic_rgb = torch.zeros_like(semantic_idx).repeat(1, 3)
324
+ for idx in range(20):
325
+ rgb = torch.from_numpy(color_dict[idx]).to(semantic_rgb.device)[None, :]
326
+ semantic_rgb[(semantic_idx == idx)[:, 0], :] = rgb
327
+ semantic_rgb = semantic_rgb.float() / 255.0
328
+ pcd_xyz = self.get_xyz[mask]
329
+ smt_pcd = o3d.geometry.PointCloud()
330
+ smt_pcd.points = o3d.utility.Vector3dVector(pcd_xyz.detach().cpu().numpy())
331
+ smt_pcd.colors = o3d.utility.Vector3dVector(semantic_rgb.detach().cpu().numpy())
332
+ o3d.io.write_point_cloud(path, smt_pcd)
333
+
334
+ def save_vis_ply(self, path):
335
+ mkdir_p(os.path.dirname(path))
336
+ xyz = self.get_xyz.detach().cpu().numpy()
337
+ pcd = o3d.geometry.PointCloud()
338
+ pcd.points = o3d.utility.Vector3dVector(xyz)
339
+ colors = SH2RGB(self._features_dc[:, 0, :].detach().cpu().numpy()).clip(0, 1)
340
+ pcd.colors = o3d.utility.Vector3dVector(colors)
341
+ o3d.io.write_point_cloud(path, pcd)
342
+
343
+ def reset_opacity(self):
344
+ opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
345
+ optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
346
+ self._opacity = optimizable_tensors["opacity"]
347
+
348
+ def load_ply(self, path):
349
+ plydata = PlyData.read(path)
350
+
351
+ xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
352
+ np.asarray(plydata.elements[0]["y"]),
353
+ np.asarray(plydata.elements[0]["z"])), axis=1)
354
+ opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
355
+
356
+ features_dc = np.zeros((xyz.shape[0], 3, 1))
357
+ features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
358
+ features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
359
+ features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
360
+
361
+ extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
362
+ assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
363
+ features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
364
+ for idx, attr_name in enumerate(extra_f_names):
365
+ features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
366
+ # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
367
+ features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
368
+
369
+ scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
370
+ scales = np.zeros((xyz.shape[0], len(scale_names)))
371
+ for idx, attr_name in enumerate(scale_names):
372
+ scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
373
+
374
+ rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
375
+ rots = np.zeros((xyz.shape[0], len(rot_names)))
376
+ for idx, attr_name in enumerate(rot_names):
377
+ rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
378
+
379
+ self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
380
+ self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
381
+ self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
382
+ self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
383
+ self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
384
+ self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
385
+
386
+ self.active_sh_degree = self.max_sh_degree
387
+
388
+ def replace_tensor_to_optimizer(self, tensor, name):
389
+ optimizable_tensors = {}
390
+ for group in self.optimizer.param_groups:
391
+ if group["name"] == name:
392
+ stored_state = self.optimizer.state.get(group['params'][0], None)
393
+ if stored_state is not None:
394
+ stored_state["exp_avg"] = torch.zeros_like(tensor)
395
+ stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
396
+ del self.optimizer.state[group['params'][0]]
397
+ group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
398
+ self.optimizer.state[group['params'][0]] = stored_state
399
+ optimizable_tensors[group["name"]] = group["params"][0]
400
+ else:
401
+ group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
402
+ optimizable_tensors[group["name"]] = group["params"][0]
403
+ return optimizable_tensors
404
+
405
+ def _prune_optimizer(self, mask):
406
+ optimizable_tensors = {}
407
+ for group in self.optimizer.param_groups:
408
+ if group['name'] == 'appearance_model':
409
+ continue
410
+ stored_state = self.optimizer.state.get(group['params'][0], None)
411
+ if stored_state is not None:
412
+ stored_state["exp_avg"] = stored_state["exp_avg"][mask]
413
+ stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
414
+
415
+ del self.optimizer.state[group['params'][0]]
416
+ group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
417
+ self.optimizer.state[group['params'][0]] = stored_state
418
+
419
+ optimizable_tensors[group["name"]] = group["params"][0]
420
+ else:
421
+ group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
422
+ optimizable_tensors[group["name"]] = group["params"][0]
423
+ return optimizable_tensors
424
+
425
+ def prune_points(self, mask):
426
+ valid_points_mask = ~mask
427
+ optimizable_tensors = self._prune_optimizer(valid_points_mask)
428
+
429
+ self._xyz = optimizable_tensors["xyz"]
430
+ self._features_dc = optimizable_tensors["f_dc"]
431
+ self._features_rest = optimizable_tensors["f_rest"]
432
+ if self.feat_mutable:
433
+ self._feats3D = optimizable_tensors["feats3D"]
434
+ else:
435
+ self._feats3D = self._feats3D[1, :].repeat((self._xyz.shape[0], 1))
436
+ self._opacity = optimizable_tensors["opacity"]
437
+ self._scaling = optimizable_tensors["scaling"]
438
+ self._rotation = optimizable_tensors["rotation"]
439
+
440
+ self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
441
+
442
+ self.denom = self.denom[valid_points_mask]
443
+ self.max_radii2D = self.max_radii2D[valid_points_mask]
444
+
445
+ def cat_tensors_to_optimizer(self, tensors_dict):
446
+ optimizable_tensors = {}
447
+ for group in self.optimizer.param_groups:
448
+ if group['name'] not in tensors_dict:
449
+ continue
450
+ assert len(group["params"]) == 1
451
+ extension_tensor = tensors_dict[group["name"]]
452
+ stored_state = self.optimizer.state.get(group["params"][0], None)
453
+ if stored_state is not None:
454
+
455
+ stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
456
+ stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
457
+
458
+ del self.optimizer.state[group["params"][0]]
459
+ group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
460
+ self.optimizer.state[group["params"][0]] = stored_state
461
+
462
+ optimizable_tensors[group["name"]] = group["params"][0]
463
+ else:
464
+ group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
465
+ optimizable_tensors[group["name"]] = group["params"][0]
466
+
467
+ return optimizable_tensors
468
+
469
+ def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_feats3D, new_opacities, new_scaling, new_rotation):
470
+ d = {"xyz": new_xyz,
471
+ "f_dc": new_features_dc,
472
+ "f_rest": new_features_rest,
473
+ "feats3D": new_feats3D,
474
+ "opacity": new_opacities,
475
+ "scaling" : new_scaling,
476
+ "rotation" : new_rotation}
477
+
478
+ optimizable_tensors = self.cat_tensors_to_optimizer(d)
479
+ self._xyz = optimizable_tensors["xyz"]
480
+ self._features_dc = optimizable_tensors["f_dc"]
481
+ if self.feat_mutable:
482
+ self._feats3D = optimizable_tensors["feats3D"]
483
+ else:
484
+ self._feats3D = self._feats3D[1, :].repeat((self._xyz.shape[0], 1))
485
+ self._features_rest = optimizable_tensors["f_rest"]
486
+ self._opacity = optimizable_tensors["opacity"]
487
+ self._scaling = optimizable_tensors["scaling"]
488
+ self._rotation = optimizable_tensors["rotation"]
489
+
490
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
491
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
492
+ self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
493
+
494
+ def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
495
+ n_init_points = self.get_xyz.shape[0]
496
+ # Extract points that satisfy the gradient condition
497
+ padded_grad = torch.zeros((n_init_points), device="cuda")
498
+ padded_grad[:grads.shape[0]] = grads.squeeze()
499
+ selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
500
+ selected_pts_mask = torch.logical_and(selected_pts_mask,
501
+ torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
502
+
503
+ stds = self.get_scaling[selected_pts_mask].repeat(N,1)
504
+ means =torch.zeros((stds.size(0), 3),device="cuda")
505
+ samples = torch.normal(mean=means, std=stds)
506
+ rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
507
+ new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
508
+ new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
509
+ new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
510
+ new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
511
+ new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
512
+ new_feats3D = self._feats3D[selected_pts_mask].repeat(N,1)
513
+ new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
514
+
515
+ self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_feats3D, new_opacity, new_scaling, new_rotation)
516
+
517
+ prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
518
+ self.prune_points(prune_filter)
519
+
520
+ def densify_and_clone(self, grads, grad_threshold, scene_extent):
521
+ # Extract points that satisfy the gradient condition
522
+ selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
523
+ selected_pts_mask = torch.logical_and(selected_pts_mask,
524
+ torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
525
+
526
+ new_xyz = self._xyz[selected_pts_mask]
527
+ new_features_dc = self._features_dc[selected_pts_mask]
528
+ new_features_rest = self._features_rest[selected_pts_mask]
529
+ new_feats3D = self._feats3D[selected_pts_mask]
530
+ new_opacities = self._opacity[selected_pts_mask]
531
+ new_scaling = self._scaling[selected_pts_mask]
532
+ new_rotation = self._rotation[selected_pts_mask]
533
+
534
+ self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_feats3D, new_opacities, new_scaling, new_rotation)
535
+
536
+ def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size, cam_pos=None):
537
+ grads = self.xyz_gradient_accum / self.denom
538
+ grads[grads.isnan()] = 0.0
539
+
540
+ self.densify_and_clone(grads, max_grad, extent)
541
+ self.densify_and_split(grads, max_grad, extent)
542
+
543
+ prune_mask = (self.get_opacity < min_opacity).squeeze()
544
+ if max_screen_size:
545
+ big_points_vs = self.max_radii2D > max_screen_size
546
+ if cam_pos is not None:
547
+ # points_cam_dist = torch.abs(self.get_xyz[:, None, :] - cam_pos[None, ...])
548
+ # points_cam_nearest_idx = torch.argmin(torch.norm(points_cam_dist, dim=-1), dim=1)
549
+ # points_cam_dist = points_cam_dist[torch.arange(points_cam_dist.shape[0]), points_cam_nearest_idx, :]
550
+ # near_mask1 = (points_cam_dist[:, 1] < 5) & (points_cam_dist[:, 0] < 10) & (points_cam_dist[:, 2] < 10)
551
+ # big_points_ws1 = near_mask1 & (self.get_scaling.max(dim=1).values > 1.0)
552
+ # near_mask2 = (points_cam_dist[:, 1] < 10) & (points_cam_dist[:, 0] < 20) & (points_cam_dist[:, 2] < 20)
553
+ # big_points_ws2 = near_mask2 & (self.get_scaling.max(dim=1).values > 3.0)
554
+ # big_points_ws = (self.get_scaling.max(dim=1).values > 10.0) | big_points_ws1 | big_points_ws2
555
+ big_points_ws = self.get_scaling.max(dim=1).values > 10
556
+ prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
557
+ else:
558
+ big_points_ws = self.get_scaling.max(dim=1).values > 5
559
+ prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
560
+ self.prune_points(prune_mask)
561
+
562
+ torch.cuda.empty_cache()
563
+
564
+ def add_densification_stats_grad(self, tensor_grad, update_filter):
565
+ self.xyz_gradient_accum[update_filter] += torch.norm(tensor_grad[update_filter,:2], dim=-1, keepdim=True)
566
+ self.denom[update_filter] += 1
567
+
code/sim/hugsim_env.egg-info/PKG-INFO ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: hugsim-env
3
+ Version: 0.0.1
4
+ Requires-Dist: gymnasium
code/sim/hugsim_env.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ pyproject.toml
2
+ setup.py
3
+ hugsim_env/__init__.py
4
+ hugsim_env.egg-info/PKG-INFO
5
+ hugsim_env.egg-info/SOURCES.txt
6
+ hugsim_env.egg-info/dependency_links.txt
7
+ hugsim_env.egg-info/requires.txt
8
+ hugsim_env.egg-info/top_level.txt
9
+ hugsim_env/envs/__init__.py
10
+ hugsim_env/envs/hug_sim.py
code/sim/hugsim_env.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
code/sim/hugsim_env.egg-info/requires.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ gymnasium
code/sim/hugsim_env.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ hugsim_env
code/sim/hugsim_env/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from gymnasium.envs.registration import register
2
+
3
+
4
+ register(
5
+ id="hugsim_env/HUGSim-v0",
6
+ entry_point="hugsim_env.envs:HUGSimEnv",
7
+ max_episode_steps=400,
8
+ )
code/sim/hugsim_env/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (379 Bytes). View file
 
code/sim/hugsim_env/envs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from hugsim_env.envs.hug_sim import HUGSimEnv
code/sim/hugsim_env/envs/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (221 Bytes). View file
 
code/sim/hugsim_env/envs/__pycache__/hug_sim.cpython-311.pyc ADDED
Binary file (22.2 kB). View file
 
code/sim/hugsim_env/envs/hug_sim.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from copy import deepcopy
4
+ import gymnasium
5
+ from gymnasium import spaces
6
+ from copy import deepcopy
7
+ from sim.utils.sim_utils import create_cam, rt2pose, pose2rt, load_camera_cfg, dense_cam_poses
8
+ from scipy.spatial.transform import Rotation as SCR
9
+ from sim.utils.score_calculator import create_rectangle, bg_collision_det
10
+ import os
11
+ import pickle
12
+ from sim.utils.plan import planner, UnifiedMap
13
+ from omegaconf import OmegaConf
14
+ import math
15
+ from gaussian_renderer import GaussianModel
16
+ from scene.obj_model import ObjModel
17
+ from gaussian_renderer import render
18
+ import open3d as o3d
19
+
20
+
21
+ def fg_collision_det(ego_box, objs):
22
+ ego_x, ego_y, _, ego_w, ego_l, ego_h, ego_yaw = ego_box
23
+ ego_poly = create_rectangle(ego_x, ego_y, ego_w, ego_l, ego_yaw)
24
+ for obs in objs:
25
+ obs_x, obs_y, _, obs_w, obs_l, _, obs_yaw = obs
26
+ obs_poly = create_rectangle(
27
+ obs_x, obs_y, obs_w, obs_l, obs_yaw)
28
+ if ego_poly.intersects(obs_poly):
29
+ return True
30
+ return False
31
+
32
+ class HUGSimEnv(gymnasium.Env):
33
+ def __init__(self, cfg, output):
34
+ super().__init__()
35
+
36
+ plan_list = cfg.scenario.plan_list
37
+ for control_param in plan_list:
38
+ control_param[5] = os.path.join(cfg.base.realcar_path, control_param[5])
39
+
40
+ # read ground infos
41
+ with open(os.path.join(cfg.model_path, 'ground_param.pkl'), 'rb') as f:
42
+ #numpy.ndarray, float, list
43
+ cam_poses, cam_heights, commands = pickle.load(f)
44
+ cam_poses, commands = dense_cam_poses(cam_poses, commands)
45
+ self.ground_model = (cam_poses, cam_heights, commands)
46
+
47
+ if cfg.scenario.load_HD_map:
48
+ unified_map = UnifiedMap(cfg.base.HD_map.path, cfg.base.HD_map.version, cfg.scenario.scene_name)
49
+ else:
50
+ unified_map = None
51
+
52
+ self.kinematic = OmegaConf.to_container(cfg.kinematic)
53
+ self.kinematic['min_steer'] = -math.radians(cfg.kinematic.min_steer)
54
+ self.kinematic['max_steer'] = math.radians(cfg.kinematic.max_steer)
55
+ self.kinematic['start_vr']= np.array(cfg.scenario.start_euler) / 180 * np.pi
56
+ self.kinematic['start_vab'] = np.array(cfg.scenario.start_ab)
57
+ self.kinematic['start_velo'] = cfg.scenario.start_velo
58
+ self.kinematic['start_steer'] = cfg.scenario.start_steer
59
+
60
+ self.gaussians = GaussianModel(cfg.model.sh_degree, affine=cfg.affine)
61
+
62
+ """
63
+ plan_list: a, b, height, yaw, v, model_path, controller, params
64
+ Yaw is based on ego car's orientation. 0 means same direction as ego.
65
+ Right is positive and left is negative.
66
+ """
67
+
68
+ (model_params, iteration) = torch.load(os.path.join(cfg.model_path, "scene.pth"), weights_only=False)
69
+ self.gaussians.restore(model_params, None)
70
+
71
+ dynamic_gaussians = {}
72
+ if len(plan_list) == 0:
73
+ self.planner = None
74
+ else:
75
+ self.planner = planner(plan_list, scene_path=cfg.model_path, unified_map=unified_map, ground=self.ground_model, dt=cfg.kinematic.dt)
76
+ for plan_id in self.planner.ckpts.keys():
77
+ dynamic_gaussians[plan_id] = ObjModel(cfg.model.sh_degree, feat_mutable=False)
78
+ (model_params, iteration) = torch.load(self.planner.ckpts[plan_id], weights_only=False)
79
+ model_params = list(model_params)
80
+ dynamic_gaussians[plan_id].restore(model_params, None)
81
+
82
+ semantic_idx = torch.argmax(self.gaussians.get_full_3D_features, dim=-1, keepdim=True)
83
+ ground_xyz = self.gaussians.get_full_xyz[(semantic_idx == 0)[:, 0]].detach().cpu().numpy()
84
+ scene_xyz = self.gaussians.get_full_xyz[((semantic_idx > 1) & (semantic_idx != 10))[:, 0]].detach().cpu().numpy()
85
+ ground_pcd = o3d.geometry.PointCloud()
86
+ ground_pcd.points = o3d.utility.Vector3dVector(ground_xyz.astype(float))
87
+ o3d.io.write_point_cloud(os.path.join(output, 'ground.ply'), ground_pcd)
88
+ scene_pcd = o3d.geometry.PointCloud()
89
+ scene_pcd.points = o3d.utility.Vector3dVector(scene_xyz.astype(float))
90
+ o3d.io.write_point_cloud(os.path.join(output, 'scene.ply'), scene_pcd)
91
+
92
+ unicycles = {}
93
+
94
+ if cfg.scenario.load_HD_map and self.planner is not None:
95
+ self.planner.update_agent_route()
96
+
97
+ self.cam_params, cam_align, self.cam_rect = load_camera_cfg(cfg.camera)
98
+
99
+ self.ego_verts = np.array([[0.5, 0, 0.5], [0.5, 0, -0.5], [0.5, 1.0, 0.5], [0.5, 1.0, -0.5],
100
+ [-0.5, 0, -0.5], [-0.5, 0, 0.5], [-0.5, 1.0, -0.5], [-0.5, 1.0, 0.5]])
101
+ self.whl = np.array([1.6, 1.5, 3.0])
102
+ self.ego_verts *= self.whl
103
+ self.data_type = cfg.data_type
104
+
105
+ self.action_space = spaces.Dict(
106
+ {
107
+ "steer_rate": spaces.Box(self.kinematic['min_steer'], self.kinematic['max_steer'], dtype=float),
108
+ "acc": spaces.Box(self.kinematic['min_acc'], self.kinematic['max_acc'], dtype=float)
109
+ }
110
+ )
111
+ self.observation_space = spaces.Dict(
112
+ {
113
+ 'rgb': spaces.Dict({
114
+ cam_name: spaces.Box(
115
+ low=0, high=255,
116
+ shape=(params['intrinsic']['H'], params['intrinsic']['W'], 3), dtype=np.uint8
117
+ ) for cam_name, params in self.cam_params.items()
118
+ }),
119
+ # 'semantic': spaces.Dict({
120
+ # cam_name: spaces.Box(
121
+ # low=0, high=50,
122
+ # shape=(params['intrinsic']['H'], params['intrinsic']['W']), dtype=np.uint8
123
+ # ) for cam_name, params in self.cam_params.items()
124
+ # }),
125
+ # 'depth': spaces.Dict({
126
+ # cam_name: spaces.Box(
127
+ # low=0, high=1000,
128
+ # shape=(params['intrinsic']['H'], params['intrinsic']['W']), dtype=np.float32
129
+ # ) for cam_name, params in self.cam_params.items()
130
+ # }),
131
+ }
132
+ )
133
+ self.fric = self.kinematic['fric']
134
+
135
+ self.start_vr = self.kinematic['start_vr']
136
+ self.start_vab = self.kinematic['start_vab']
137
+ self.start_velo = self.kinematic['start_velo']
138
+ self.vr = deepcopy(self.kinematic['start_vr'])
139
+ self.vab = deepcopy(self.kinematic['start_vab'])
140
+ self.velo = deepcopy(self.kinematic['start_velo'])
141
+ self.steer = deepcopy(self.kinematic['start_steer'])
142
+ self.dt = self.kinematic['dt']
143
+
144
+ bg_color = [1, 1, 1] if cfg.model.white_background else [0, 0, 0]
145
+ self.render_fn = render
146
+ self.render_kwargs = {
147
+ "pc": self.gaussians,
148
+ "bg_color": torch.tensor(bg_color, dtype=torch.float32, device="cuda"),
149
+ "dynamic_gaussians": dynamic_gaussians,
150
+ "unicycles": unicycles
151
+ }
152
+ gaussians = self.gaussians
153
+ semantic_idx = torch.argmax(gaussians.get_3D_features, dim=-1, keepdim=True)
154
+ opacities = gaussians.get_opacity[:, 0]
155
+ mask = ((semantic_idx > 1) & (semantic_idx != 10))[:, 0] & (opacities > 0.8)
156
+ self.points = gaussians.get_xyz[mask]
157
+
158
+ self.last_accel = 0
159
+ self.last_steer_rate = 0
160
+
161
+ self.timestamp = 0
162
+
163
+ def ground_height(self, u, v):
164
+ cam_poses, cam_height, _ = self.ground_model
165
+ cam_dist = np.sqrt(
166
+ (cam_poses[:, 0, 3] - u)**2 + (cam_poses[:, 2, 3] - v)**2
167
+ )
168
+ nearest_cam_idx = np.argmin(cam_dist, axis=0)
169
+ nearest_c2w = cam_poses[nearest_cam_idx]
170
+
171
+ nearest_w2c = np.linalg.inv(nearest_c2w)
172
+ uhv_local = nearest_w2c[:3, :3] @ np.array([u, 0, v]) + nearest_w2c[:3, 3]
173
+ uhv_local[1] = 0
174
+ uhv_world = nearest_c2w[:3, :3] @ uhv_local + nearest_c2w[:3, 3]
175
+
176
+ return uhv_world[1]
177
+
178
+ @property
179
+ def route_completion(self):
180
+ cam_poses, _, _ = self.ground_model
181
+ cam_dist = np.sqrt(
182
+ (cam_poses[:, 0, 3] - self.vab[0])**2 + (cam_poses[:, 2, 3] - self.vab[1])**2
183
+ )
184
+ nearest_cam_idx = np.argmin(cam_dist, axis=0)
185
+ return (nearest_cam_idx + 1) / (cam_poses.shape[0] * 0.9), cam_dist[nearest_cam_idx]
186
+
187
+
188
+ @property
189
+ def vt(self):
190
+ vt = np.zeros(3)
191
+ vt[[0, 2]] = self.vab
192
+ vt[1] = self.ground_height(self.vab[0], self.vab[1])
193
+ return vt
194
+
195
+ @property
196
+ def ego(self):
197
+ return rt2pose(self.vr, self.vt)
198
+
199
+ @property
200
+ def ego_state(self):
201
+ return torch.tensor([self.vab[0], self.vab[1], self.vr[1], self.velo])
202
+
203
+ @property
204
+ def ego_box(self):
205
+ return [self.vt[2], -self.vt[0], -self.vt[1], self.whl[0], self.whl[2], self.whl[1], -self.vr[1]]
206
+
207
+ @property
208
+ def objs_list(self):
209
+ obj_boxes = []
210
+ objs = self.render_kwargs['planning'][0]
211
+ for obj_id, obj_b2w in objs.items():
212
+ yaw = SCR.from_matrix(obj_b2w[:3, :3].detach().cpu().numpy()).as_euler('YXZ')[0]
213
+ # X, Y, Z in IMU, w, l, h
214
+ wlh = self.planner.wlhs[obj_id]
215
+ obj_boxes.append([obj_b2w[2, 3].item(), -obj_b2w[0, 3].item(), -obj_b2w[1, 3].item(), wlh[0], wlh[1], wlh[2], -yaw-0.5*np.pi])
216
+ return obj_boxes
217
+
218
+ def _get_obs(self):
219
+ rgbs, semantics, depths = {}, {}, {}
220
+ v2front = self.cam_params['CAM_FRONT']["v2c"]
221
+ for cam_name, params in self.cam_params.items():
222
+ intrinsic, v2c = params['intrinsic'], params['v2c']
223
+ c2front = v2front @ np.linalg.inv(v2c) @ self.cam_rect
224
+ c2w = self.ego @ c2front
225
+ viewpoint = create_cam(intrinsic, c2w)
226
+ with torch.no_grad():
227
+ render_pkg = self.render_fn(viewpoint=viewpoint, prev_viewpoint=None, **self.render_kwargs)
228
+ rgb = (torch.permute(render_pkg['render'].clamp(0, 1), (1,2,0)).detach().cpu().numpy() * 255).astype(np.uint8)
229
+ smt = torch.argmax(render_pkg['feats'], dim=0).detach().cpu().numpy().astype(np.uint8)
230
+ depth = render_pkg['depth'][0].detach().cpu().numpy()
231
+ if (self.data_type == 'waymo' or self.data_type == 'kitti360') and 'BACK' in cam_name:
232
+ rgbs[cam_name] = np.zeros_like(rgb)
233
+ semantics[cam_name] = np.zeros_like(smt)
234
+ depths[cam_name] = np.zeros_like(depth)
235
+ else:
236
+ rgbs[cam_name] = rgb
237
+ semantics[cam_name] = smt
238
+ depths[cam_name] = depth
239
+
240
+ return {
241
+ 'rgb': rgbs,
242
+ # 'semantic': semantics,
243
+ # 'depth': depths,
244
+ }
245
+
246
+ def _get_info(self):
247
+ wego_r, wego_t = pose2rt(self.ego)
248
+ cam_poses, _, commands = self.ground_model
249
+ dist = np.sum((cam_poses[:, :3, 3] - self.vt) ** 2, axis=-1)
250
+ nearest_cam_idx = np.argmin(dist)
251
+ command = commands[nearest_cam_idx]
252
+ return {
253
+ 'ego_pos' : wego_t.tolist(),
254
+ 'ego_rot' : wego_r.tolist(),
255
+ 'ego_velo' : self.velo,
256
+ 'ego_steer': self.steer,
257
+ 'accelerate': self.last_accel,
258
+ 'steer_rate': self.last_steer_rate,
259
+ 'timestamp': self.timestamp,
260
+ 'command': command,
261
+ 'ego_box': self.ego_box,
262
+ 'obj_boxes': self.objs_list,
263
+ 'cam_params': self.cam_params,
264
+ # 'ego_verts': verts,
265
+ }
266
+
267
+ def reset(self, seed=None, options=None):
268
+ self.vr = deepcopy(self.start_vr)
269
+ self.vab = deepcopy(self.start_vab)
270
+ self.velo = deepcopy(self.start_velo)
271
+ self.timestamp = 0
272
+
273
+ if self.planner is not None:
274
+ self.render_kwargs['planning'] = self.planner.plan_traj(self.timestamp, self.ego_state)
275
+ else:
276
+ self.render_kwargs['planning'] = [{}, {}]
277
+
278
+ observation = self._get_obs()
279
+ info = self._get_info()
280
+
281
+ return observation, info
282
+
283
+ def step(self, action):
284
+ self.timestamp += self.dt
285
+ if self.planner is not None:
286
+ self.render_kwargs['planning'] = self.planner.plan_traj(self.timestamp, self.ego_state)
287
+ else:
288
+ self.render_kwargs['planning'] = [{}, {}]
289
+ steer_rate, acc = action['steer_rate'], action['acc']
290
+ self.last_steer_rate, self.last_accel = steer_rate, acc
291
+ L = self.kinematic['Lr'] + self.kinematic['Lf']
292
+ self.velo += acc * self.dt
293
+ self.steer += steer_rate * self.dt
294
+ theta = self.vr[1]
295
+ # print(theta / np.pi * 180, self.steer / np.pi * 180)
296
+ self.vab[0] = self.vab[0] + self.velo * np.sin(theta) * self.dt
297
+ self.vab[1] = self.vab[1] + self.velo * np.cos(theta) * self.dt
298
+ self.vr[1] = theta + self.velo * np.tan(self.steer) / L * self.dt
299
+
300
+ terminated = False
301
+ reward = 0
302
+ verts = (self.ego[:3, :3] @ self.ego_verts.T).T + self.ego[:3, 3]
303
+ verts = torch.from_numpy(verts.astype(np.float32)).cuda()
304
+
305
+ bg_collision = bg_collision_det(self.points, verts)
306
+ if bg_collision:
307
+ terminated = True
308
+ print('Collision with background')
309
+ reward = -100
310
+
311
+ fg_collision = fg_collision_det(self.ego_box, self.objs_list)
312
+ if fg_collision:
313
+ terminated = True
314
+ print('Collision with foreground')
315
+ reward = -100
316
+
317
+ rc, dist = self.route_completion
318
+ if dist > 10:
319
+ terminated=True
320
+ print('Far from preset trajectory')
321
+ reward = -50
322
+
323
+ if rc >= 1:
324
+ terminated = True
325
+ print('Complete')
326
+ reward = 1000
327
+
328
+ observation = self._get_obs()
329
+ info = self._get_info()
330
+ info['rc'] = rc
331
+ info['collision'] = bg_collision or fg_collision
332
+
333
+ return observation, reward, terminated, False, info
code/sim/ilqr/__pycache__/lqr.cpython-311.pyc ADDED
Binary file (2.56 kB). View file
 
code/sim/ilqr/__pycache__/lqr_solver.cpython-311.pyc ADDED
Binary file (33.4 kB). View file
 
code/sim/ilqr/__pycache__/utils.cpython-311.pyc ADDED
Binary file (17.2 kB). View file
 
code/sim/ilqr/lqr.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sim.ilqr.lqr_solver import ILQRSolverParameters, ILQRWarmStartParameters, ILQRSolver
2
+ import numpy as np
3
+
4
+ solver_params = ILQRSolverParameters(
5
+ discretization_time=0.5,
6
+ state_cost_diagonal_entries=[1.0, 1.0, 10.0, 0.0, 0.0],
7
+ input_cost_diagonal_entries=[1.0, 10.0],
8
+ state_trust_region_entries=[1.0] * 5,
9
+ input_trust_region_entries=[1.0] * 2,
10
+ max_ilqr_iterations=100,
11
+ convergence_threshold=1e-6,
12
+ max_solve_time=0.05,
13
+ max_acceleration=3.0,
14
+ max_steering_angle=np.pi / 3.0,
15
+ max_steering_angle_rate=0.4,
16
+ min_velocity_linearization=0.01,
17
+ wheelbase=2.7
18
+ )
19
+
20
+ warm_start_params = ILQRWarmStartParameters(
21
+ k_velocity_error_feedback=0.5,
22
+ k_steering_angle_error_feedback=0.05,
23
+ lookahead_distance_lateral_error=15.0,
24
+ k_lateral_error=0.1,
25
+ jerk_penalty_warm_start_fit=1e-4,
26
+ curvature_rate_penalty_warm_start_fit=1e-2,
27
+ )
28
+
29
+ lqr = ILQRSolver(solver_params=solver_params, warm_start_params=warm_start_params)
30
+
31
+ def plan2control(plan_traj, init_state):
32
+ current_state = init_state
33
+ solutions = lqr.solve(current_state, plan_traj)
34
+ optimal_inputs = solutions[-1].input_trajectory
35
+ accel_cmd = optimal_inputs[0, 0]
36
+ steering_rate_cmd = optimal_inputs[0, 1]
37
+ return accel_cmd, steering_rate_cmd
38
+
39
+ if __name__ == '__main__':
40
+ # plan_traj = np.zeros((6,5))
41
+ # plan_traj[:, 0] = 1
42
+ # plan_traj[:, 1] = np.ones(6)
43
+ # plan_traj = np.cumsum(plan_traj, axis=0)
44
+ # print(plan_traj)
45
+ plan_traj = np.array([[-0.18724936, 2.29100776, 0., 0., 0., ],
46
+ [-0.29260731, 2.2971828 , 0., 0., 0. ],
47
+ [-0.46831554, 2.55596018, 0., 0., 0. ],
48
+ [-0.5859955 , 2.73183298, 0., 0., 0. ],
49
+ [-0.62684 , 2.84659386, 0., 0., 0. ],
50
+ [-0.67761713, 2.80647802, 0., 0., 0. ]])
51
+ plan_traj = plan_traj[:, [1,0,2,3,4]]
52
+ init_state = np.array([0.00000000e+00, 3.46944695e-17, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00])
53
+ print(plan_traj.shape, init_state.shape)
54
+ acc, steer = plan2control(plan_traj, init_state)
55
+ print(acc, steer)
code/sim/ilqr/lqr_solver.py ADDED
@@ -0,0 +1,689 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This provides an implementation of the iterative linear quadratic regulator (iLQR) algorithm for trajectory tracking.
3
+ It is specialized to the case with a discrete-time kinematic bicycle model and a quadratic trajectory tracking cost.
4
+
5
+ Original (Nonlinear) Discrete Time System:
6
+ z_k = [x_k, y_k, theta_k, v_k, delta_k]
7
+ u_k = [a_k, phi_k]
8
+
9
+ x_{k+1} = x_k + v_k * cos(theta_k) * dt
10
+ y_{k+1} = y_k + v_k * sin(theta_k) * dt
11
+ theta_{k+1} = theta_k + v_k * tan(delta_k) / L * dt
12
+ v_{k+1} = v_k + a_k * dt
13
+ delta_{k+1} = delta_k + phi_k * dt
14
+
15
+ where (x_k, y_k, theta_k) is the pose at timestep k with time discretization dt,
16
+ v_k and a_k are velocity and acceleration,
17
+ delta_k and phi_k are steering angle and steering angle rate,
18
+ and L is the vehicle wheelbase.
19
+
20
+ Quadratic Tracking Cost:
21
+ J = sum_{k=0}^{N-1} ||u_k||_2^{R_k} +
22
+ sum_{k=0}^N ||z_k - z_{ref,k}||_2^{Q_k}
23
+ For simplicity, we opt to use constant input cost matrices R_k = R and constant state cost matrices Q_k = Q.
24
+
25
+ There are multiple improvements that can be done for this implementation, but omitted for simplicity of the code.
26
+ Some of these include:
27
+ * Handle constraints directly in the optimization (e.g. log-barrier / penalty method with quadratic cost estimate).
28
+ * Line search in the input policy update (feedforward term) to determine a good gradient step size.
29
+
30
+ References Used: https://people.eecs.berkeley.edu/~pabbeel/cs287-fa19/slides/Lec5-LQR.pdf and
31
+ https://www.cs.cmu.edu/~rsalakhu/10703/Lectures/Lecture_trajectoryoptimization.pdf
32
+ """
33
+
34
+ import time
35
+ from dataclasses import dataclass, fields
36
+ from typing import List, Optional, Tuple
37
+
38
+ import numpy as np
39
+ import numpy.typing as npt
40
+
41
+ # from nuplan.common.actor_state.vehicle_parameters import get_pacifica_parameters
42
+ # from nuplan.common.geometry.compute import principal_value
43
+ # from nuplan.planning.simulation.controller.tracker.tracker_utils import (
44
+ # complete_kinematic_state_and_inputs_from_poses,
45
+ # compute_steering_angle_feedback,
46
+ # )
47
+ from sim.ilqr.utils import principal_value, complete_kinematic_state_and_inputs_from_poses, compute_steering_angle_feedback
48
+
49
+ DoubleMatrix = npt.NDArray[np.float64]
50
+
51
+
52
+ @dataclass(frozen=True)
53
+ class ILQRSolverParameters:
54
+ """Parameters related to the solver implementation."""
55
+
56
+ discretization_time: float # [s] Time discretization used for integration.
57
+
58
+ # Cost weights for state [x, y, heading, velocity, steering angle] and input variables [acceleration, steering rate].
59
+ state_cost_diagonal_entries: List[float]
60
+ input_cost_diagonal_entries: List[float]
61
+
62
+ # Trust region cost weights for state and input variables. Helps keep linearization error per update step bounded.
63
+ state_trust_region_entries: List[float]
64
+ input_trust_region_entries: List[float]
65
+
66
+ # Parameters related to solver runtime / solution sub-optimality.
67
+ max_ilqr_iterations: int # Maximum number of iterations to run iLQR before timeout.
68
+ convergence_threshold: float # Threshold for delta inputs below which we can terminate iLQR early.
69
+ max_solve_time: Optional[
70
+ float
71
+ ] # [s] If defined, sets a maximum time to run a solve call of iLQR before terminating.
72
+
73
+ # Constraints for underlying dynamics model.
74
+ max_acceleration: float # [m/s^2] Absolute value threshold on acceleration input.
75
+ max_steering_angle: float # [rad] Absolute value threshold on steering angle state.
76
+ max_steering_angle_rate: float # [rad/s] Absolute value threshold on steering rate input.
77
+
78
+ # Parameters for dynamics / linearization.
79
+ min_velocity_linearization: float # [m/s] Absolute value threshold below which linearization velocity is modified.
80
+ wheelbase: float # [m] Wheelbase length parameter for the vehicle.
81
+
82
+ def __post_init__(self) -> None:
83
+ """Ensure entries lie in expected bounds and initialize wheelbase."""
84
+ for entry in [
85
+ "discretization_time",
86
+ "max_ilqr_iterations",
87
+ "convergence_threshold",
88
+ "max_acceleration",
89
+ "max_steering_angle",
90
+ "max_steering_angle_rate",
91
+ "min_velocity_linearization",
92
+ "wheelbase",
93
+ ]:
94
+ assert getattr(self, entry) > 0.0, f"Field {entry} should be positive."
95
+
96
+ assert self.max_steering_angle < np.pi / 2.0, "Max steering angle should be less than 90 degrees."
97
+
98
+ if isinstance(self.max_solve_time, float):
99
+ assert self.max_solve_time > 0.0, "The specified max solve time should be positive."
100
+
101
+ assert np.all([x >= 0 for x in self.state_cost_diagonal_entries]), "Q matrix must be positive semidefinite."
102
+ assert np.all([x > 0 for x in self.input_cost_diagonal_entries]), "R matrix must be positive definite."
103
+
104
+ assert np.all(
105
+ [x > 0 for x in self.state_trust_region_entries]
106
+ ), "State trust region cost matrix must be positive definite."
107
+ assert np.all(
108
+ [x > 0 for x in self.input_trust_region_entries]
109
+ ), "Input trust region cost matrix must be positive definite."
110
+
111
+
112
+ @dataclass(frozen=True)
113
+ class ILQRWarmStartParameters:
114
+ """Parameters related to generating a warm start trajectory for iLQR."""
115
+
116
+ k_velocity_error_feedback: float # Gain for initial velocity error for warm start acceleration.
117
+ k_steering_angle_error_feedback: float # Gain for initial steering angle error for warm start steering rate.
118
+ lookahead_distance_lateral_error: float # [m] Distance ahead for which we estimate lateral error.
119
+ k_lateral_error: float # Gain for lateral error to compute steering angle feedback.
120
+ jerk_penalty_warm_start_fit: float # Penalty for jerk in velocity profile estimation.
121
+ curvature_rate_penalty_warm_start_fit: float # Penalty for curvature rate in curvature profile estimation.
122
+
123
+ def __post_init__(self) -> None:
124
+ """Ensure entries lie in expected bounds."""
125
+ for entry in [
126
+ "k_velocity_error_feedback",
127
+ "k_steering_angle_error_feedback",
128
+ "lookahead_distance_lateral_error",
129
+ "k_lateral_error",
130
+ "jerk_penalty_warm_start_fit",
131
+ "curvature_rate_penalty_warm_start_fit",
132
+ ]:
133
+ assert getattr(self, entry) > 0.0, f"Field {entry} should be positive."
134
+
135
+
136
+ @dataclass(frozen=True)
137
+ class ILQRIterate:
138
+ """Contains state, input, and associated Jacobian trajectories needed to perform an update step of iLQR."""
139
+
140
+ state_trajectory: DoubleMatrix
141
+ input_trajectory: DoubleMatrix
142
+ state_jacobian_trajectory: DoubleMatrix
143
+ input_jacobian_trajectory: DoubleMatrix
144
+
145
+ def __post_init__(self) -> None:
146
+ """Check consistency of dimension across trajectory elements."""
147
+ assert len(self.state_trajectory.shape) == 2, "Expect state trajectory to be a 2D matrix."
148
+ state_trajectory_length, state_dim = self.state_trajectory.shape
149
+
150
+ assert len(self.input_trajectory.shape) == 2, "Expect input trajectory to be a 2D matrix."
151
+ input_trajectory_length, input_dim = self.input_trajectory.shape
152
+
153
+ assert (
154
+ input_trajectory_length == state_trajectory_length - 1
155
+ ), "State trajectory should be 1 longer than the input trajectory."
156
+ assert self.state_jacobian_trajectory.shape == (input_trajectory_length, state_dim, state_dim)
157
+ assert self.input_jacobian_trajectory.shape == (input_trajectory_length, state_dim, input_dim)
158
+
159
+ for field in fields(self):
160
+ # Make sure that we have no nan entries in our trajectory rollout prior to operating on this.
161
+ assert ~np.any(np.isnan(getattr(self, field.name))), f"{field.name} has unexpected nan values."
162
+
163
+
164
+ @dataclass(frozen=True)
165
+ class ILQRInputPolicy:
166
+ """Contains parameters for the perturbation input policy computed after performing LQR."""
167
+
168
+ state_feedback_matrices: DoubleMatrix
169
+ feedforward_inputs: DoubleMatrix
170
+
171
+ def __post__init__(self) -> None:
172
+ """Check shape of policy parameters."""
173
+ assert (
174
+ len(self.state_feedback_matrices.shape) == 3
175
+ ), "Expected state_feedback_matrices to have shape (n_horizon, n_inputs, n_states)"
176
+
177
+ assert (
178
+ len(self.feedforward_inputs.shape) == 2
179
+ ), "Expected feedforward inputs to have shape (n_horizon, n_inputs)."
180
+
181
+ assert (
182
+ self.feedforward_inputs.shape == self.state_feedback_matrices.shape[:2]
183
+ ), "Inconsistent horizon or input dimension between feedforward inputs and state feedback matrices."
184
+
185
+ for field in fields(self):
186
+ # Make sure that we have no nan entries in our policy parameters prior to using them.
187
+ assert ~np.any(np.isnan(getattr(self, field.name))), f"{field.name} has unexpected nan values."
188
+
189
+
190
+ @dataclass(frozen=True)
191
+ class ILQRSolution:
192
+ """Contains the iLQR solution with associated cost for consumption by the solver's client."""
193
+
194
+ state_trajectory: DoubleMatrix
195
+ input_trajectory: DoubleMatrix
196
+
197
+ tracking_cost: float
198
+
199
+ def __post_init__(self) -> None:
200
+ """Check consistency of dimension across trajectory elements and nonnegative cost."""
201
+ assert len(self.state_trajectory.shape) == 2, "Expect state trajectory to be a 2D matrix."
202
+ state_trajectory_length, _ = self.state_trajectory.shape
203
+
204
+ assert len(self.input_trajectory.shape) == 2, "Expect input trajectory to be a 2D matrix."
205
+ input_trajectory_length, _ = self.input_trajectory.shape
206
+
207
+ assert (
208
+ input_trajectory_length == state_trajectory_length - 1
209
+ ), "State trajectory should be 1 longer than the input trajectory."
210
+
211
+ assert self.tracking_cost >= 0.0, "Expect the tracking cost to be nonnegative."
212
+
213
+
214
+ class ILQRSolver:
215
+ """iLQR solver implementation, see module docstring for details."""
216
+
217
+ def __init__(
218
+ self,
219
+ solver_params: ILQRSolverParameters,
220
+ warm_start_params: ILQRWarmStartParameters,
221
+ ) -> None:
222
+ """
223
+ Initialize solver parameters.
224
+ :param solver_params: Contains solver parameters for iLQR.
225
+ :param warm_start_params: Contains warm start parameters for iLQR.
226
+ """
227
+ self._solver_params = solver_params
228
+ self._warm_start_params = warm_start_params
229
+
230
+ self._n_states = 5 # state dimension
231
+ self._n_inputs = 2 # input dimension
232
+
233
+ state_cost_diagonal_entries = self._solver_params.state_cost_diagonal_entries
234
+ assert (
235
+ len(state_cost_diagonal_entries) == self._n_states
236
+ ), f"State cost matrix should have diagonal length {self._n_states}."
237
+ self._state_cost_matrix: DoubleMatrix = np.diag(state_cost_diagonal_entries)
238
+
239
+ input_cost_diagonal_entries = self._solver_params.input_cost_diagonal_entries
240
+ assert (
241
+ len(input_cost_diagonal_entries) == self._n_inputs
242
+ ), f"Input cost matrix should have diagonal length {self._n_inputs}."
243
+ self._input_cost_matrix: DoubleMatrix = np.diag(input_cost_diagonal_entries)
244
+
245
+ state_trust_region_entries = self._solver_params.state_trust_region_entries
246
+ assert (
247
+ len(state_trust_region_entries) == self._n_states
248
+ ), f"State trust region cost matrix should have diagonal length {self._n_states}."
249
+ self._state_trust_region_cost_matrix: DoubleMatrix = np.diag(state_trust_region_entries)
250
+
251
+ input_trust_region_entries = self._solver_params.input_trust_region_entries
252
+ assert (
253
+ len(input_trust_region_entries) == self._n_inputs
254
+ ), f"Input trust region cost matrix should have diagonal length {self._n_inputs}."
255
+ self._input_trust_region_cost_matrix: DoubleMatrix = np.diag(input_trust_region_entries)
256
+
257
+ max_acceleration = self._solver_params.max_acceleration
258
+ max_steering_angle_rate = self._solver_params.max_steering_angle_rate
259
+
260
+ # Define input clip limits once to avoid recomputation in _clip_inputs.
261
+ self._input_clip_min = (-max_acceleration, -max_steering_angle_rate)
262
+ self._input_clip_max = (max_acceleration, max_steering_angle_rate)
263
+
264
+ def solve(self, current_state: DoubleMatrix, reference_trajectory: DoubleMatrix) -> List[ILQRSolution]:
265
+ """
266
+ Run the main iLQR loop used to try to find (locally) optimal inputs to track the reference trajectory.
267
+ :param current_state: The initial state from which we apply inputs, z_0.
268
+ :param reference_trajectory: The state reference we'd like to track, inclusive of the initial timestep,
269
+ z_{r,k} for k in {0, ..., N}.
270
+ :return: A list of solution iterates after running the iLQR algorithm where the index is the iteration number.
271
+ """
272
+ # Check that state parameter has the right shape.
273
+ assert current_state.shape == (self._n_states,), "Incorrect state shape."
274
+
275
+ # Check that reference trajectory parameter has the right shape.
276
+ assert len(reference_trajectory.shape) == 2, "Reference trajectory should be a 2D matrix."
277
+ reference_trajectory_length, reference_trajectory_state_dimension = reference_trajectory.shape
278
+ assert reference_trajectory_length > 1, "The reference trajectory should be at least two timesteps long."
279
+ assert (
280
+ reference_trajectory_state_dimension == self._n_states
281
+ ), "The reference trajectory should have a matching state dimension."
282
+
283
+ # List of ILQRSolution results where the index corresponds to the iteration of iLQR.
284
+ solution_list: List[ILQRSolution] = []
285
+
286
+ # Get warm start input and state trajectory, as well as associated Jacobians.
287
+ current_iterate = self._input_warm_start(current_state, reference_trajectory)
288
+
289
+ # Main iLQR Loop.
290
+ solve_start_time = time.perf_counter()
291
+ for _ in range(self._solver_params.max_ilqr_iterations):
292
+ # Determine the cost and store the associated solution object.
293
+ tracking_cost = self._compute_tracking_cost(
294
+ iterate=current_iterate,
295
+ reference_trajectory=reference_trajectory,
296
+ )
297
+ solution_list.append(
298
+ ILQRSolution(
299
+ input_trajectory=current_iterate.input_trajectory,
300
+ state_trajectory=current_iterate.state_trajectory,
301
+ tracking_cost=tracking_cost,
302
+ )
303
+ )
304
+
305
+ # Determine the LQR optimal perturbations to apply.
306
+ lqr_input_policy = self._run_lqr_backward_recursion(
307
+ current_iterate=current_iterate,
308
+ reference_trajectory=reference_trajectory,
309
+ )
310
+
311
+ # Apply the optimal perturbations to generate the next input trajectory iterate.
312
+ input_trajectory_next = self._update_inputs_with_policy(
313
+ current_iterate=current_iterate,
314
+ lqr_input_policy=lqr_input_policy,
315
+ )
316
+
317
+ # Check for convergence/timeout and terminate early if so.
318
+ # Else update the input_trajectory iterate and continue.
319
+ input_trajectory_norm_difference = np.linalg.norm(input_trajectory_next - current_iterate.input_trajectory)
320
+
321
+ current_iterate = self._run_forward_dynamics(current_state, input_trajectory_next)
322
+
323
+ if input_trajectory_norm_difference < self._solver_params.convergence_threshold:
324
+ break
325
+
326
+ elapsed_time = time.perf_counter() - solve_start_time
327
+ if (
328
+ isinstance(self._solver_params.max_solve_time, float)
329
+ and elapsed_time >= self._solver_params.max_solve_time
330
+ ):
331
+ break
332
+
333
+ # Store the final iterate in the solution_dict.
334
+ tracking_cost = self._compute_tracking_cost(
335
+ iterate=current_iterate,
336
+ reference_trajectory=reference_trajectory,
337
+ )
338
+ solution_list.append(
339
+ ILQRSolution(
340
+ input_trajectory=current_iterate.input_trajectory,
341
+ state_trajectory=current_iterate.state_trajectory,
342
+ tracking_cost=tracking_cost,
343
+ )
344
+ )
345
+
346
+ return solution_list
347
+
348
+ ####################################################################################################################
349
+ # Helper methods.
350
+ ####################################################################################################################
351
+
352
+ def _compute_tracking_cost(self, iterate: ILQRIterate, reference_trajectory: DoubleMatrix) -> float:
353
+ """
354
+ Compute the trajectory tracking cost given a candidate solution.
355
+ :param iterate: Contains the candidate state and input trajectory to evaluate.
356
+ :param reference_trajectory: The desired state reference trajectory with same length as state_trajectory.
357
+ :return: The tracking cost of the candidate state/input trajectory.
358
+ """
359
+ input_trajectory = iterate.input_trajectory
360
+ state_trajectory = iterate.state_trajectory
361
+
362
+ assert len(state_trajectory) == len(
363
+ reference_trajectory
364
+ ), "The state and reference trajectory should have the same length."
365
+
366
+ error_state_trajectory = state_trajectory - reference_trajectory
367
+ error_state_trajectory[:, 2] = principal_value(error_state_trajectory[:, 2])
368
+
369
+ cost = np.sum([u.T @ self._input_cost_matrix @ u for u in input_trajectory]) + np.sum(
370
+ [e.T @ self._state_cost_matrix @ e for e in error_state_trajectory]
371
+ )
372
+
373
+ return float(cost)
374
+
375
+ def _clip_inputs(self, inputs: DoubleMatrix) -> DoubleMatrix:
376
+ """
377
+ Used to clip control inputs within constraints.
378
+ :param: inputs: The control inputs with shape (self._n_inputs,) to clip.
379
+ :return: Clipped version of the control inputs, unmodified if already within constraints.
380
+ """
381
+ assert inputs.shape == (self._n_inputs,), f"The inputs should be a 1D vector with {self._n_inputs} elements."
382
+
383
+ return np.clip(inputs, self._input_clip_min, self._input_clip_max) # type: ignore
384
+
385
+ def _clip_steering_angle(self, steering_angle: float) -> float:
386
+ """
387
+ Used to clip the steering angle state within bounds.
388
+ :param steering_angle: [rad] A steering angle (scalar) to clip.
389
+ :return: [rad] The clipped steering angle.
390
+ """
391
+ steering_angle_sign = 1.0 if steering_angle >= 0 else -1.0
392
+ steering_angle = steering_angle_sign * min(abs(steering_angle), self._solver_params.max_steering_angle)
393
+ return steering_angle
394
+
395
+ def _input_warm_start(self, current_state: DoubleMatrix, reference_trajectory: DoubleMatrix) -> ILQRIterate:
396
+ """
397
+ Given a reference trajectory, we generate the warm start (initial guess) by inferring the inputs applied based
398
+ on poses in the reference trajectory.
399
+ :param current_state: The initial state from which we apply inputs.
400
+ :param reference_trajectory: The reference trajectory we are trying to follow.
401
+ :return: The warm start iterate from which to start iLQR.
402
+ """
403
+ reference_states_completed, reference_inputs_completed = complete_kinematic_state_and_inputs_from_poses(
404
+ discretization_time=self._solver_params.discretization_time,
405
+ wheel_base=self._solver_params.wheelbase,
406
+ poses=reference_trajectory[:, :3],
407
+ jerk_penalty=self._warm_start_params.jerk_penalty_warm_start_fit,
408
+ curvature_rate_penalty=self._warm_start_params.curvature_rate_penalty_warm_start_fit,
409
+ )
410
+
411
+ # We could just stop here and apply reference_inputs_completed (assuming it satisfies constraints).
412
+ # This could work if current_state = reference_states_completed[0,:] - i.e. no initial tracking error.
413
+ # We add feedback input terms for the first control input only to account for nonzero initial tracking error.
414
+ _, _, _, velocity_current, steering_angle_current = current_state
415
+ _, _, _, velocity_reference, steering_angle_reference = reference_states_completed[0, :]
416
+
417
+ acceleration_feedback = -self._warm_start_params.k_velocity_error_feedback * (
418
+ velocity_current - velocity_reference
419
+ )
420
+
421
+ steering_angle_feedback = compute_steering_angle_feedback(
422
+ pose_reference=current_state[:3],
423
+ pose_current=reference_states_completed[0, :3],
424
+ lookahead_distance=self._warm_start_params.lookahead_distance_lateral_error,
425
+ k_lateral_error=self._warm_start_params.k_lateral_error,
426
+ )
427
+ steering_angle_desired = steering_angle_feedback + steering_angle_reference
428
+ steering_rate_feedback = -self._warm_start_params.k_steering_angle_error_feedback * (
429
+ steering_angle_current - steering_angle_desired
430
+ )
431
+
432
+ reference_inputs_completed[0, 0] += acceleration_feedback
433
+ reference_inputs_completed[0, 1] += steering_rate_feedback
434
+
435
+ # We rerun dynamics with constraints applied to make sure we have a feasible warm start for iLQR.
436
+ return self._run_forward_dynamics(current_state, reference_inputs_completed)
437
+
438
+ ####################################################################################################################
439
+ # Dynamics and Jacobian.
440
+ ####################################################################################################################
441
+
442
+ def _run_forward_dynamics(self, current_state: DoubleMatrix, input_trajectory: DoubleMatrix) -> ILQRIterate:
443
+ """
444
+ Compute states and corresponding state/input Jacobian matrices using forward dynamics.
445
+ We additionally return the input since the dynamics may modify the input to ensure constraint satisfaction.
446
+ :param current_state: The initial state from which we apply inputs. Must be feasible given constraints.
447
+ :param input_trajectory: The input trajectory applied to the model. May be modified to ensure feasibility.
448
+ :return: A feasible iterate after applying dynamics with state/input trajectories and Jacobian matrices.
449
+ """
450
+ # Store rollout as a set of numpy arrays, initialized as np.nan to ensure we correctly fill them in.
451
+ # The state trajectory includes the current_state, z_0, and is 1 element longer than the other arrays.
452
+ # The final_input_trajectory captures the applied input for the dynamics model satisfying constraints.
453
+ N = len(input_trajectory)
454
+ state_trajectory = np.nan * np.ones((N + 1, self._n_states), dtype=np.float64)
455
+ final_input_trajectory = np.nan * np.ones_like(input_trajectory, dtype=np.float64)
456
+ state_jacobian_trajectory = np.nan * np.ones((N, self._n_states, self._n_states), dtype=np.float64)
457
+ final_input_jacobian_trajectory = np.nan * np.ones((N, self._n_states, self._n_inputs), dtype=np.float64)
458
+
459
+ state_trajectory[0] = current_state
460
+
461
+ for idx_u, u in enumerate(input_trajectory):
462
+ state_next, final_input, state_jacobian, final_input_jacobian = self._dynamics_and_jacobian(
463
+ state_trajectory[idx_u], u
464
+ )
465
+
466
+ state_trajectory[idx_u + 1] = state_next
467
+ final_input_trajectory[idx_u] = final_input
468
+ state_jacobian_trajectory[idx_u] = state_jacobian
469
+ final_input_jacobian_trajectory[idx_u] = final_input_jacobian
470
+
471
+ iterate = ILQRIterate(
472
+ state_trajectory=state_trajectory, # type: ignore
473
+ input_trajectory=final_input_trajectory, # type: ignore
474
+ state_jacobian_trajectory=state_jacobian_trajectory, # type: ignore
475
+ input_jacobian_trajectory=final_input_jacobian_trajectory, # type: ignore
476
+ )
477
+
478
+ return iterate
479
+
480
+ def _dynamics_and_jacobian(
481
+ self, current_state: DoubleMatrix, current_input: DoubleMatrix
482
+ ) -> Tuple[DoubleMatrix, DoubleMatrix, DoubleMatrix, DoubleMatrix]:
483
+ """
484
+ Propagates the state forward by one step and computes the corresponding state and input Jacobian matrices.
485
+ We also impose all constraints here to ensure the current input and next state are always feasible.
486
+ :param current_state: The current state z_k.
487
+ :param current_input: The applied input u_k.
488
+ :return: The next state z_{k+1}, (possibly modified) input u_k, and state (df/dz) and input (df/du) Jacobians.
489
+ """
490
+ x, y, heading, velocity, steering_angle = current_state
491
+
492
+ # Check steering angle is in expected range for valid Jacobian matrices.
493
+ assert (
494
+ np.abs(steering_angle) < np.pi / 2.0
495
+ ), f"The steering angle {steering_angle} is outside expected limits. There is a singularity at delta = np.pi/2."
496
+
497
+ # Input constraints: clip inputs within bounds and then use.
498
+ current_input = self._clip_inputs(current_input)
499
+ acceleration, steering_rate = current_input
500
+
501
+ # Euler integration of bicycle model.
502
+ discretization_time = self._solver_params.discretization_time
503
+ wheelbase = self._solver_params.wheelbase
504
+
505
+ next_state: DoubleMatrix = np.copy(current_state)
506
+ next_state[0] += velocity * np.cos(heading) * discretization_time
507
+ next_state[1] += velocity * np.sin(heading) * discretization_time
508
+ next_state[2] += velocity * np.tan(steering_angle) / wheelbase * discretization_time
509
+ next_state[3] += acceleration * discretization_time
510
+ next_state[4] += steering_rate * discretization_time
511
+
512
+ # Constrain heading angle to lie within +/- pi.
513
+ next_state[2] = principal_value(next_state[2])
514
+
515
+ # State constraints: clip the steering_angle within bounds and update steering_rate accordingly.
516
+ next_steering_angle = self._clip_steering_angle(next_state[4])
517
+ applied_steering_rate = (next_steering_angle - steering_angle) / discretization_time
518
+ next_state[4] = next_steering_angle
519
+ current_input[1] = applied_steering_rate
520
+
521
+ # Now we construct and populate the state and input Jacobians.
522
+ state_jacobian: DoubleMatrix = np.eye(self._n_states, dtype=np.float64)
523
+ input_jacobian: DoubleMatrix = np.zeros((self._n_states, self._n_inputs), dtype=np.float64)
524
+
525
+ # Set a nonzero velocity to handle issues when linearizing at (near) zero velocity.
526
+ # This helps e.g. when the vehicle is stopped with zero steering angle and needs to accelerate/turn.
527
+ # Without this, the A matrix will indicate steering has no impact on heading due to Euler discretization.
528
+ # There will be a rank drop in the controllability matrix, so the discrete-time algebraic Riccati equation
529
+ # may not have a solution (uncontrollable subspace) or it may not be unique.
530
+ min_velocity_linearization = self._solver_params.min_velocity_linearization
531
+ if -min_velocity_linearization <= velocity and velocity <= min_velocity_linearization:
532
+ sign_velocity = 1.0 if velocity >= 0.0 else -1.0
533
+ velocity = sign_velocity * min_velocity_linearization
534
+
535
+ state_jacobian[0, 2] = -velocity * np.sin(heading) * discretization_time
536
+ state_jacobian[0, 3] = np.cos(heading) * discretization_time
537
+
538
+ state_jacobian[1, 2] = velocity * np.cos(heading) * discretization_time
539
+ state_jacobian[1, 3] = np.sin(heading) * discretization_time
540
+
541
+ state_jacobian[2, 3] = np.tan(steering_angle) / wheelbase * discretization_time
542
+ state_jacobian[2, 4] = velocity * discretization_time / (wheelbase * np.cos(steering_angle) ** 2)
543
+
544
+ input_jacobian[3, 0] = discretization_time
545
+ input_jacobian[4, 1] = discretization_time
546
+
547
+ return next_state, current_input, state_jacobian, input_jacobian
548
+
549
+ ####################################################################################################################
550
+ # Core LQR implementation.
551
+ ####################################################################################################################
552
+
553
+ def _run_lqr_backward_recursion(
554
+ self,
555
+ current_iterate: ILQRIterate,
556
+ reference_trajectory: DoubleMatrix,
557
+ ) -> ILQRInputPolicy:
558
+ """
559
+ Computes the locally optimal affine state feedback policy by applying dynamic programming to linear perturbation
560
+ dynamics about a specified linearization trajectory. We include a trust region penalty as part of the cost.
561
+ :param current_iterate: Contains all relevant linearization information needed to compute LQR policy.
562
+ :param reference_trajectory: The desired state trajectory we are tracking.
563
+ :return: An affine state feedback policy - state feedback matrices and feedforward inputs found using LQR.
564
+ """
565
+ state_trajectory = current_iterate.state_trajectory
566
+ input_trajectory = current_iterate.input_trajectory
567
+ state_jacobian_trajectory = current_iterate.state_jacobian_trajectory
568
+ input_jacobian_trajectory = current_iterate.input_jacobian_trajectory
569
+
570
+ # Check reference matches the expected shape.
571
+ assert reference_trajectory.shape == state_trajectory.shape, "The reference trajectory has incorrect shape."
572
+
573
+ # Compute nominal error trajectory.
574
+ error_state_trajectory = state_trajectory - reference_trajectory
575
+ error_state_trajectory[:, 2] = principal_value(error_state_trajectory[:, 2])
576
+
577
+ # The value function has the form V_k(\Delta z_k) = \Delta z_k^T P_k \Delta z_k + 2 \rho_k^T \Delta z_k.
578
+ # So p_current = P_k is related to the Hessian of the value function at the current timestep.
579
+ # And rho_current = rho_k is part of the linear cost term in the value function at the current timestep.
580
+ p_current = self._state_cost_matrix + self._state_trust_region_cost_matrix
581
+ rho_current = self._state_cost_matrix @ error_state_trajectory[-1]
582
+
583
+ # The optimal LQR policy has the form \Delta u_k^* = K_k \Delta z_k + \kappa_k
584
+ # We refer to K_k as state_feedback_matrix and \kappa_k as feedforward input in the code below.
585
+ N = len(input_trajectory)
586
+ state_feedback_matrices = np.nan * np.ones((N, self._n_inputs, self._n_states), dtype=np.float64)
587
+ feedforward_inputs = np.nan * np.ones((N, self._n_inputs), dtype=np.float64)
588
+
589
+ for i in reversed(range(N)):
590
+ A = state_jacobian_trajectory[i]
591
+ B = input_jacobian_trajectory[i]
592
+ u = input_trajectory[i]
593
+ error = error_state_trajectory[i]
594
+
595
+ # Compute the optimal input policy for this timestep.
596
+ inverse_matrix_term = np.linalg.inv(
597
+ self._input_cost_matrix + self._input_trust_region_cost_matrix + B.T @ p_current @ B
598
+ ) # invertible since we checked input_cost / input_trust_region_cost are positive definite during creation.
599
+ state_feedback_matrix = -inverse_matrix_term @ B.T @ p_current @ A
600
+ feedforward_input = -inverse_matrix_term @ (self._input_cost_matrix @ u + B.T @ rho_current)
601
+
602
+ # Compute the optimal value function for this timestep.
603
+ a_closed_loop = A + B @ state_feedback_matrix
604
+
605
+ p_prior = (
606
+ self._state_cost_matrix
607
+ + self._state_trust_region_cost_matrix
608
+ + state_feedback_matrix.T @ self._input_cost_matrix @ state_feedback_matrix
609
+ + state_feedback_matrix.T @ self._input_trust_region_cost_matrix @ state_feedback_matrix
610
+ + a_closed_loop.T @ p_current @ a_closed_loop
611
+ )
612
+
613
+ rho_prior = (
614
+ self._state_cost_matrix @ error
615
+ + state_feedback_matrix.T @ self._input_cost_matrix @ (feedforward_input + u)
616
+ + state_feedback_matrix.T @ self._input_trust_region_cost_matrix @ feedforward_input
617
+ + a_closed_loop.T @ p_current @ B @ feedforward_input
618
+ + a_closed_loop.T @ rho_current
619
+ )
620
+
621
+ p_current = p_prior
622
+ rho_current = rho_prior
623
+
624
+ state_feedback_matrices[i] = state_feedback_matrix
625
+ feedforward_inputs[i] = feedforward_input
626
+
627
+ lqr_input_policy = ILQRInputPolicy(
628
+ state_feedback_matrices=state_feedback_matrices, # type: ignore
629
+ feedforward_inputs=feedforward_inputs, # type: ignore
630
+ )
631
+
632
+ return lqr_input_policy
633
+
634
+ def _update_inputs_with_policy(
635
+ self,
636
+ current_iterate: ILQRIterate,
637
+ lqr_input_policy: ILQRInputPolicy,
638
+ ) -> DoubleMatrix:
639
+ """
640
+ Used to update an iterate of iLQR by applying a perturbation input policy for local cost improvement.
641
+ :param current_iterate: Contains the state and input trajectory about which we linearized.
642
+ :param lqr_input_policy: Contains the LQR policy to apply.
643
+ :return: The next input trajectory found by applying the LQR policy.
644
+ """
645
+ state_trajectory = current_iterate.state_trajectory
646
+ input_trajectory = current_iterate.input_trajectory
647
+
648
+ # Trajectory of state perturbations while applying feedback policy.
649
+ # Starts with zero as the initial states match exactly, only later states might vary.
650
+ delta_state_trajectory = np.nan * np.ones((len(input_trajectory) + 1, self._n_states), dtype=np.float64)
651
+ delta_state_trajectory[0] = [0.0] * self._n_states
652
+
653
+ # This is the updated input trajectory we will return after applying the input perturbations.
654
+ input_next_trajectory = np.nan * np.ones_like(input_trajectory, dtype=np.float64)
655
+
656
+ zip_object = zip(
657
+ input_trajectory,
658
+ state_trajectory[:-1],
659
+ state_trajectory[1:],
660
+ lqr_input_policy.state_feedback_matrices,
661
+ lqr_input_policy.feedforward_inputs,
662
+ )
663
+
664
+ for input_idx, (input_lin, state_lin, state_lin_next, state_feedback_matrix, feedforward_input) in enumerate(
665
+ zip_object
666
+ ):
667
+ # Compute locally optimal input perturbation.
668
+ delta_state = delta_state_trajectory[input_idx]
669
+ delta_input = state_feedback_matrix @ delta_state + feedforward_input
670
+
671
+ # Apply state and input perturbation.
672
+ input_perturbed = input_lin + delta_input
673
+ state_perturbed = state_lin + delta_state
674
+ state_perturbed[2] = principal_value(state_perturbed[2])
675
+
676
+ # Run dynamics with perturbed state/inputs to get next state.
677
+ # We get the actually applied input since it might have been clipped/modified to satisfy constraints.
678
+ state_perturbed_next, input_perturbed, _, _ = self._dynamics_and_jacobian(state_perturbed, input_perturbed)
679
+
680
+ # Compute next state perturbation given next state.
681
+ delta_state_next = state_perturbed_next - state_lin_next
682
+ delta_state_next[2] = principal_value(delta_state_next[2])
683
+
684
+ delta_state_trajectory[input_idx + 1] = delta_state_next
685
+ input_next_trajectory[input_idx] = input_perturbed
686
+
687
+ assert ~np.any(np.isnan(input_next_trajectory)), "All next inputs should be valid float values."
688
+
689
+ return input_next_trajectory # type: ignore
code/sim/ilqr/utils.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import numpy as np
3
+ import numpy.typing as npt
4
+
5
+ DoubleMatrix = npt.NDArray[np.float64]
6
+
7
+ def principal_value(angle, min_=-np.pi):
8
+ """
9
+ Wrap heading angle in to specified domain (multiples of 2 pi alias),
10
+ ensuring that the angle is between min_ and min_ + 2 pi. This function raises an error if the angle is infinite
11
+ :param angle: rad
12
+ :param min_: minimum domain for angle (rad)
13
+ :return angle wrapped to [min_, min_ + 2 pi).
14
+ """
15
+ assert np.all(np.isfinite(angle)), "angle is not finite"
16
+ lhs = (angle - min_) % (2 * np.pi) + min_
17
+ return lhs
18
+
19
+ def compute_steering_angle_feedback(
20
+ pose_reference, pose_current, lookahead_distance, k_lateral_error
21
+ ):
22
+ """
23
+ Given pose information, determines the steering angle feedback value to address initial tracking error.
24
+ This is based on the feedback controller developed in Section 2.2 of the following paper:
25
+ https://ddl.stanford.edu/publications/design-feedback-feedforward-steering-controller-accurate-path-tracking-and-stability
26
+ :param pose_reference: <np.ndarray: 3,> Contains the reference pose at the current timestep.
27
+ :param pose_current: <np.ndarray: 3,> Contains the actual pose at the current timestep.
28
+ :param lookahead_distance: [m] Distance ahead for which we should estimate lateral error based on a linear fit.
29
+ :param k_lateral_error: Feedback gain for lateral error used to determine steering angle feedback.
30
+ :return: [rad] The steering angle feedback to apply.
31
+ """
32
+ assert pose_reference.shape == (3,), "We expect a single reference pose."
33
+ assert pose_current.shape == (3,), "We expect a single current pose."
34
+
35
+ assert lookahead_distance > 0.0, "Lookahead distance should be positive."
36
+ assert k_lateral_error > 0.0, "Feedback gain for lateral error should be positive."
37
+
38
+ x_reference, y_reference, heading_reference = pose_reference
39
+ x_current, y_current, heading_current = pose_current
40
+
41
+ x_error = x_current - x_reference
42
+ y_error = y_current - y_reference
43
+ heading_error = principal_value(heading_current - heading_reference)
44
+
45
+ lateral_error = -x_error * np.sin(heading_reference) + y_error * np.cos(heading_reference)
46
+
47
+ return float(-k_lateral_error * (lateral_error + lookahead_distance * heading_error))
48
+
49
+ def _convert_curvature_profile_to_steering_profile(
50
+ curvature_profile: DoubleMatrix,
51
+ discretization_time: float,
52
+ wheel_base: float,
53
+ ) -> Tuple[DoubleMatrix, DoubleMatrix]:
54
+ """
55
+ Converts from a curvature profile to the corresponding steering profile.
56
+ We assume a kinematic bicycle model where curvature = tan(steering_angle) / wheel_base.
57
+ For simplicity, we just use finite differences to determine steering rate.
58
+ :param curvature_profile: [rad] Curvature trajectory to convert.
59
+ :param discretization_time: [s] Time discretization used for integration.
60
+ :param wheel_base: [m] The vehicle wheelbase parameter required for conversion.
61
+ :return: The [rad] steering angle and [rad/s] steering rate (derivative) profiles.
62
+ """
63
+ assert discretization_time > 0.0, "Discretization time must be positive."
64
+ assert wheel_base > 0.0, "The vehicle's wheelbase length must be positive."
65
+
66
+ steering_angle_profile = np.arctan(wheel_base * curvature_profile)
67
+ steering_rate_profile = np.diff(steering_angle_profile) / discretization_time
68
+
69
+ return steering_angle_profile, steering_rate_profile
70
+
71
+
72
+ def _get_xy_heading_displacements_from_poses(poses: DoubleMatrix) -> Tuple[DoubleMatrix, DoubleMatrix]:
73
+ """
74
+ Returns position and heading displacements given a pose trajectory.
75
+ :param poses: <np.ndarray: num_poses, 3> A trajectory of poses (x, y, heading).
76
+ :return: Tuple of xy displacements with shape (num_poses-1, 2) and heading displacements with shape (num_poses-1,).
77
+ """
78
+ assert len(poses.shape) == 2, "Expect a 2D matrix representing a trajectory of poses."
79
+ assert poses.shape[0] > 1, "Cannot get displacements given an empty or single element pose trajectory."
80
+ assert poses.shape[1] == 3, "Expect pose to have three elements (x, y, heading)."
81
+
82
+ # Compute displacements that are used to complete the kinematic state and input.
83
+ pose_differences = np.diff(poses, axis=0)
84
+ xy_displacements = pose_differences[:, :2]
85
+ heading_displacements = principal_value(pose_differences[:, 2])
86
+
87
+ return xy_displacements, heading_displacements
88
+
89
+
90
+ def _make_banded_difference_matrix(number_rows: int) -> DoubleMatrix:
91
+ """
92
+ Returns a banded difference matrix with specified number_rows.
93
+ When applied to a vector [x_1, ..., x_N], it returns [x_2 - x_1, ..., x_N - x_{N-1}].
94
+ :param number_rows: The row dimension of the banded difference matrix (e.g. N-1 in the example above).
95
+ :return: A banded difference matrix with shape (number_rows, number_rows+1).
96
+ """
97
+ banded_matrix: DoubleMatrix = -1.0 * np.eye(number_rows + 1, dtype=np.float64)[:-1, :]
98
+ for ind in range(len(banded_matrix)):
99
+ banded_matrix[ind, ind + 1] = 1.0
100
+
101
+ return banded_matrix
102
+
103
+
104
+
105
+ def _fit_initial_velocity_and_acceleration_profile(
106
+ xy_displacements: DoubleMatrix, heading_profile: DoubleMatrix, discretization_time: float, jerk_penalty: float
107
+ ) -> Tuple[float, DoubleMatrix]:
108
+ """
109
+ Estimates initial velocity (v_0) and acceleration ({a_0, ...}) using least squares with jerk penalty regularization.
110
+ :param xy_displacements: [m] Deviations in x and y occurring between M+1 poses, a M by 2 matrix.
111
+ :param heading_profile: [rad] Headings associated to the starting timestamp for xy_displacements, a M-length vector.
112
+ :param discretization_time: [s] Time discretization used for integration.
113
+ :param jerk_penalty: A regularization parameter used to penalize acceleration differences. Should be positive.
114
+ :return: Least squares solution for initial velocity (v_0) and acceleration profile ({a_0, ..., a_M-1})
115
+ for M displacement values.
116
+ """
117
+ assert discretization_time > 0.0, "Discretization time must be positive."
118
+ assert jerk_penalty > 0, "Should have a positive jerk_penalty."
119
+
120
+ assert len(xy_displacements.shape) == 2, "Expect xy_displacements to be a matrix."
121
+ assert xy_displacements.shape[1] == 2, "Expect xy_displacements to have 2 columns."
122
+
123
+ num_displacements = len(xy_displacements) # aka M in the docstring
124
+
125
+ assert heading_profile.shape == (
126
+ num_displacements,
127
+ ), "Expect the length of heading_profile to match that of xy_displacements."
128
+
129
+ # Core problem: minimize_x ||y-Ax||_2
130
+ y = xy_displacements.flatten() # Flatten to a vector, [delta x_0, delta y_0, ...]
131
+
132
+ A: DoubleMatrix = np.zeros((2 * num_displacements, num_displacements), dtype=np.float64)
133
+ for idx_timestep, heading in enumerate(heading_profile):
134
+ start_row = 2 * idx_timestep # Which row of A corresponds to x-coordinate information at timestep k.
135
+
136
+ # Related to v_0, initial velocity - column 0.
137
+ # We fill in rows for measurements delta x_k, delta y_k.
138
+ A[start_row : (start_row + 2), 0] = np.array(
139
+ [
140
+ np.cos(heading) * discretization_time,
141
+ np.sin(heading) * discretization_time,
142
+ ],
143
+ dtype=np.float64,
144
+ )
145
+
146
+ if idx_timestep > 0:
147
+ # Related to {a_0, ..., a_k-1}, acceleration profile - column 1 to k.
148
+ # We fill in rows for measurements delta x_k, delta y_k.
149
+ A[start_row : (start_row + 2), 1 : (1 + idx_timestep)] = np.array(
150
+ [
151
+ [np.cos(heading) * discretization_time**2],
152
+ [np.sin(heading) * discretization_time**2],
153
+ ],
154
+ dtype=np.float64,
155
+ )
156
+
157
+ # Regularization using jerk penalty, i.e. difference of acceleration values.
158
+ # If there are M displacements, then we have M - 1 acceleration values.
159
+ # That means we have M - 2 jerk values, thus we make a banded difference matrix of that size.
160
+ banded_matrix = _make_banded_difference_matrix(num_displacements - 2)
161
+ R: DoubleMatrix = np.block([np.zeros((len(banded_matrix), 1)), banded_matrix])
162
+
163
+ # Compute regularized least squares solution.
164
+ x = np.linalg.pinv(A.T @ A + jerk_penalty * R.T @ R) @ A.T @ y
165
+
166
+ # Extract profile from solution.
167
+ initial_velocity = x[0]
168
+ acceleration_profile = x[1:]
169
+
170
+ return initial_velocity, acceleration_profile
171
+
172
+
173
+ def _generate_profile_from_initial_condition_and_derivatives(
174
+ initial_condition: float, derivatives: DoubleMatrix, discretization_time: float
175
+ ) -> DoubleMatrix:
176
+ """
177
+ Returns the corresponding profile (i.e. trajectory) given an initial condition and derivatives at
178
+ multiple timesteps by integration.
179
+ :param initial_condition: The value of the variable at the initial timestep.
180
+ :param derivatives: The trajectory of time derivatives of the variable at timesteps 0,..., N-1.
181
+ :param discretization_time: [s] Time discretization used for integration.
182
+ :return: The trajectory of the variable at timesteps 0,..., N.
183
+ """
184
+ assert discretization_time > 0.0, "Discretization time must be positive."
185
+
186
+ profile = initial_condition + np.insert(np.cumsum(derivatives * discretization_time), 0, 0.0)
187
+
188
+ return profile # type: ignore
189
+
190
+
191
+ def _fit_initial_curvature_and_curvature_rate_profile(
192
+ heading_displacements: DoubleMatrix,
193
+ velocity_profile: DoubleMatrix,
194
+ discretization_time: float,
195
+ curvature_rate_penalty: float,
196
+ initial_curvature_penalty: float = 1e-10,
197
+ ) -> Tuple[float, DoubleMatrix]:
198
+ """
199
+ Estimates initial curvature (curvature_0) and curvature rate ({curvature_rate_0, ...})
200
+ using least squares with curvature rate regularization.
201
+ :param heading_displacements: [rad] Angular deviations in heading occuring between timesteps.
202
+ :param velocity_profile: [m/s] Estimated or actual velocities at the timesteps matching displacements.
203
+ :param discretization_time: [s] Time discretization used for integration.
204
+ :param curvature_rate_penalty: A regularization parameter used to penalize curvature_rate. Should be positive.
205
+ :param initial_curvature_penalty: A regularization parameter to handle zero initial speed. Should be positive and small.
206
+ :return: Least squares solution for initial curvature (curvature_0) and curvature rate profile
207
+ (curvature_rate_0, ..., curvature_rate_{M-1}) for M heading displacement values.
208
+ """
209
+ assert discretization_time > 0.0, "Discretization time must be positive."
210
+ assert curvature_rate_penalty > 0.0, "Should have a positive curvature_rate_penalty."
211
+ assert initial_curvature_penalty > 0.0, "Should have a positive initial_curvature_penalty."
212
+
213
+ # Core problem: minimize_x ||y-Ax||_2
214
+ y = heading_displacements
215
+ A: DoubleMatrix = np.tri(len(y), dtype=np.float64) # lower triangular matrix
216
+ A[:, 0] = velocity_profile * discretization_time
217
+
218
+ for idx, velocity in enumerate(velocity_profile):
219
+ if idx == 0:
220
+ continue
221
+ A[idx, 1:] *= velocity * discretization_time**2
222
+
223
+ # Regularization on curvature rate. We add a small but nonzero weight on initial curvature too.
224
+ # This is since the corresponding row of the A matrix might be zero if initial speed is 0, leading to singularity.
225
+ # We guarantee that Q is positive definite such that the minimizer of the least squares problem is unique.
226
+ Q: DoubleMatrix = curvature_rate_penalty * np.eye(len(y))
227
+ Q[0, 0] = initial_curvature_penalty
228
+
229
+ # Compute regularized least squares solution.
230
+ x = np.linalg.pinv(A.T @ A + Q) @ A.T @ y
231
+
232
+ # Extract profile from solution.
233
+ initial_curvature = x[0]
234
+ curvature_rate_profile = x[1:]
235
+
236
+ return initial_curvature, curvature_rate_profile
237
+
238
+
239
+ def get_velocity_curvature_profiles_with_derivatives_from_poses(
240
+ discretization_time: float,
241
+ poses: DoubleMatrix,
242
+ jerk_penalty: float,
243
+ curvature_rate_penalty: float,
244
+ ) -> Tuple[DoubleMatrix, DoubleMatrix, DoubleMatrix, DoubleMatrix]:
245
+ """
246
+ Main function for joint estimation of velocity, acceleration, curvature, and curvature rate given N poses
247
+ sampled at discretization_time. This is done by solving two least squares problems with the given penalty weights.
248
+ :param discretization_time: [s] Time discretization used for integration.
249
+ :param poses: <np.ndarray: num_poses, 3> A trajectory of N poses (x, y, heading).
250
+ :param jerk_penalty: A regularization parameter used to penalize acceleration differences. Should be positive.
251
+ :param curvature_rate_penalty: A regularization parameter used to penalize curvature_rate. Should be positive.
252
+ :return: Profiles for velocity (N-1), acceleration (N-2), curvature (N-1), and curvature rate (N-2).
253
+ """
254
+ xy_displacements, heading_displacements = _get_xy_heading_displacements_from_poses(poses)
255
+
256
+ # Compute initial velocity + acceleration least squares solution and extract results.
257
+ # Note: If we have M displacements, we require the M associated heading values.
258
+ # Therefore, we exclude the last heading in the call below.
259
+ initial_velocity, acceleration_profile = _fit_initial_velocity_and_acceleration_profile(
260
+ xy_displacements=xy_displacements,
261
+ heading_profile=poses[:-1, 2],
262
+ discretization_time=discretization_time,
263
+ jerk_penalty=jerk_penalty,
264
+ )
265
+
266
+ velocity_profile = _generate_profile_from_initial_condition_and_derivatives(
267
+ initial_condition=initial_velocity,
268
+ derivatives=acceleration_profile,
269
+ discretization_time=discretization_time,
270
+ )
271
+
272
+ # Compute initial curvature + curvature rate least squares solution and extract results. It relies on velocity fit.
273
+ initial_curvature, curvature_rate_profile = _fit_initial_curvature_and_curvature_rate_profile(
274
+ heading_displacements=heading_displacements,
275
+ velocity_profile=velocity_profile,
276
+ discretization_time=discretization_time,
277
+ curvature_rate_penalty=curvature_rate_penalty,
278
+ )
279
+
280
+ curvature_profile = _generate_profile_from_initial_condition_and_derivatives(
281
+ initial_condition=initial_curvature,
282
+ derivatives=curvature_rate_profile,
283
+ discretization_time=discretization_time,
284
+ )
285
+
286
+ return velocity_profile, acceleration_profile, curvature_profile, curvature_rate_profile
287
+
288
+
289
+
290
+ def complete_kinematic_state_and_inputs_from_poses(
291
+ discretization_time: float,
292
+ wheel_base: float,
293
+ poses: DoubleMatrix,
294
+ jerk_penalty: float,
295
+ curvature_rate_penalty: float,
296
+ ) -> Tuple[DoubleMatrix, DoubleMatrix]:
297
+ """
298
+ Main function for joint estimation of velocity, acceleration, steering angle, and steering rate given poses
299
+ sampled at discretization_time and the vehicle wheelbase parameter for curvature -> steering angle conversion.
300
+ One caveat is that we can only determine the first N-1 kinematic states and N-2 kinematic inputs given
301
+ N-1 displacement/difference values, so we need to extrapolate to match the length of poses provided.
302
+ This is handled by repeating the last input and extrapolating the motion model for the last state.
303
+ :param discretization_time: [s] Time discretization used for integration.
304
+ :param wheel_base: [m] The wheelbase length for the kinematic bicycle model being used.
305
+ :param poses: <np.ndarray: num_poses, 3> A trajectory of poses (x, y, heading).
306
+ :param jerk_penalty: A regularization parameter used to penalize acceleration differences. Should be positive.
307
+ :param curvature_rate_penalty: A regularization parameter used to penalize curvature_rate. Should be positive.
308
+ :return: kinematic_states (x, y, heading, velocity, steering_angle) and corresponding
309
+ kinematic_inputs (acceleration, steering_rate).
310
+ """
311
+ (
312
+ velocity_profile,
313
+ acceleration_profile,
314
+ curvature_profile,
315
+ curvature_rate_profile,
316
+ ) = get_velocity_curvature_profiles_with_derivatives_from_poses(
317
+ discretization_time=discretization_time,
318
+ poses=poses,
319
+ jerk_penalty=jerk_penalty,
320
+ curvature_rate_penalty=curvature_rate_penalty,
321
+ )
322
+
323
+ # Convert to steering angle given the wheelbase parameter. At this point, we don't need to worry about curvature.
324
+ steering_angle_profile, steering_rate_profile = _convert_curvature_profile_to_steering_profile(
325
+ curvature_profile=curvature_profile,
326
+ discretization_time=discretization_time,
327
+ wheel_base=wheel_base,
328
+ )
329
+
330
+ # Extend input fits with a repeated element and extrapolate state fits to match length of poses.
331
+ # This is since we fit with N-1 displacements but still have N poses at the end to deal with.
332
+ acceleration_profile = np.append(acceleration_profile, acceleration_profile[-1])
333
+ steering_rate_profile = np.append(steering_rate_profile, steering_rate_profile[-1])
334
+
335
+ velocity_profile = np.append(
336
+ velocity_profile, velocity_profile[-1] + acceleration_profile[-1] * discretization_time
337
+ )
338
+ steering_angle_profile = np.append(
339
+ steering_angle_profile, steering_angle_profile[-1] + steering_rate_profile[-1] * discretization_time
340
+ )
341
+
342
+ # Collect completed state and input in matrices.
343
+ kinematic_states: DoubleMatrix = np.column_stack((poses, velocity_profile, steering_angle_profile))
344
+ kinematic_inputs: DoubleMatrix = np.column_stack((acceleration_profile, steering_rate_profile))
345
+
346
+ return kinematic_states, kinematic_inputs
code/sim/pyproject.toml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ [tool.hatch.build.targets.wheel]
2
+ packages = ["hugsim-env"]
3
+
4
+ [project]
5
+ name = "hugsim-env"
6
+ version = "0.0.1"
7
+ dependencies = [
8
+ "gymnasium",
9
+ ]
code/sim/setup.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="hugsim-env",
5
+ version="0.0.1",
6
+ packages=find_packages(),
7
+ )
code/sim/utils/__pycache__/agent_controller.cpython-311.pyc ADDED
Binary file (16.1 kB). View file
 
code/sim/utils/__pycache__/plan.cpython-311.pyc ADDED
Binary file (17.1 kB). View file
 
code/sim/utils/__pycache__/score_calculator.cpython-311.pyc ADDED
Binary file (32.5 kB). View file
 
code/sim/utils/__pycache__/sim_utils.cpython-311.pyc ADDED
Binary file (8.37 kB). View file
 
code/sim/utils/agent_controller.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import numpy as np
4
+ from trajdata.maps import VectorMap
5
+ from submodules.Pplan.Sampling.spline_planner import SplinePlanner
6
+ import torch
7
+ import time
8
+ import math
9
+ from copy import deepcopy
10
+ from utils.dynamic_utils import unicycle
11
+
12
+
13
+ def constant_tracking(state, path, dt):
14
+ '''
15
+ Args:
16
+ state: current state of the vehicle, of size [x, y, yaw, speed]
17
+ path: the path to follow, of size (N, [x, y, yaw])
18
+ dt: time duration
19
+ '''
20
+
21
+ # find the nearest point in the path
22
+ dists = torch.norm(path[:, :2] - state[None, :2], dim=1)
23
+ nearest_index = torch.argmin(dists)
24
+
25
+ # find the target point
26
+ lookahead_distance = state[3] * dt
27
+ target = path[-1]
28
+ is_end = True
29
+ for i in range(nearest_index + 1, len(path)):
30
+ if torch.norm(path[i, :2] - state[:2]) > lookahead_distance:
31
+ target = path[i]
32
+ is_end = False
33
+ break
34
+
35
+ # compute the new state
36
+ target_distance = torch.norm(target[:2] - state[:2])
37
+ ratio = lookahead_distance / target_distance.clamp(min=1e-6)
38
+ ratio = ratio.clamp(max=1.0)
39
+
40
+ new_state = deepcopy(state)
41
+ new_state[:2] = state[:2] + ratio * (target[:2] - state[:2])
42
+ new_state[2] = torch.atan2(
43
+ state[2].sin() + ratio * (target[2].sin() - state[2].sin()),
44
+ state[2].cos() + ratio * (target[2].cos() - state[2].cos())
45
+ )
46
+ if is_end:
47
+ new_state[3] = 0
48
+
49
+ return new_state
50
+
51
+
52
+ def constant_headaway(states, num_steps, dt):
53
+ '''
54
+ Args:
55
+ states: current states of a batch of vehicles, of size (num_agents, [x, y, yaw, speed])
56
+ num_steps: number of steps to move forward
57
+ dt: time duration
58
+ Return:
59
+ trajs: the trajectories of the vehicles, of size (num_agents, num_steps, [x, y, yaw, speed])
60
+ '''
61
+
62
+ # state: [x, y, yaw, speed]
63
+ x = states[:, 0]
64
+ y = states[:, 1]
65
+ yaw = states[:, 2]
66
+ speed = states[:, 3]
67
+
68
+ # Generate time steps
69
+ t_steps = torch.arange(num_steps) * dt
70
+
71
+ # Calculate dx and dy for each step
72
+ dx = torch.outer(speed * torch.sin(yaw), t_steps)
73
+ dy = torch.outer(speed * torch.cos(yaw), t_steps)
74
+
75
+ # Update x and y positions
76
+ x_traj = x.unsqueeze(1) + dx
77
+ y_traj = y.unsqueeze(1) + dy
78
+
79
+ # Replicate the yaw and speed for each time step
80
+ yaw_traj = yaw.unsqueeze(1).repeat(1, num_steps)
81
+ speed_traj = speed.unsqueeze(1).repeat(1, num_steps)
82
+
83
+ # Stack the x, y, yaw, and speed components to form the trajectory
84
+ trajs = torch.stack((x_traj, y_traj, yaw_traj, speed_traj), dim=-1)
85
+
86
+ return trajs
87
+
88
+
89
+ class IDM:
90
+ def __init__(
91
+ self, v0=30.0, s0=5.0, T=2.0, a=2.0, b=4.0, delta=4.0,
92
+ lookahead_path_length=100, lead_distance_threshold=1.0
93
+ ):
94
+ '''
95
+ Args:
96
+ v0: desired speed
97
+ s0: minimum gap
98
+ T: safe time headway
99
+ a: max acceleration
100
+ b: comfortable deceleration
101
+ delta: acceleration exponent
102
+ lookahead_path_length: the length of path to look ahead
103
+ lead_distance_threshold: the distance to consider a vehicle as a lead vehicle
104
+ '''
105
+ self.v0 = v0
106
+ self.s0 = s0
107
+ self.T = T
108
+ self.a = a
109
+ self.b = b
110
+ self.delta = delta
111
+ self.lookahead_path_length = lookahead_path_length
112
+ self.lead_distance_threshold = lead_distance_threshold
113
+
114
+ def update(self, state, path, dt, neighbors):
115
+ '''
116
+ Args:
117
+ state: current state of the vehicle, of size [x, y, yaw, speed]
118
+ path: the path to follow, of size (N, [x, y, yaw])
119
+ dt: time duration
120
+ neighbors: the future states of the neighbors, of size (K, T, [x, y, yaw, speed])
121
+ '''
122
+
123
+ if path is None:
124
+ return deepcopy(state)
125
+
126
+ # find the nearest point in the path
127
+ dists = torch.norm(path[:, :2] - state[None, :2], dim=1)
128
+ nearest_index = torch.argmin(dists)
129
+
130
+
131
+ # lookahead_distance = state[3] * dt
132
+ # lookahead_targe = state[:2] + np.array([np.sin(state[2]) * lookahead_distance, np.cos(state[2]) * lookahead_distance])
133
+ # # target = path[-1]
134
+ # is_end = False
135
+ # target_idx = torch.argmin(torch.norm(path[:, :2] - lookahead_targe, dim=-1))
136
+ # target = path[target_idx]
137
+
138
+ # find the target point
139
+ lookahead_distance = state[3] * dt
140
+ target = path[-1]
141
+ is_end = True
142
+ for i in range(nearest_index + 1, len(path)):
143
+ if torch.norm(path[i, :2] - state[:2]) > lookahead_distance:
144
+ target = path[i]
145
+ is_end = False
146
+ break
147
+
148
+ # distance between neighbors and the path
149
+ lookahead_path = path[nearest_index + 1:][:self.lookahead_path_length]
150
+ lookahead_neighbors = neighbors[..., None, :].expand(
151
+ -1, -1, lookahead_path.shape[0], -1
152
+ ) # (K, T, n, 4)
153
+
154
+ dists_neighbors = torch.norm(
155
+ lookahead_neighbors[..., :2] - lookahead_path[None, None, :, :2], dim=-1
156
+ ) # (K, T, n)
157
+ indices_neighbors = torch.arange(
158
+ lookahead_path.shape[0]
159
+ )[None, None].expand_as(dists_neighbors)
160
+
161
+ # determine lead vehicles
162
+ is_lead = (dists_neighbors < self.lead_distance_threshold)
163
+ if is_lead.any():
164
+ # compute lead distance
165
+ indices_lead = indices_neighbors[is_lead] # (num_lead)
166
+ lookahead_lengths = torch.cumsum(torch.norm(
167
+ lookahead_path[1:, :2] - lookahead_path[:-1, :2], dim=1
168
+ ), dim=0)
169
+ lookahead_lengths = torch.cat([lookahead_lengths, lookahead_lengths[-1:]])
170
+ lead_distance = lookahead_lengths[indices_lead]
171
+
172
+ # compute lead speed
173
+ states_lead = lookahead_neighbors[is_lead] # (num_lead, 4)
174
+ ori_speed_lead = states_lead[:, 3]
175
+ yaw_lead = states_lead[:, 2]
176
+ yaw_path = lookahead_path[indices_lead, 2]
177
+ lead_speed = ori_speed_lead * (yaw_lead - yaw_path).cos()
178
+
179
+ # compute acceleration
180
+ ego_speed = state[3]
181
+ delta_v = ego_speed - lead_speed
182
+ s_star = self.s0 + \
183
+ (ego_speed * self.T + ego_speed * delta_v / (2 * math.sqrt(self.a * self.b))).clamp(min=0)
184
+ acceleration = self.a * (1 - (ego_speed / self.v0) ** self.delta - (s_star / lead_distance) ** 2)
185
+ acceleration = acceleration.min()
186
+ else:
187
+ acceleration = self.a * (1 - (state[3] / self.v0) ** self.delta)
188
+
189
+ # compute the new state
190
+ target_distance = torch.norm(target[:2] - state[:2])
191
+ ratio = lookahead_distance / target_distance.clamp(min=1e-6)
192
+ ratio = ratio.clamp(max=1.0)
193
+
194
+ new_state = deepcopy(state)
195
+ new_state[:2] = state[:2] + ratio * (target[:2] - state[:2])
196
+ new_state[2] = torch.atan2(
197
+ state[2].sin() + ratio * (target[2].sin() - state[2].sin()),
198
+ state[2].cos() + ratio * (target[2].cos() - state[2].cos())
199
+ )
200
+ if is_end:
201
+ new_state[3] = 0
202
+ else:
203
+ new_state[3] = (state[3] + acceleration * dt).clamp(min=0)
204
+
205
+ return new_state
206
+
207
+
208
+ class AttackPlanner:
209
+ def __init__(self, pred_steps=20, ATTACK_FREQ = 3, best_k=1, device='cpu'):
210
+ self.device = device
211
+ self.predict_steps = pred_steps
212
+ self.best_k = best_k
213
+
214
+ self.planner = SplinePlanner(
215
+ device,
216
+ N_seg=self.predict_steps,
217
+ acce_grid=torch.linspace(-2, 5, 10).to(self.device),
218
+ acce_bound=[-6, 5],
219
+ vbound=[-2, 50]
220
+ )
221
+ self.planner.psi_bound = [-math.pi * 2, math.pi * 2]
222
+
223
+ self.exec_traj = None
224
+ self.exec_pointer = 1
225
+
226
+ def update(
227
+ self, state, unified_map, dt,
228
+ neighbors, attacked_states,
229
+ new_plan=True
230
+ ):
231
+ '''
232
+ Args:
233
+ state: current state of the vehicle, of size [x, y, yaw, speed]
234
+ vector_map: the vector map
235
+ attacked_states: future states of the attacked agent, of size (T, [x, y, yaw, speed])
236
+ neighbors: future states of the neighbors, of size (K, T, [x, y, yaw, speed])
237
+ new_plan: whether to generate a new plan
238
+ '''
239
+ assert self.exec_pointer > 0
240
+
241
+ # directly execute the current plan
242
+ if not new_plan:
243
+ if self.exec_traj is not None and \
244
+ self.exec_pointer < self.exec_traj.shape[0]:
245
+ next_state = self.exec_traj[self.exec_pointer]
246
+ self.exec_pointer += 1
247
+ return next_state
248
+ else:
249
+ new_plan = True
250
+
251
+ assert attacked_states.shape[0] == self.predict_steps
252
+
253
+ # state: [x, y, yaw, speed]
254
+ x, y, yaw, speed = state
255
+
256
+ # query vector map to get lanes
257
+ query_xyzr = np.array([x, y, 0, yaw + np.pi / 2])
258
+ # query_xyzr = unified_map.xyzr_local2world(np.array([x, y, 0, yaw]))
259
+ # lanes = unified_map.vector_map.get_lanes_within(query_xyzr[:3], dist=30)
260
+ # lanes = [unified_map.batch_xyzr_world2local(l.center.xyzh)[:, [0,1,3]] for l in lanes]
261
+ # lanes = [l.center.xyzh[:, [0,1,3]] for l in lanes]
262
+ lanes = None
263
+
264
+ # for lane in lanes:
265
+ # plt.plot(lane[:, 0], lane[:, 1], 'k--', linewidth=0.5, alpha=0.5)
266
+
267
+ # generate spline trajectories
268
+ x0 = torch.tensor([query_xyzr[0], query_xyzr[1], speed, query_xyzr[3]], device=self.device)
269
+ possible_trajs, xf_set = self.planner.gen_trajectories(x0, self.predict_steps * dt, lanes,
270
+ dyn_filter=True) # (num_trajs, T-1, [x, y, v, a, yaw, r, t])
271
+ if possible_trajs.shape[0] == 0:
272
+ trajs = constant_headaway(state[None], self.predict_steps, dt) # (1, T, [x, y, yaw, speed])
273
+ else:
274
+ trajs = torch.cat([
275
+ state[None, None].expand(possible_trajs.shape[0], -1, -1),
276
+ possible_trajs[..., [0, 1, 4, 2]]
277
+ ], dim=1)
278
+
279
+ # select the best trajectory
280
+ attack_distance = torch.norm(attacked_states[None, :, :2] - trajs[..., :2], dim=-1)
281
+ cost_attack = attack_distance.min(dim=1).values
282
+ cost_collision = (
283
+ torch.norm(neighbors[None, ..., :2] - trajs[:, None, :, :2], dim=-1).min(dim=-1).values < 2.0).sum(
284
+ dim=-1)
285
+ cost = cost_attack + 0.1 * cost_collision
286
+ values, indices = torch.topk(cost, self.best_k, largest=False)
287
+ random_index = torch.randint(0, self.best_k, (1,)).item()
288
+ selected_index = indices[random_index]
289
+ traj_best = trajs[selected_index]
290
+
291
+ # produce next state
292
+ self.exec_traj = traj_best
293
+ self.exec_traj[:, 2] -= np.pi / 2
294
+ self.exec_pointer = 1
295
+ next_state = self.exec_traj[self.exec_pointer]
296
+ # next_state[0] = -next_state[0]
297
+ self.exec_pointer += 1
298
+
299
+ return next_state
300
+
301
+
302
+ class ConstantPlanner:
303
+ def __init__(self):
304
+ return
305
+
306
+ def update(self, state, dt):
307
+ a, b, yaw, v = state
308
+ a = a - v * np.sin(yaw) * dt
309
+ b = b + v * np.cos(yaw) * dt
310
+ return torch.tensor([a, b, yaw, v])
311
+
312
+
313
+ class UnicyclePlanner:
314
+ def __init__(self, uc_path, speed=1.0):
315
+ self.uc_model = unicycle.restore(torch.load(uc_path, weights_only=False))
316
+ self.t = 0
317
+ self.speed = speed
318
+
319
+ def update(self, dt):
320
+ self.t += dt * self.speed
321
+ a, b, v, pitchroll, yaw, h = self.uc_model.forward(self.t)
322
+ # return torch.tensor([a, b, yaw, v]), pitchroll.detach().cpu(), h.item()
323
+ return torch.tensor([a, b, yaw, v])
code/sim/utils/launch_ad.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import time
3
+ import os
4
+
5
+
6
+ def launch(shell_path, cuda_id, output):
7
+ os.makedirs(output, exist_ok=True)
8
+ print(os.path.join(output, 'output.txt'))
9
+ print(shell_path, cuda_id, output)
10
+ with open(os.path.join(output, 'output.txt'), 'w') as f:
11
+ process = subprocess.Popen(
12
+ ["zsh", shell_path, cuda_id, output], stdout=f, stderr=f
13
+ )
14
+ return process
15
+
16
+
17
+ def check_alive(process, tolerant=100):
18
+ i = 0
19
+ while i < tolerant:
20
+ return_code = process.poll()
21
+ if return_code is not None:
22
+ print(f"The AD algorithm completed with return code {return_code}.")
23
+ process.kill()
24
+ return
25
+ elif i % 5 == 0:
26
+ print(f"The AD algorithm is still running, remaining tolerant {tolerant - i}.")
27
+ time.sleep(1)
28
+ i += 1
29
+ process.kill()
30
+ print("The AD algorithm process is killed.")
code/sim/utils/plan.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from scipy.spatial.transform import Rotation as SCR
4
+ import roma
5
+ from collections import namedtuple
6
+ from sim.utils.agent_controller import constant_headaway
7
+ from sim.utils import agent_controller
8
+ from collections import defaultdict
9
+ from trajdata import AgentType, UnifiedDataset
10
+ from trajdata.maps import MapAPI
11
+ from trajdata.simulation import SimulationScene
12
+ from sim.utils.sim_utils import rt2pose, pose2rt
13
+ from sim.utils.agent_controller import IDM, AttackPlanner, ConstantPlanner, UnicyclePlanner
14
+ import os
15
+ import json
16
+
17
+ Model = namedtuple('Models', ['model_path', 'controller', 'controller_args'])
18
+
19
+
20
+ class planner:
21
+ def __init__(self, plan_list, scene_path=None, dt=0.2, unified_map=None, ground=None):
22
+ self.unified_map = unified_map
23
+ self.ground = ground
24
+ self.PREDICT_STEPS = 20
25
+ self.NUM_NEIGHBORS = 3
26
+
27
+ self.rectify_angle = 0
28
+ if self.unified_map is not None:
29
+ self.rectify_angle = self.unified_map.rectify_angle
30
+
31
+ # plan_list: a, b, height, yaw, v, model_path, controller, controller_args: dict
32
+ self.stats, self.route, self.controller, self.ckpts, self.wlhs = {}, {}, {}, {}, {}
33
+ self.dt = dt
34
+ self.ATTACK_FREQ = 3
35
+ for iid, args in enumerate(plan_list):
36
+ if args[6] == "UnicyclePlanner":
37
+ # self.ckpts[f"agent_{iid}"] = os.path.join(scene_path, "ckpts", f"dynamic_{args[7]}_chkpnt30000.pth")
38
+ # self.wlhs[f'agent_{iid}'] = [2.0, 4.0, 1.5]
39
+ self.ckpts[f'agent_{iid}'] = os.path.join(args[5], 'gs.pth')
40
+ with open(os.path.join(args[5], 'wlh.json')) as f:
41
+ self.wlhs[f'agent_{iid}'] = json.load(f)
42
+ uc_configs = args[7]
43
+ self.controller[f"agent_{iid}"] = UnicyclePlanner(os.path.join(scene_path, f"unicycle_{uc_configs['uc_id']}.pth"), speed=uc_configs['speed'])
44
+ a, b, v, pitchroll, yaw, h = self.controller[f"agent_{iid}"].uc_model.forward(0.0)
45
+ self.stats[f'agent_{iid}'] = torch.tensor([a, b, args[2], yaw, v])
46
+ self.route[f'agent_{iid}'] = None
47
+ else:
48
+ model = Model(*args[5:])
49
+ self.stats[f'agent_{iid}'] = torch.tensor(args[:5]) # a, b, height, yaw, v
50
+ self.stats[f'agent_{iid}'][3] += self.rectify_angle
51
+ self.route[f'agent_{iid}'] = None
52
+ self.ckpts[f'agent_{iid}'] = os.path.join(model.model_path, 'gs.pth')
53
+ with open(os.path.join(model.model_path, 'wlh.json')) as f:
54
+ self.wlhs[f'agent_{iid}'] = json.load(f)
55
+ self.controller[f'agent_{iid}'] = getattr(agent_controller, model.controller)(**model.controller_args)
56
+ if model.controller == "AttackPlanner":
57
+ self.ATTACK_FREQ = model.controller_args["ATTACK_FREQ"]
58
+
59
+ def update_ground(self, ground):
60
+ self.ground = ground
61
+
62
+ def update_agent_route(self):
63
+ assert self.unified_map is not None, "Map shouldn't be None to forecast agent path"
64
+ for iid, stat in self.stats.items():
65
+ path = self.unified_map.get_route(stat)
66
+ if path is None:
67
+ print("path not found at ", self.stats)
68
+ if path is not None:
69
+ self.route[iid] = torch.from_numpy(np.hstack([path[:, :2], path[:, -1:]]))
70
+
71
+ def ground_height(self, u, v):
72
+ cam_poses, cam_height, _ = self.ground
73
+ cam_poses = torch.from_numpy(cam_poses)
74
+ cam_dist = np.sqrt(
75
+ (cam_poses[:-1, 0, 3] - u) ** 2 + (cam_poses[:-1, 2, 3] - v) ** 2
76
+ )
77
+ nearest_cam_idx = np.argmin(cam_dist, axis=0)
78
+ nearest_c2w = cam_poses[nearest_cam_idx]
79
+
80
+ nearest_w2c = np.linalg.inv(nearest_c2w)
81
+ uv_local = nearest_w2c[:3, :3] @ np.array([u, 0, v]) + nearest_w2c[:3, 3]
82
+ uv_local[1] = 0
83
+ uv_world = nearest_c2w[:3, :3] @ uv_local + nearest_c2w[:3, 3]
84
+
85
+ return uv_world[1] + cam_height
86
+
87
+ def plan_traj(self, t, ego_stats):
88
+ all_stats = [ego_stats]
89
+ for iid, stat in self.stats.items():
90
+ all_stats.append(stat[[0, 1, 3, 4]]) # a, b, yaw, v
91
+ all_stats = torch.stack(all_stats, dim=0)
92
+ future_states = constant_headaway(all_stats, num_steps=self.PREDICT_STEPS, dt=self.dt)
93
+
94
+ b2ws = {}
95
+ for iid, stat in self.stats.items():
96
+ # find closet neighbors
97
+ curr_xy_agents = all_stats[:, :2]
98
+ distance_agents = torch.norm(curr_xy_agents - stat[:2], dim=-1)
99
+ neighbor_idx = torch.argsort(distance_agents)[1:self.NUM_NEIGHBORS + 1]
100
+ neighbors = future_states[neighbor_idx]
101
+
102
+ controller = self.controller[iid]
103
+ if type(controller) is IDM:
104
+ next_xyrv = controller.update(state=stat[[0, 1, 3, 4]], path=self.route[iid], dt=self.dt,
105
+ neighbors=neighbors)
106
+ elif type(controller) is AttackPlanner:
107
+ safe_neighbors = neighbors[1:, ...]
108
+ next_xyrv = controller.update(state=stat[[0, 1, 3, 4]], unified_map=self.unified_map, dt=0.1,
109
+ neighbors=safe_neighbors, attacked_states=future_states[0],
110
+ new_plan=((t // self.dt) % self.ATTACK_FREQ == 0))
111
+ elif type(controller) is ConstantPlanner:
112
+ next_xyrv = controller.update(state=stat[[0, 1, 3, 4]], dt=self.dt)
113
+ elif type(controller) is UnicyclePlanner:
114
+ next_xyrv = controller.update(dt=self.dt)
115
+ else:
116
+ raise NotImplementedError
117
+ next_stat = torch.zeros_like(stat)
118
+ next_stat[[0, 1, 3, 4]] = next_xyrv.float()
119
+ next_stat[2] = stat[2]
120
+ self.stats[iid] = next_stat
121
+
122
+ b2w = np.eye(4)
123
+ h = self.ground_height(next_xyrv[0].numpy(), next_xyrv[1].numpy())
124
+ if type(controller) is UnicyclePlanner:
125
+ # b2w[:3, :3] = SCR.from_euler('xzy', [pitch_roll[0], pitch_roll[1], stat[3]]).as_matrix()
126
+ b2w[:3, :3] = SCR.from_euler('y', [-stat[3]]).as_matrix()
127
+ b2w[:3, 3] = np.array([next_stat[0], h + stat[2], next_stat[1]])
128
+ else:
129
+ b2w[:3, :3] = SCR.from_euler('y', [-stat[3] - np.pi / 2 - self.rectify_angle]).as_matrix()
130
+ b2w[:3, 3] = np.array([next_stat[0], h + stat[2], next_stat[1]])
131
+ b2ws[iid] = torch.tensor(b2w).float().cuda()
132
+
133
+ return [b2ws, {}]
134
+
135
+
136
+ class UnifiedMap:
137
+ def __init__(self, datapath, version, scene_name):
138
+ self.datapath = datapath
139
+ self.version = version
140
+
141
+ self.dataset = UnifiedDataset(
142
+ desired_data=[self.version],
143
+ data_dirs={
144
+ self.version: self.datapath,
145
+ },
146
+ cache_location="/app/app_datas/nusc_map_cache",
147
+ only_types=[AgentType.VEHICLE],
148
+ agent_interaction_distances=defaultdict(lambda: 50.0),
149
+ desired_dt=0.1,
150
+ num_workers=4,
151
+ verbose=True,
152
+ )
153
+
154
+ self.map_api = MapAPI(self.dataset.cache_path)
155
+
156
+ self.scene = None
157
+ for scene in list(self.dataset.scenes()):
158
+ if scene.name == scene_name:
159
+ self.scene = scene
160
+ assert self.scene is not None, f"Can't find scene {scene_name}"
161
+ self.vector_map = self.map_api.get_map(
162
+ f"{self.version}:{self.scene.location}"
163
+ )
164
+ self.ego_start_pos, self.ego_start_yaw = self.get_start_pose()
165
+ self.rectify_angle = 0
166
+ if self.ego_start_yaw < 0:
167
+ self.ego_start_yaw += np.pi
168
+ self.rectify_angle = np.pi
169
+ self.PATH_LENGTH = 100
170
+
171
+ def get_start_pose(self):
172
+ sim_scene: SimulationScene = SimulationScene(
173
+ env_name=self.version,
174
+ scene_name=f"sim_scene",
175
+ scene=self.scene,
176
+ dataset=self.dataset,
177
+ init_timestep=0,
178
+ freeze_agents=True,
179
+ )
180
+ obs = sim_scene.reset()
181
+ assert obs.agent_name[0] == 'ego', 'The first agent is not ego'
182
+ # We consider position of the first ego frame as origin
183
+ # This suppose is ok when the first frame front camera pose is set as origin
184
+ ego_start_pos = obs.curr_agent_state.position[0]
185
+ ego_start_yaw = obs.curr_agent_state.heading[0]
186
+ return ego_start_pos.numpy(), ego_start_yaw.item()
187
+
188
+ def xyzr_local2world(self, stat):
189
+ alpha = np.arctan(stat[0] / stat[1])
190
+ beta = self.ego_start_yaw - alpha
191
+ dist = np.linalg.norm(stat[:2])
192
+ delta_x = dist * np.cos(beta)
193
+ delta_y = dist * np.sin(beta)
194
+
195
+ world_stat = np.zeros(4)
196
+ world_stat[0] = delta_x + self.ego_start_pos[0]
197
+ world_stat[1] = delta_y + self.ego_start_pos[1]
198
+ world_stat[3] = stat[3] + self.ego_start_yaw
199
+
200
+ return world_stat
201
+
202
+ def batch_xyzr_world2local(self, stat):
203
+ beta = np.arctan((stat[:, 1] - self.ego_start_pos[1]) / (stat[:, 0] - self.ego_start_pos[0]))
204
+ alpha = self.ego_start_yaw - beta
205
+ dist = np.linalg.norm(stat[:, :2] - self.ego_start_pos, axis=1)
206
+ delta_x = dist * np.sin(alpha)
207
+ delta_y = dist * np.cos(alpha)
208
+
209
+ local_stat = np.zeros_like(stat)
210
+ local_stat[:, 0] = delta_x
211
+ local_stat[:, 1] = delta_y
212
+ local_stat[:, 3] = stat[:, 3] - self.ego_start_yaw
213
+
214
+ return local_stat
215
+
216
+ def get_route(self, stat):
217
+ # stat: a, b, height, yaw, v
218
+ curr_xyzr = self.xyzr_local2world(stat[:4].numpy())
219
+
220
+ # lanes = self.vector_map.get_current_lane(curr_xyzr, max_dist=5, max_heading_error=np.pi/3)
221
+ lanes = self.vector_map.get_current_lane(curr_xyzr)
222
+
223
+ if len(lanes) > 0:
224
+ curr_lane = lanes[0]
225
+ path = self.batch_xyzr_world2local(curr_lane.center.xyzh)
226
+ total_path_length = np.linalg.norm(curr_lane.center.xy[1:] - curr_lane.center.xy[:-1], axis=1).sum()
227
+ # random select next lanes until reach PATH_LENGTH
228
+ while total_path_length < self.PATH_LENGTH:
229
+ next_lanes = list(curr_lane.next_lanes)
230
+ if len(next_lanes) == 0:
231
+ break
232
+ next_lane = self.vector_map.get_road_lane(next_lanes[np.random.randint(len(next_lanes))])
233
+ path = np.vstack([path, self.batch_xyzr_world2local(next_lane.center.xyzh)])
234
+ total_path_length += np.linalg.norm(next_lane.center.xy[1:] - next_lane.center.xy[:-1], axis=1).sum()
235
+ curr_lane = next_lane
236
+ else:
237
+ path = None
238
+ return path
code/sim/utils/score_calculator.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from matplotlib import pyplot as plt
3
+ from matplotlib.patches import Rectangle, Polygon
4
+ from shapely.geometry import LineString, Point
5
+ import numpy as np
6
+ from shapely.geometry import Polygon as ShapelyPolygon
7
+ from shapely.geometry import Point
8
+ from collections import defaultdict
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ import threading
11
+ import argparse
12
+ import os
13
+ import open3d as o3d
14
+ import torch
15
+ from scipy.spatial.transform import Rotation as SCR
16
+
17
+ ego_verts_canonic = np.array([[0.5, 0.5, 0], [0.5, -0.5, 0], [0.5, 0.5, 1.0], [0.5, -0.5, 1.0],
18
+ [-0.5, -0.5, 0], [-0.5, 0.5, 0], [-0.5, -0.5, 1.0], [-0.5, 0.5, 1.0]])
19
+
20
+ # Define boundaries
21
+ boundaries = {
22
+ 'max_abs_lat_accel': 4.89, # [m/s^2]
23
+ 'max_lon_accel': 2.40, # [m/s^2]
24
+ 'min_lon_accel': -4.05, # [m/s^2]
25
+ 'max_abs_yaw_accel': 1.93, # [rad/s^2]
26
+ 'max_abs_lon_jerk': 8.37, # [m/s^3],
27
+ 'max_abs_yaw_rate': 0.95, # [rad/s]
28
+ }
29
+
30
+ score_weight = {
31
+ 'ttc': 5,
32
+ 'c': 2,
33
+ 'ep': 5,
34
+ }
35
+
36
+ def create_rectangle(center_x, center_y, width, length, yaw):
37
+ """Create a rectangle polygon."""
38
+ cos_yaw = np.cos(yaw)
39
+ sin_yaw = np.sin(yaw)
40
+
41
+ x_offs = [length/2, length/2, -length/2, -length/2]
42
+ y_offs = [width/2, -width/2, -width/2, width/2]
43
+
44
+ x_pts = [center_x + x_off*cos_yaw - y_off *
45
+ sin_yaw for x_off, y_off in zip(x_offs, y_offs)]
46
+ y_pts = [center_y + x_off*sin_yaw + y_off *
47
+ cos_yaw for x_off, y_off in zip(x_offs, y_offs)]
48
+
49
+ return ShapelyPolygon(zip(x_pts, y_pts))
50
+
51
+ def bg_collision_det(points, box):
52
+ O, A, B, C = box[0], box[1], box[2], box[5]
53
+ OA = A - O
54
+ OB = B - O
55
+ OC = C - O
56
+ POA, POB, POC = (points @ OA[..., None])[:, 0], (points @ OB[..., None])[:, 0], (points @ OC[..., None])[:, 0]
57
+ mask = (torch.dot(O, OA) < POA) & (POA < torch.dot(A, OA)) & \
58
+ (torch.dot(O, OB) < POB) & (POB < torch.dot(B, OB)) & \
59
+ (torch.dot(O, OC) < POC) & (POC < torch.dot(C, OC))
60
+ return True if torch.sum(mask) > 100 else False
61
+
62
+
63
+ class ScoreCalculator:
64
+ def __init__(self, data):
65
+ self.data = data
66
+
67
+ self.pdms = 0.0
68
+ self.driving_score = None
69
+ pass
70
+
71
+ def transform_to_ego_frame(self, traj, ego_box):
72
+ """
73
+ Transform trajectory from global frame to ego-centric frame.
74
+
75
+ :param traj: List of tuples (x, y, yaw) in global frame
76
+ :param ego_box: Tuple (x, y, z, w, l, h, yaw) of ego vehicle in global frame
77
+ :return: Numpy array of transformed trajectory
78
+ """
79
+ ego_x, ego_y, _, _, _, _, ego_yaw = ego_box
80
+
81
+ # Create rotation matrix
82
+ c, s = np.cos(-ego_yaw), np.sin(-ego_yaw)
83
+ R = np.array([[c, -s], [s, c]])
84
+
85
+ # Transform each point
86
+ transformed_traj = []
87
+ for x, y, yaw in traj:
88
+ # Translate
89
+ x_translated, y_translated = x - ego_x, y - ego_y
90
+
91
+ # Rotate
92
+ x_rotated, y_rotated = R @ np.array([x_translated, y_translated])
93
+
94
+ # Adjust yaw
95
+ yaw_adjusted = yaw - ego_yaw
96
+
97
+ transformed_traj.append((x_rotated, y_rotated, yaw_adjusted))
98
+
99
+ return np.array(transformed_traj)
100
+
101
+ def get_vehicle_corners(self, x, y, yaw, length, width):
102
+ """
103
+ Calculate the corner points of the vehicle given its position, orientation, and dimensions.
104
+
105
+ :param x: x-coordinate of the vehicle's center
106
+ :param y: y-coordinate of the vehicle's center
107
+ :param yaw: orientation of the vehicle in radians
108
+ :param length: length of the vehicle
109
+ :param width: width of the vehicle
110
+ :return: numpy array of corner coordinates (4x2)
111
+ """
112
+ c, s = np.cos(yaw), np.sin(yaw)
113
+ front_left = np.array([x + c * length / 2 - s * width / 2,
114
+ y + s * length / 2 + c * width / 2])
115
+ front_right = np.array([x + c * length / 2 + s * width / 2,
116
+ y + s * length / 2 - c * width / 2])
117
+ rear_left = np.array([x - c * length / 2 - s * width / 2,
118
+ y - s * length / 2 + c * width / 2])
119
+ rear_right = np.array([x - c * length / 2 + s * width / 2,
120
+ y - s * length / 2 - c * width / 2])
121
+ return np.array([front_left, front_right, rear_right, rear_left])
122
+
123
+ def plot_trajectory_on_drivable_mask(self, drivable_mask, transformed_traj, vehicle_width, vehicle_length):
124
+ """
125
+ Plot the transformed trajectory and vehicle bounding boxes on the drivable mask.
126
+
127
+ :param drivable_mask: 2D numpy array representing the drivable area (200x200)
128
+ :param transformed_traj: Numpy array of transformed trajectory points
129
+ :param vehicle_width: Width of the vehicle in meters
130
+ :param vehicle_length: Length of the vehicle in meters
131
+ """
132
+
133
+ plt.figure(figsize=(10, 10))
134
+ plt.imshow(drivable_mask, cmap='gray', extent=[-50, 50, -50, 50])
135
+
136
+ # Scale factor (200 pixels represent 100 meters)
137
+ scale_factor = 200 / 100 # pixels per meter
138
+
139
+ # Plot trajectory
140
+ x_coords, y_coords, yaws = transformed_traj.T
141
+ plt.plot(x_coords, y_coords, 'r-', linewidth=2)
142
+
143
+ # Plot vehicle bounding boxes
144
+ for x, y, yaw in transformed_traj:
145
+ corners = self.get_vehicle_corners(
146
+ x, y, yaw, vehicle_length, vehicle_width)
147
+ plt.gca().add_patch(Polygon(corners, fill=False, edgecolor='blue'))
148
+
149
+ # Plot start and end points
150
+ plt.plot(x_coords[0], y_coords[0], 'go', markersize=10, label='Start')
151
+ plt.plot(x_coords[-1], y_coords[-1], 'bo', markersize=10, label='End')
152
+
153
+ plt.title('Trajectory and Vehicle Bounding Boxes on Drivable Mask')
154
+ plt.legend()
155
+ plt.xlabel('x (meters)')
156
+ plt.ylabel('y (meters)')
157
+ plt.grid(True)
158
+ plt.tight_layout()
159
+ plt.show()
160
+
161
+ def _calculate_drivable_area_compliance(self, ground, traj, vehicle_width, vehicle_length):
162
+ m, n = 2, 2
163
+ dac = 1.0
164
+ for traj_i, (x, y, yaw) in enumerate(traj):
165
+ cnt = 0
166
+ c, s = np.cos(yaw), np.sin(yaw)
167
+ R = np.array([[c, -s], [s, c]])
168
+ ground_in_ego = (np.linalg.inv(R) @ (ground + np.array([-x, -y])).T).T
169
+ x_bins = np.linspace(-vehicle_length/2, vehicle_length/2, m+1)
170
+ y_bins = np.linspace(-vehicle_width/2, vehicle_width/2, n+1)
171
+ for xi in range(m):
172
+ for yi in range(n):
173
+ min_x, max_x = x_bins[xi], x_bins[xi+1]
174
+ min_y, max_y = y_bins[yi], y_bins[yi+1]
175
+ ground_mask = (min_x < ground_in_ego[:, 0]) & (ground_in_ego[:, 0] < max_x) & \
176
+ (min_y < ground_in_ego[:, 1]) & (ground_in_ego[:, 1] < max_y)
177
+ if ground_mask.sum() > 0:
178
+ cnt += 1
179
+ drivable_ratio = cnt / (m*n)
180
+ if drivable_ratio < 0.3:
181
+ return 0
182
+ elif drivable_ratio < 0.5:
183
+ dac = 0.5
184
+ return dac
185
+
186
+ def _calculate_progress(self, planned_traj, ref_taj):
187
+ def calculate_curve_length(points):
188
+ """Calculate the total length of a curve given by a set of points."""
189
+ curve = LineString(points)
190
+ return curve.length
191
+
192
+ def project_curve_onto_curve(curve_a, curve_b):
193
+ """Project curve_b onto curve_a and calculate the projected length."""
194
+ projected_points = []
195
+ for point in curve_b.coords:
196
+ projected_point = curve_a.interpolate(
197
+ curve_a.project(Point(point)))
198
+ projected_points.append(projected_point)
199
+ projected_curve = LineString(projected_points)
200
+ return projected_curve.length
201
+
202
+ # Create Shapely LineString objects
203
+ plan_curve = LineString([(x, y) for x, y, _ in planned_traj])
204
+ ref_curve = LineString([(x, y) for x, y, _ in ref_taj])
205
+
206
+ # Calculate lengths
207
+ plan_curve_length = calculate_curve_length(plan_curve)
208
+ ref_curve_length = calculate_curve_length(ref_curve)
209
+ projected_length = project_curve_onto_curve(ref_curve, plan_curve)
210
+ # print(f"plan_curve_length: {plan_curve_length}, ref_curve_length: {ref_curve_length}, project plan to ref_length: {projected_length}")
211
+
212
+ ep = 0.0
213
+ if max(plan_curve_length, ref_curve_length) < 5.0 or ref_curve_length < 1e-6:
214
+ ep = 1.0
215
+ else:
216
+ ep = projected_length / ref_curve_length
217
+ return ep
218
+
219
+ def _calculate_is_comfortable(self, traj, timestep):
220
+ """
221
+ Check if all kinematic parameters of a trajectory are within specified boundaries.
222
+
223
+ :param traj: List of tuples (x, y, yaw) representing the trajectory, in ego's local frame
224
+ :param timestep: Time interval between trajectory points in seconds
225
+ :return: 1.0 if all parameters are within boundaries, 0.0 otherwise
226
+ """
227
+
228
+ def calculate_trajectory_kinematics(traj, timestep):
229
+ """
230
+ Calculate kinematic parameters for a given trajectory.
231
+
232
+ :param traj: List of tuples (x, y, yaw) for each point in the trajectory
233
+ :param timestep: Time interval between each point in the trajectory
234
+ :return: Dictionary containing lists of calculated parameters
235
+ """
236
+ # Convert trajectory to numpy array for easier calculations
237
+ x, y, yaw = zip(*traj)
238
+ x, y, yaw = np.array(x), np.array(y), np.array(yaw)
239
+
240
+ # Calculate velocities
241
+ dx = np.diff(x) / timestep
242
+ dy = np.diff(y) / timestep
243
+
244
+ # Calculate yaw rate
245
+ dyaw = np.diff(yaw)
246
+ dyaw = np.where(dyaw > np.pi, dyaw - 2*np.pi, dyaw)
247
+ dyaw = np.where(dyaw < -np.pi, dyaw + 2*np.pi, dyaw)
248
+ dyaw = dyaw / timestep
249
+ ddyaw = np.diff(dyaw) / timestep
250
+
251
+ # Calculate speed
252
+ speed = np.sqrt(dx**2 + dy**2)
253
+
254
+ # Calculate accelerations
255
+ accel = np.diff(speed) / timestep
256
+ jerk = np.diff(accel) / timestep
257
+
258
+ # Calculate yaw rate (already calculated as dyaw)
259
+ yaw_rate = dyaw
260
+ # Calculate yaw acceleration
261
+ yaw_accel = ddyaw
262
+
263
+ lon_accel = accel
264
+ lat_accel = np.zeros_like(lon_accel)
265
+ lon_jerk = jerk
266
+
267
+ # Pad arrays to match the original trajectory length
268
+ yaw_rate = np.pad(yaw_rate, (0, 1), 'edge')
269
+ yaw_accel = np.pad(yaw_accel, (0, 2), 'edge')
270
+ lon_accel = np.pad(lon_accel, (0, 2), 'edge')
271
+ lat_accel = np.pad(lat_accel, (0, 2), 'edge')
272
+ lon_jerk = np.pad(lon_jerk, (0, 3), 'edge')
273
+
274
+ return {
275
+ 'speed': speed,
276
+ 'yaw_rate': yaw_rate,
277
+ 'yaw_accel': yaw_accel,
278
+ 'lon_accel': lon_accel,
279
+ 'lat_accel': lat_accel,
280
+ 'lon_jerk': lon_jerk,
281
+ }
282
+
283
+ # Calculate kinematic parameters
284
+ if len(traj) < 4:
285
+ return 1.0
286
+
287
+ kinematics = calculate_trajectory_kinematics(traj, timestep)
288
+
289
+ # Check each parameter against its boundary
290
+ checks = [
291
+ np.all(np.abs(kinematics['lat_accel']) <=
292
+ boundaries['max_abs_lat_accel']),
293
+ np.all(kinematics['lon_accel'] <= boundaries['max_abs_lat_accel']),
294
+ np.all(kinematics['lon_accel'] >= boundaries['min_lon_accel']),
295
+ np.all(np.abs(kinematics['lon_jerk']) <=
296
+ boundaries['max_abs_lon_jerk']),
297
+ np.all(np.abs(kinematics['yaw_accel']) <=
298
+ boundaries['max_abs_yaw_accel']),
299
+ np.all(np.abs(kinematics['yaw_rate']) <=
300
+ boundaries['max_abs_yaw_rate'])
301
+ ]
302
+
303
+ # if not all(checks):
304
+ # print(traj)
305
+ # print(kinematics)
306
+ print(f"comfortable: {checks}")
307
+
308
+ # Return 1.0 if all checks pass, 0.0 otherwise
309
+ return 1.0 if all(checks) else 0.0
310
+
311
+ def _calculate_no_collision(self, ego_box, planned_traj, obs_lists, scene_xyz):
312
+ ego_x, ego_y, z, ego_w, ego_l, ego_h, ego_yaw = ego_box
313
+ ego_verts_local = ego_verts_canonic * np.array([ego_l, ego_w, ego_h])
314
+ for idx in range(planned_traj.shape[0]):
315
+ ego_x, ego_y, ego_yaw = planned_traj[idx] # ego_state= (x,y,yaw)
316
+ ego_trans_mat = np.eye(4)
317
+ ego_trans_mat[:3, :3] = SCR.from_euler('z', ego_yaw).as_matrix()
318
+ ego_trans_mat[:3, 3] = np.array([ego_x, ego_y, z])
319
+ ego_verts_global = (ego_trans_mat[:3, :3] @ ego_verts_local.T).T + ego_trans_mat[:3, 3]
320
+ ego_verts_global = torch.from_numpy(ego_verts_global).float().cuda()
321
+ bk_collision = bg_collision_det(scene_xyz, ego_verts_global)
322
+ # scene_local = scene_xyz - np.array([ego_x, ego_y, z])
323
+ # bk_collision = np.sum(
324
+ # (-ego_l/2 < scene_local[:, 0]) & (scene_local[:, 0] < ego_l/2) & \
325
+ # (-ego_w/2 < scene_local[:, 1]) & (scene_local[:, 1] < ego_w/2) & \
326
+ # (-ego_h/2 < scene_local[:, 2]) & (scene_local[:, 2] < ego_h/2)
327
+ # ) > 100
328
+ if bk_collision:
329
+ print(f"collision with background detected! @ timestep{idx}")
330
+ return 0.0
331
+ ego_poly = create_rectangle(ego_x, ego_y, ego_w, ego_l, ego_yaw)
332
+ obs_list = obs_lists[idx if idx < len(obs_lists) else -1]
333
+ for obs in obs_list:
334
+ # obs = (x,y,z,w,l,h,yaw)
335
+ obs_x, obs_y, _, obs_w, obs_l, _, obs_yaw = obs
336
+ obs_poly = create_rectangle(
337
+ obs_x, obs_y, obs_w, obs_l, obs_yaw)
338
+ if ego_poly.intersects(obs_poly):
339
+ print(f"collision with obstacle detected! @ timestep{idx}")
340
+ print(
341
+ f"ego_poly: {(ego_x, ego_y, ego_yaw,obs_w, obs_l)}, obs_poly: {(obs_x, obs_y, obs_yaw,obs_w, obs_l )}")
342
+ return 0.0 # Collision detected
343
+ return 1.0
344
+
345
+ def _calculate_time_to_collision(self, ego_box, planned_traj, obs_lists, scene_xyz, timestep):
346
+ # breakpoint()
347
+ t_list = [0.5, 1] # ttc time
348
+
349
+ for t in t_list:
350
+ # Calculate velocities
351
+ velocities = np.diff(planned_traj[:, :2], axis=0) / timestep
352
+
353
+ # Use the velocity of the second point for the first point
354
+ velocities = np.vstack([velocities[0], velocities])
355
+
356
+ # Calculate the displacement
357
+ displacement = velocities * t
358
+
359
+ # Create the new trajectory
360
+ new_traj = planned_traj.copy()
361
+ new_traj[:, :2] += displacement
362
+
363
+ is_collide_score = self._calculate_no_collision(
364
+ ego_box, new_traj, obs_lists, scene_xyz)
365
+ if is_collide_score == 0.0:
366
+ print(f" failed to pass ttc collision check, t={t}")
367
+ # breakpoint()
368
+ return 0.0
369
+
370
+ return 1.0
371
+
372
+ def calculate(self, ):
373
+
374
+ print(f"current exp has {len(self.data['frames'])} frames")
375
+ if len(self.data['frames']) == 0:
376
+ return None
377
+ # todo: time_step need modify
378
+ score_list = {}
379
+ for i in range(0, len(self.data['frames']), 1):
380
+ frame = self.data['frames'][i]
381
+ if frame['is_key_frame'] == False:
382
+ continue
383
+
384
+ print(f"frame {i} / {len(self.data['frames'])}")
385
+ timestamp = frame['time_stamp']
386
+ planned_last_timestamp = timestamp + \
387
+ len(frame['planned_traj']['traj']) * \
388
+ frame['planned_traj']['timestep']
389
+ ego_x, ego_y, _, ego_w, ego_l, _, ego_yaw = frame['ego_box']
390
+ # frame['planned_traj']['traj']
391
+ if len(frame['planned_traj']['traj'])<2:
392
+ continue
393
+ traj = frame['planned_traj']['traj']
394
+
395
+ planned_traj = np.concatenate(([np.array([ego_x, ego_y, ego_yaw])], traj), axis=0)
396
+ # print(planned_traj)
397
+
398
+ # if the car is stopped, there may be error in the yaw of the planned trajectory
399
+ traj_distance = np.linalg.norm(planned_traj[-1, :2] - planned_traj[0, :2] )
400
+ if traj_distance<1:
401
+ planned_traj[:, 2] = planned_traj[0, 2] # set all yaw to the first yaw
402
+
403
+ current_timestamp = timestamp
404
+ current_frame_idx = i
405
+ obs_lists = []
406
+ while current_timestamp <= planned_last_timestamp+1e-5:
407
+ if abs(current_timestamp - self.data['frames'][current_frame_idx]['time_stamp']) < 1e-5:
408
+ obs_list = []
409
+ for idx, obj in enumerate(self.data['frames'][current_frame_idx]['obj_boxes']):
410
+ # obs_list.append(obj)
411
+ if self.data['frames'][current_frame_idx]['obj_names'][idx] == 'car':
412
+ obs_list.append(obj)
413
+ obs_lists.append(obs_list)
414
+ current_timestamp += frame['planned_traj']['timestep']
415
+
416
+ current_frame_idx += 1
417
+ if current_frame_idx >= len(self.data['frames']):
418
+ break
419
+
420
+ # breakpoint()
421
+ # plt.imshow(frame['drivable_mask'].astype(np.uint8))
422
+ # plt.show()
423
+ # transformed_traj = self.transform_to_ego_frame(frame['planned_traj']['traj'], frame['ego_box'])
424
+ transformed_traj = self.transform_to_ego_frame(
425
+ planned_traj, frame['ego_box'])
426
+ # breakpoint()
427
+
428
+ score_nc = self._calculate_no_collision(
429
+ frame['ego_box'], planned_traj, obs_lists, self.data['scene_xyz'])
430
+ # score_nc = 0.0 if frame['collision'] else 1.0
431
+ score_dac = self._calculate_drivable_area_compliance(
432
+ self.data['ground_xy'], planned_traj, ego_w, ego_l)
433
+ score_ttc = self._calculate_time_to_collision(
434
+ frame['ego_box'], planned_traj, obs_lists, self.data['scene_xyz'], frame['planned_traj']['timestep'])
435
+ score_c = self._calculate_is_comfortable(
436
+ transformed_traj, frame['planned_traj']['timestep'])
437
+ # score_ep = self._calculate_progress(
438
+ # planned_traj, ref_traj)
439
+
440
+ score_pdms = score_nc*score_dac*(score_weight['ttc']*score_ttc+score_weight['c']*score_c)/(
441
+ score_weight['ttc']+score_weight['c'])
442
+ print('nc, dac, ttc, com, pdms', [score_nc, score_dac, score_ttc, score_c, score_pdms])
443
+ score_list[timestamp] = {'nc': score_nc, 'dac': score_dac,
444
+ 'ttc': score_ttc, 'c': score_c, 'pdms': score_pdms}
445
+
446
+ totals = {metric: 0 for metric in next(iter(score_list.values()))}
447
+ for scores in score_list.values():
448
+ for metric, value in scores.items():
449
+ totals[metric] += value
450
+
451
+ # avg scores
452
+ num_entries = len(score_list)
453
+ averages = {metric: total / num_entries for metric,
454
+ total in totals.items()}
455
+
456
+ # writer.writerow(averages.values())
457
+
458
+ mean_score = averages['pdms']
459
+ route_completion = max([f['rc'] for f in self.data['frames']])
460
+ route_completion = route_completion if route_completion < 1 else 1.0
461
+ driving_score = mean_score*route_completion
462
+ return mean_score, route_completion, driving_score, averages
463
+
464
+
465
+ def calculate(data):
466
+ print(f"this pkl file contains {len(data)} experiment records.")
467
+ # print(f"the first item metadata is {data[0]['metas']}.")
468
+ # breakpoint()
469
+
470
+ def process_exp_data(exp_data):
471
+ score_calc = ScoreCalculator(exp_data)
472
+ score = score_calc.calculate()
473
+ print(f"The score of experiment is {score}.")
474
+ final_score_dict = score[3]
475
+ final_score_dict['rc'] = score[1]
476
+ final_score_dict['hdscore'] = score[2]
477
+ return final_score_dict
478
+
479
+ def multi_threaded_process(data, max_workers=None):
480
+ all_averages = []
481
+
482
+ # Using thread locks for thread-safe append operations
483
+ lock = threading.Lock()
484
+
485
+ def append_result(future):
486
+ result = future.result()
487
+ with lock:
488
+ all_averages.append(result)
489
+
490
+ with ThreadPoolExecutor(max_workers=1) as executor:
491
+ futures = [executor.submit(process_exp_data, exp_data)
492
+ for exp_data in data]
493
+ for future in futures:
494
+ future.add_done_callback(append_result)
495
+
496
+ return all_averages
497
+
498
+ all_averages = multi_threaded_process(data)
499
+
500
+ collected_values = defaultdict(list)
501
+ for averages in all_averages:
502
+ for key, value in averages.items():
503
+ collected_values[key].append(value)
504
+
505
+ # Calculation of mean and standard deviation for each indicator
506
+ results = {}
507
+ for key, values in collected_values.items():
508
+ avg = np.mean(values)
509
+ # std = np.std(values)
510
+ results[key] = f"{avg:.4f}"
511
+
512
+ # Output Results
513
+ print("=============================Results=============================")
514
+ for key, value in results.items():
515
+ print(f"'{key}': {value}")
516
+ return results
517
+
518
+ def parse_data(test_path):
519
+ data_file_name = os.path.join(test_path, "data.pkl")
520
+ ground_pcd_file_name = os.path.join(test_path, "ground.ply")
521
+ scene_pcd_file_name = os.path.join(test_path, "scene.ply")
522
+
523
+ # Open the file and load the data
524
+ with open(data_file_name, 'rb') as f:
525
+ data = pickle.load(f)
526
+
527
+ ground_pcd = o3d.io.read_point_cloud(ground_pcd_file_name)
528
+ ground_xyz = np.asarray(ground_pcd.points) # in camera coordinates
529
+ ground_xy = np.stack([ground_xyz[:, 2], -ground_xyz[:, 0]], axis=1) # in imu coordinates
530
+
531
+ scene_pcd = o3d.io.read_point_cloud(scene_pcd_file_name)
532
+ scene_xyz = np.asarray(scene_pcd.points) # in camera coordinates
533
+ # in imu coordinates
534
+ scene_xyz = np.stack([scene_xyz[:, 2], -scene_xyz[:, 0], -scene_xyz[:, 1]], axis=1)
535
+
536
+ data[0]['ground_xy'] = ground_xy
537
+ data[0]['scene_xyz'] = torch.from_numpy(scene_xyz).cuda()
538
+ # data[0]['scene_xyz'] = scene_xyz
539
+ return data
540
+
541
+
542
+ def hugsim_evaluate(test_data, ground_xyz, scene_xyz):
543
+ ground_xy = np.stack([ground_xyz[:, 2], -ground_xyz[:, 0]], axis=1) # in imu coordinates
544
+ scene_xyz = np.stack([scene_xyz[:, 2], -scene_xyz[:, 0], -scene_xyz[:, 1]], axis=1)
545
+ test_data[0]['ground_xy'] = ground_xy
546
+ test_data[0]['scene_xyz'] = torch.from_numpy(scene_xyz).float().cuda()
547
+ results = calculate(test_data)
548
+ return results
549
+
550
+
551
+ def get_opts():
552
+ parser = argparse.ArgumentParser()
553
+ parser.add_argument('--test_path', type=str, required=True)
554
+ return parser.parse_args()
555
+
556
+
557
+ if __name__ == "__main__":
558
+ args = get_opts()
559
+ data = parse_data(args.test_path)
560
+
561
+ # Call the main function with the loaded data
562
+ calculate(data)
code/sim/utils/sim_utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.spatial.transform import Rotation as SCR
3
+ import math
4
+ from scene.cameras import Camera
5
+ from sim.ilqr.lqr import plan2control
6
+ from omegaconf import OmegaConf
7
+
8
+ def rt2pose(r, t, degrees=False):
9
+ pose = np.eye(4)
10
+ pose[:3, :3] = SCR.from_euler('XYZ', r, degrees=degrees).as_matrix()
11
+ pose[:3, 3] = t
12
+ return pose
13
+
14
+ def pose2rt(pose, degrees=False):
15
+ r = SCR.from_matrix(pose[:3, :3]).as_euler('XYZ', degrees=degrees)
16
+ t = pose[:3, 3]
17
+ return r, t
18
+
19
+ def load_camera_cfg(cfg):
20
+ cam_params = {}
21
+ cams = OmegaConf.to_container(cfg.cams, resolve=True)
22
+ for cam_name, cam in cams.items():
23
+ v2c = rt2pose(cam['extrinsics']['v2c_rot'], cam['extrinsics']['v2c_trans'], degrees=True)
24
+ l2c = rt2pose(cam['extrinsics']['l2c_rot'], cam['extrinsics']['l2c_trans'], degrees=True)
25
+ cam_intrin = cam['intrinsics']
26
+ cam_intrin['fovx'] = cam_intrin['fovx'] / 180.0 * np.pi
27
+ cam_intrin['fovy'] = cam_intrin['fovy'] / 180.0 * np.pi
28
+ cam_params[cam_name] = {'intrinsic': cam_intrin, 'v2c': v2c, 'l2c': l2c}
29
+
30
+ rect_mat = np.eye(4)
31
+ if 'cam_rect' in cfg:
32
+ rect_mat[:3, :3] = SCR.from_euler('XYZ', cfg.cam_rect.rot, degrees=True).as_matrix()
33
+ rect_mat[:3, 3] = np.array(cfg.cam_rect.trans)
34
+
35
+ return cam_params, OmegaConf.to_container(cfg.cam_align, resolve=True), rect_mat
36
+
37
+ def fov2focal(fov, pixels):
38
+ return pixels / (2 * math.tan(fov / 2))
39
+
40
+ def focal2fov(focal, pixels):
41
+ return 2*math.atan(pixels/(2*focal))
42
+
43
+ def create_cam(intrinsic, c2w):
44
+ fovx, fovy = intrinsic['fovx'], intrinsic['fovy']
45
+ h, w = intrinsic['H'], intrinsic['W']
46
+ K = np.eye(4)
47
+ K[0, 0], K[1, 1] = fov2focal(fovx, w), fov2focal(fovy, h)
48
+ K[0, 2], K[1, 2] = intrinsic['cx'], intrinsic['cy']
49
+ cam = Camera(K=K, c2w=c2w, width=w, height=h,
50
+ image=np.zeros((h, w, 3)), image_name='', dynamics={})
51
+ return cam
52
+
53
+ def traj2control(plan_traj, info):
54
+ """
55
+ The input plan trajectory is under lidar coordinates
56
+ x to right, y to forward and z to upward
57
+ """
58
+ plan_traj_stats = np.zeros((plan_traj.shape[0]+1, 5))
59
+ plan_traj_stats[1:, :2] = plan_traj[:, [1,0]]
60
+ prev_a, prev_b = 0, 0
61
+ for i, (a, b) in enumerate(plan_traj):
62
+ rot = np.arctan((b - prev_b)/(a - prev_a))
63
+ plan_traj_stats[i+1, 2] = rot
64
+ curr_stat = np.array(
65
+ [0, 0, 0, info['ego_velo'], info['ego_steer']]
66
+ )
67
+ acc, steer_rate = plan2control(plan_traj_stats, curr_stat)
68
+ return acc, steer_rate
69
+
70
+ def dense_cam_poses(cam_poses, cmds):
71
+
72
+ for i in range(5):
73
+ dense_poses = []
74
+ dense_cmds = []
75
+ for i in range(cam_poses.shape[0]-1):
76
+ cam1 = cam_poses[i]
77
+ cam2 = cam_poses[i+1]
78
+ dense_poses.append(cam1)
79
+ dense_cmds.append(cmds[i])
80
+ if np.linalg.norm(cam1[:3, 3]-cam2[:3, 3]) > 0.1:
81
+ euler1 = SCR.from_matrix(cam1[:3, :3]).as_euler("XYZ")
82
+ euler2 = SCR.from_matrix(cam2[:3, :3]).as_euler("XYZ")
83
+ interp_euler = (euler1 + euler2) / 2
84
+ interp_trans = (cam1[:3, 3] + cam2[:3, 3]) / 2
85
+ interp_pose = np.eye(4)
86
+ interp_pose[:3, :3] = SCR.from_euler("XYZ", interp_euler).as_matrix()
87
+ interp_pose[:3, 3] = interp_trans
88
+ dense_poses.append(interp_pose)
89
+ dense_cmds.append(cmds[i])
90
+ dense_poses.append(cam_poses[-1])
91
+ dense_poses = np.stack(dense_poses)
92
+ cam_poses = dense_poses
93
+ cmds = dense_cmds
94
+
95
+ return cam_poses, cmds
96
+
97
+ def traj_transform_to_global(traj, ego_box):
98
+ """
99
+ Transform trajectory from ego-centeric frame to global frame
100
+ """
101
+ ego_x, ego_y, _, _, _, _, ego_yaw = ego_box
102
+ global_points = [
103
+ (
104
+ ego_x
105
+ + px * math.cos(ego_yaw)
106
+ - py * math.sin(ego_yaw),
107
+ ego_y
108
+ + px * math.sin(ego_yaw)
109
+ + py * math.cos(ego_yaw),
110
+ )
111
+ for px, py in traj
112
+ ]
113
+ global_trajs = []
114
+ for i in range(1, len(global_points)):
115
+ x1, y1 = global_points[i - 1]
116
+ x2, y2 = global_points[i]
117
+ dx, dy = x2 - x1, y2 - y1
118
+ # distance = math.sqrt(dx**2 + dy**2)
119
+ yaw = math.atan2(dy, dx)
120
+ global_trajs.append((x1, y1, yaw))
121
+ return global_trajs
122
+
code/submodules/Pplan/Policy/base.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+
4
+ class Policy(abc.ABC):
5
+ def __init__(self, device, *args, **kwargs):
6
+ self.device = device
7
+
8
+ @abc.abstractmethod
9
+ def get_action(self, obs_dict, **kwargs):
10
+ """Predict an action based on the input observation """
11
+ pass
12
+
13
+ @abc.abstractmethod
14
+ def eval(self):
15
+ """Set the policy to evaluation mode"""
16
+ pass
code/submodules/Pplan/Policy/sampling_planner.py ADDED
File without changes
code/submodules/Pplan/Sampling/__init__.py ADDED
File without changes
code/submodules/Pplan/Sampling/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (156 Bytes). View file
 
code/submodules/Pplan/Sampling/__pycache__/forward_sampler.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
code/submodules/Pplan/Sampling/__pycache__/spline_planner.cpython-311.pyc ADDED
Binary file (32.1 kB). View file
 
code/submodules/Pplan/Sampling/forward_sampler.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import raiseExceptions
2
+ import numpy as np
3
+ import torch
4
+ import pdb
5
+ from ..utils import geometry_utils as GeoUtils
6
+ import matplotlib.pyplot as plt
7
+ from scipy.interpolate import interp1d
8
+ import random
9
+ from typing import (
10
+ Any,
11
+ Callable,
12
+ Dict,
13
+ Final,
14
+ Iterable,
15
+ List,
16
+ Optional,
17
+ Set,
18
+ Tuple,
19
+ Union,
20
+ )
21
+ class ForwardSampler(object):
22
+ def __init__(self,dt:float, acce_grid:list,dhm_grid:list, dhf_grid:list, max_rvel=8,max_steer=0.5, vbound=[-5.0, 30],device="cuda" if torch.cuda.is_available() else "cpu"):
23
+ self.device = device
24
+ self.accels = torch.tensor(acce_grid,device=self.device)
25
+ self.dhf_grid = torch.tensor(dhf_grid,device=self.device)
26
+ self.dhm_grid = torch.tensor(dhm_grid,device=self.device)
27
+ self.max_rvel = max_rvel
28
+ self.vbound = vbound
29
+ self.max_steer = max_steer
30
+ self.dt = dt
31
+
32
+
33
+ def velocity_plan(self,x0:torch.Tensor,T:int,acce:Optional[torch.Tensor]=None):
34
+ """plan velocity profile
35
+
36
+ Args:
37
+ x0 (torch.Tensor): [B, 4], X,Y,v,heading
38
+ T (int): time horizon
39
+ acce (torch.Tensor): [B, N]
40
+ """
41
+ bs = x0.shape[0]
42
+ if acce is None:
43
+ acce = self.accels[None,:].repeat_interleave(bs,0)
44
+ v0 = x0[...,2] # [B]
45
+ vdes = v0[:,None,None]+torch.arange(T,device=self.device)[None,None]*acce[:,:,None]*self.dt
46
+ vplan = torch.clip(vdes,min=self.vbound[0],max=self.vbound[1])
47
+ return vplan # [B, N, T]
48
+
49
+ def lateral_plan(self,x0:torch.Tensor,vplan:torch.Tensor,dhf:torch.Tensor,dhm:torch.Tensor,T:int,bangbang=True):
50
+ """plan lateral profile,
51
+ steering plan that ends with the desired heading change with mean heading change equal to dhm, if feasible
52
+ Args:
53
+ x0 (torch.Tensor): [B, 4], X,Y,v,heading
54
+ vplan (torch.Tensor): [B, N, T] velocity profile
55
+ dhf (torch.Tensor): [B, M] desired heading change at the end of the horizon
56
+ dhm (torch.Tensor): [B, M] mean heading change at the end of the horizon
57
+ T (int): horizon
58
+ """
59
+ # using a linear steering profile
60
+ bs,M = dhf.shape
61
+ N = vplan.shape[1]
62
+ vplan = vplan[:,:,None] # [B, N, 1, T]
63
+ vl = torch.cat([x0[:,2].reshape(-1,1,1,1).repeat_interleave(N,1),vplan[...,:-1]],-1)
64
+ acce = vplan-vl
65
+
66
+ c0 = torch.abs(vl)
67
+ c1 = torch.cumsum(c0*self.dt,-1)
68
+ c2 = torch.cumsum(c1*self.dt,-1)
69
+ c3 = torch.cumsum(c2*self.dt,-1)
70
+
71
+
72
+ # algebraic equation: c1[T]*a0+c2[T]*a1 = dhf, c2[T]*a0+c3[T]*a1 = dhm
73
+
74
+ a0 = (c3[...,-1]*dhf.unsqueeze(1)-c2[...,-1]*dhm.unsqueeze(1))/(c1[...,-1]*c3[...,-1]-c2[...,-1]**2) # [B, N, M]
75
+ a1 = (dhf.unsqueeze(1)-c1[...,-1]*a0)/c2[...,-1]
76
+
77
+ yawrate = a0[...,None]*c0+a1[...,None]*c1
78
+
79
+ if bangbang:
80
+ # turn into bang-bang control to reduce the peak steering value, but the mean heading value is not retained
81
+ pos_flag = (yawrate>0)
82
+ neg_flag = ~pos_flag
83
+ mean_pos_steering = (yawrate*pos_flag).sum(-1)/((c0*pos_flag).sum(-1)+1e-6)
84
+ mean_neg_steering = (yawrate*neg_flag).sum(-1)/((c0*neg_flag).sum(-1)+1e-6)
85
+ mean_pos_steering = torch.clip(mean_pos_steering,min=-self.max_steer,max=self.max_steer)
86
+ mean_neg_steering = torch.clip(mean_neg_steering,min=-self.max_steer,max=self.max_steer)
87
+ bb_yawrate = (mean_pos_steering[...,None]*pos_flag+mean_neg_steering[...,None]*neg_flag)*c0
88
+ bb_yawrate = torch.clip(bb_yawrate,min=-self.max_rvel/c0,max=self.max_rvel/c0)
89
+ dh = torch.cumsum(bb_yawrate*self.dt,-1)
90
+ else:
91
+ yawrate = torch.clip(yawrate,min=-self.max_rvel/c0,max=self.max_rvel/c0)
92
+ yawrate = torch.clip(yawrate,min=-self.max_steer*c0,max=self.max_steer*c0)
93
+ dh = torch.cumsum(yawrate*self.dt,-1)
94
+ heading = x0[...,3,None,None,None]+dh
95
+
96
+ vx = vplan*torch.cos(heading)
97
+ vy = vplan*torch.sin(heading)
98
+ traj = torch.stack([x0[:,None,None,None,0]+vx.cumsum(-1)*self.dt,
99
+ x0[:,None,None,None,1]+vy.cumsum(-1)*self.dt,
100
+ vplan.repeat_interleave(M,2),
101
+ heading],-1)
102
+ t = torch.arange(1,T+1,device=self.device)[None,None,None,:,None].repeat(bs,N,M,1,1)*self.dt
103
+ xyvaqrt = torch.cat([traj[...,:3],acce[...,None].repeat_interleave(M,2),traj[...,3:],yawrate[...,None],t],-1)
104
+ return xyvaqrt.reshape(bs,N*M,T,-1) # [B, N*M, T, 7]
105
+
106
+ def sample_trajs(self,x0,T,bangbang=True):
107
+ # velocity sample
108
+ vplan = self.velocity_plan(x0,T)
109
+ bs = x0.shape[0]
110
+ dhf = self.dhf_grid
111
+ dhm = self.dhm_grid
112
+ Mf = dhf.shape[0]
113
+ Mm = dhm.shape[0]
114
+ dhm = dhm.repeat(Mf).unsqueeze(0).repeat_interleave(bs,0)
115
+ dhf = dhf.repeat_interleave(Mm,0).unsqueeze(0).repeat_interleave(bs,0)+dhm
116
+ return self.lateral_plan(x0,vplan,dhf,dhm,T,bangbang)
117
+
118
+
119
+
120
+ def test():
121
+ sampler = ForwardSampler(acce_grid=[-4,-2,0,2,4],dhm_grid=torch.linspace(-0.7,0.7,9),dhf_grid=[-0.4,0,0.4],dt=0.1)
122
+ x0 = torch.tensor([0,0,1.,0.],device="cuda").unsqueeze(0).repeat_interleave(3,0)
123
+ T = 10
124
+ # vel_grid = sampler.velocity_plan(x0,T)
125
+ # dhf = torch.tensor([0.5,0,-0.5]).repeat(3).unsqueeze(0)
126
+ # dhm = torch.tensor([0.2,0,-0.2]).repeat_interleave(3,0).unsqueeze(0)
127
+ traj = sampler.sample_trajs(x0,T,bangbang=False)
128
+ traj = traj[0].reshape(-1,T,7).cpu().numpy()
129
+ import matplotlib.pyplot as plt
130
+ fig,ax = plt.subplots()
131
+ for i in range(traj.shape[0]):
132
+ ax.plot(traj[i,:,0],traj[i,:,1])
133
+ plt.show()
134
+
135
+
136
+
137
+ if __name__ == "__main__":
138
+ test()
139
+
140
+
141
+