Spaces:
Starting
on
T4
Starting
on
T4
File size: 10,725 Bytes
7f3c2df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 |
import numpy as np
import torch
from scipy.spatial.transform import Rotation as SCR
import roma
from collections import namedtuple
from sim.utils.agent_controller import constant_headaway
from sim.utils import agent_controller
from collections import defaultdict
from trajdata import AgentType, UnifiedDataset
from trajdata.maps import MapAPI
from trajdata.simulation import SimulationScene
from sim.utils.sim_utils import rt2pose, pose2rt
from sim.utils.agent_controller import IDM, AttackPlanner, ConstantPlanner, UnicyclePlanner
import os
import json
Model = namedtuple('Models', ['model_path', 'controller', 'controller_args'])
class planner:
def __init__(self, plan_list, scene_path=None, dt=0.2, unified_map=None, ground=None):
self.unified_map = unified_map
self.ground = ground
self.PREDICT_STEPS = 20
self.NUM_NEIGHBORS = 3
self.rectify_angle = 0
if self.unified_map is not None:
self.rectify_angle = self.unified_map.rectify_angle
# plan_list: a, b, height, yaw, v, model_path, controller, controller_args: dict
self.stats, self.route, self.controller, self.ckpts, self.wlhs = {}, {}, {}, {}, {}
self.dt = dt
self.ATTACK_FREQ = 3
for iid, args in enumerate(plan_list):
if args[6] == "UnicyclePlanner":
# self.ckpts[f"agent_{iid}"] = os.path.join(scene_path, "ckpts", f"dynamic_{args[7]}_chkpnt30000.pth")
# self.wlhs[f'agent_{iid}'] = [2.0, 4.0, 1.5]
self.ckpts[f'agent_{iid}'] = os.path.join(args[5], 'gs.pth')
with open(os.path.join(args[5], 'wlh.json')) as f:
self.wlhs[f'agent_{iid}'] = json.load(f)
uc_configs = args[7]
self.controller[f"agent_{iid}"] = UnicyclePlanner(os.path.join(scene_path, f"unicycle_{uc_configs['uc_id']}.pth"), speed=uc_configs['speed'])
a, b, v, pitchroll, yaw, h = self.controller[f"agent_{iid}"].uc_model.forward(0.0)
self.stats[f'agent_{iid}'] = torch.tensor([a, b, args[2], yaw, v])
self.route[f'agent_{iid}'] = None
else:
model = Model(*args[5:])
self.stats[f'agent_{iid}'] = torch.tensor(args[:5]) # a, b, height, yaw, v
self.stats[f'agent_{iid}'][3] += self.rectify_angle
self.route[f'agent_{iid}'] = None
self.ckpts[f'agent_{iid}'] = os.path.join(model.model_path, 'gs.pth')
with open(os.path.join(model.model_path, 'wlh.json')) as f:
self.wlhs[f'agent_{iid}'] = json.load(f)
self.controller[f'agent_{iid}'] = getattr(agent_controller, model.controller)(**model.controller_args)
if model.controller == "AttackPlanner":
self.ATTACK_FREQ = model.controller_args["ATTACK_FREQ"]
def update_ground(self, ground):
self.ground = ground
def update_agent_route(self):
assert self.unified_map is not None, "Map shouldn't be None to forecast agent path"
for iid, stat in self.stats.items():
path = self.unified_map.get_route(stat)
if path is None:
print("path not found at ", self.stats)
if path is not None:
self.route[iid] = torch.from_numpy(np.hstack([path[:, :2], path[:, -1:]]))
def ground_height(self, u, v):
cam_poses, cam_height, _ = self.ground
cam_poses = torch.from_numpy(cam_poses)
cam_dist = np.sqrt(
(cam_poses[:-1, 0, 3] - u) ** 2 + (cam_poses[:-1, 2, 3] - v) ** 2
)
nearest_cam_idx = np.argmin(cam_dist, axis=0)
nearest_c2w = cam_poses[nearest_cam_idx]
nearest_w2c = np.linalg.inv(nearest_c2w)
uv_local = nearest_w2c[:3, :3] @ np.array([u, 0, v]) + nearest_w2c[:3, 3]
uv_local[1] = 0
uv_world = nearest_c2w[:3, :3] @ uv_local + nearest_c2w[:3, 3]
return uv_world[1] + cam_height
def plan_traj(self, t, ego_stats):
all_stats = [ego_stats]
for iid, stat in self.stats.items():
all_stats.append(stat[[0, 1, 3, 4]]) # a, b, yaw, v
all_stats = torch.stack(all_stats, dim=0)
future_states = constant_headaway(all_stats, num_steps=self.PREDICT_STEPS, dt=self.dt)
b2ws = {}
for iid, stat in self.stats.items():
# find closet neighbors
curr_xy_agents = all_stats[:, :2]
distance_agents = torch.norm(curr_xy_agents - stat[:2], dim=-1)
neighbor_idx = torch.argsort(distance_agents)[1:self.NUM_NEIGHBORS + 1]
neighbors = future_states[neighbor_idx]
controller = self.controller[iid]
if type(controller) is IDM:
next_xyrv = controller.update(state=stat[[0, 1, 3, 4]], path=self.route[iid], dt=self.dt,
neighbors=neighbors)
elif type(controller) is AttackPlanner:
safe_neighbors = neighbors[1:, ...]
next_xyrv = controller.update(state=stat[[0, 1, 3, 4]], unified_map=self.unified_map, dt=0.1,
neighbors=safe_neighbors, attacked_states=future_states[0],
new_plan=((t // self.dt) % self.ATTACK_FREQ == 0))
elif type(controller) is ConstantPlanner:
next_xyrv = controller.update(state=stat[[0, 1, 3, 4]], dt=self.dt)
elif type(controller) is UnicyclePlanner:
next_xyrv = controller.update(dt=self.dt)
else:
raise NotImplementedError
next_stat = torch.zeros_like(stat)
next_stat[[0, 1, 3, 4]] = next_xyrv.float()
next_stat[2] = stat[2]
self.stats[iid] = next_stat
b2w = np.eye(4)
h = self.ground_height(next_xyrv[0].numpy(), next_xyrv[1].numpy())
if type(controller) is UnicyclePlanner:
# b2w[:3, :3] = SCR.from_euler('xzy', [pitch_roll[0], pitch_roll[1], stat[3]]).as_matrix()
b2w[:3, :3] = SCR.from_euler('y', [-stat[3]]).as_matrix()
b2w[:3, 3] = np.array([next_stat[0], h + stat[2], next_stat[1]])
else:
b2w[:3, :3] = SCR.from_euler('y', [-stat[3] - np.pi / 2 - self.rectify_angle]).as_matrix()
b2w[:3, 3] = np.array([next_stat[0], h + stat[2], next_stat[1]])
b2ws[iid] = torch.tensor(b2w).float().cuda()
return [b2ws, {}]
class UnifiedMap:
def __init__(self, datapath, version, scene_name):
self.datapath = datapath
self.version = version
self.dataset = UnifiedDataset(
desired_data=[self.version],
data_dirs={
self.version: self.datapath,
},
cache_location="/app/app_datas/nusc_map_cache",
only_types=[AgentType.VEHICLE],
agent_interaction_distances=defaultdict(lambda: 50.0),
desired_dt=0.1,
num_workers=4,
verbose=True,
)
self.map_api = MapAPI(self.dataset.cache_path)
self.scene = None
for scene in list(self.dataset.scenes()):
if scene.name == scene_name:
self.scene = scene
assert self.scene is not None, f"Can't find scene {scene_name}"
self.vector_map = self.map_api.get_map(
f"{self.version}:{self.scene.location}"
)
self.ego_start_pos, self.ego_start_yaw = self.get_start_pose()
self.rectify_angle = 0
if self.ego_start_yaw < 0:
self.ego_start_yaw += np.pi
self.rectify_angle = np.pi
self.PATH_LENGTH = 100
def get_start_pose(self):
sim_scene: SimulationScene = SimulationScene(
env_name=self.version,
scene_name=f"sim_scene",
scene=self.scene,
dataset=self.dataset,
init_timestep=0,
freeze_agents=True,
)
obs = sim_scene.reset()
assert obs.agent_name[0] == 'ego', 'The first agent is not ego'
# We consider position of the first ego frame as origin
# This suppose is ok when the first frame front camera pose is set as origin
ego_start_pos = obs.curr_agent_state.position[0]
ego_start_yaw = obs.curr_agent_state.heading[0]
return ego_start_pos.numpy(), ego_start_yaw.item()
def xyzr_local2world(self, stat):
alpha = np.arctan(stat[0] / stat[1])
beta = self.ego_start_yaw - alpha
dist = np.linalg.norm(stat[:2])
delta_x = dist * np.cos(beta)
delta_y = dist * np.sin(beta)
world_stat = np.zeros(4)
world_stat[0] = delta_x + self.ego_start_pos[0]
world_stat[1] = delta_y + self.ego_start_pos[1]
world_stat[3] = stat[3] + self.ego_start_yaw
return world_stat
def batch_xyzr_world2local(self, stat):
beta = np.arctan((stat[:, 1] - self.ego_start_pos[1]) / (stat[:, 0] - self.ego_start_pos[0]))
alpha = self.ego_start_yaw - beta
dist = np.linalg.norm(stat[:, :2] - self.ego_start_pos, axis=1)
delta_x = dist * np.sin(alpha)
delta_y = dist * np.cos(alpha)
local_stat = np.zeros_like(stat)
local_stat[:, 0] = delta_x
local_stat[:, 1] = delta_y
local_stat[:, 3] = stat[:, 3] - self.ego_start_yaw
return local_stat
def get_route(self, stat):
# stat: a, b, height, yaw, v
curr_xyzr = self.xyzr_local2world(stat[:4].numpy())
# lanes = self.vector_map.get_current_lane(curr_xyzr, max_dist=5, max_heading_error=np.pi/3)
lanes = self.vector_map.get_current_lane(curr_xyzr)
if len(lanes) > 0:
curr_lane = lanes[0]
path = self.batch_xyzr_world2local(curr_lane.center.xyzh)
total_path_length = np.linalg.norm(curr_lane.center.xy[1:] - curr_lane.center.xy[:-1], axis=1).sum()
# random select next lanes until reach PATH_LENGTH
while total_path_length < self.PATH_LENGTH:
next_lanes = list(curr_lane.next_lanes)
if len(next_lanes) == 0:
break
next_lane = self.vector_map.get_road_lane(next_lanes[np.random.randint(len(next_lanes))])
path = np.vstack([path, self.batch_xyzr_world2local(next_lane.center.xyzh)])
total_path_length += np.linalg.norm(next_lane.center.xy[1:] - next_lane.center.xy[:-1], axis=1).sum()
curr_lane = next_lane
else:
path = None
return path
|