hyzhou404's picture
private scenes
7f3c2df
import numpy as np
import torch
from scipy.spatial.transform import Rotation as SCR
import roma
from collections import namedtuple
from sim.utils.agent_controller import constant_headaway
from sim.utils import agent_controller
from collections import defaultdict
from trajdata import AgentType, UnifiedDataset
from trajdata.maps import MapAPI
from trajdata.simulation import SimulationScene
from sim.utils.sim_utils import rt2pose, pose2rt
from sim.utils.agent_controller import IDM, AttackPlanner, ConstantPlanner, UnicyclePlanner
import os
import json
Model = namedtuple('Models', ['model_path', 'controller', 'controller_args'])
class planner:
def __init__(self, plan_list, scene_path=None, dt=0.2, unified_map=None, ground=None):
self.unified_map = unified_map
self.ground = ground
self.PREDICT_STEPS = 20
self.NUM_NEIGHBORS = 3
self.rectify_angle = 0
if self.unified_map is not None:
self.rectify_angle = self.unified_map.rectify_angle
# plan_list: a, b, height, yaw, v, model_path, controller, controller_args: dict
self.stats, self.route, self.controller, self.ckpts, self.wlhs = {}, {}, {}, {}, {}
self.dt = dt
self.ATTACK_FREQ = 3
for iid, args in enumerate(plan_list):
if args[6] == "UnicyclePlanner":
# self.ckpts[f"agent_{iid}"] = os.path.join(scene_path, "ckpts", f"dynamic_{args[7]}_chkpnt30000.pth")
# self.wlhs[f'agent_{iid}'] = [2.0, 4.0, 1.5]
self.ckpts[f'agent_{iid}'] = os.path.join(args[5], 'gs.pth')
with open(os.path.join(args[5], 'wlh.json')) as f:
self.wlhs[f'agent_{iid}'] = json.load(f)
uc_configs = args[7]
self.controller[f"agent_{iid}"] = UnicyclePlanner(os.path.join(scene_path, f"unicycle_{uc_configs['uc_id']}.pth"), speed=uc_configs['speed'])
a, b, v, pitchroll, yaw, h = self.controller[f"agent_{iid}"].uc_model.forward(0.0)
self.stats[f'agent_{iid}'] = torch.tensor([a, b, args[2], yaw, v])
self.route[f'agent_{iid}'] = None
else:
model = Model(*args[5:])
self.stats[f'agent_{iid}'] = torch.tensor(args[:5]) # a, b, height, yaw, v
self.stats[f'agent_{iid}'][3] += self.rectify_angle
self.route[f'agent_{iid}'] = None
self.ckpts[f'agent_{iid}'] = os.path.join(model.model_path, 'gs.pth')
with open(os.path.join(model.model_path, 'wlh.json')) as f:
self.wlhs[f'agent_{iid}'] = json.load(f)
self.controller[f'agent_{iid}'] = getattr(agent_controller, model.controller)(**model.controller_args)
if model.controller == "AttackPlanner":
self.ATTACK_FREQ = model.controller_args["ATTACK_FREQ"]
def update_ground(self, ground):
self.ground = ground
def update_agent_route(self):
assert self.unified_map is not None, "Map shouldn't be None to forecast agent path"
for iid, stat in self.stats.items():
path = self.unified_map.get_route(stat)
if path is None:
print("path not found at ", self.stats)
if path is not None:
self.route[iid] = torch.from_numpy(np.hstack([path[:, :2], path[:, -1:]]))
def ground_height(self, u, v):
cam_poses, cam_height, _ = self.ground
cam_poses = torch.from_numpy(cam_poses)
cam_dist = np.sqrt(
(cam_poses[:-1, 0, 3] - u) ** 2 + (cam_poses[:-1, 2, 3] - v) ** 2
)
nearest_cam_idx = np.argmin(cam_dist, axis=0)
nearest_c2w = cam_poses[nearest_cam_idx]
nearest_w2c = np.linalg.inv(nearest_c2w)
uv_local = nearest_w2c[:3, :3] @ np.array([u, 0, v]) + nearest_w2c[:3, 3]
uv_local[1] = 0
uv_world = nearest_c2w[:3, :3] @ uv_local + nearest_c2w[:3, 3]
return uv_world[1] + cam_height
def plan_traj(self, t, ego_stats):
all_stats = [ego_stats]
for iid, stat in self.stats.items():
all_stats.append(stat[[0, 1, 3, 4]]) # a, b, yaw, v
all_stats = torch.stack(all_stats, dim=0)
future_states = constant_headaway(all_stats, num_steps=self.PREDICT_STEPS, dt=self.dt)
b2ws = {}
for iid, stat in self.stats.items():
# find closet neighbors
curr_xy_agents = all_stats[:, :2]
distance_agents = torch.norm(curr_xy_agents - stat[:2], dim=-1)
neighbor_idx = torch.argsort(distance_agents)[1:self.NUM_NEIGHBORS + 1]
neighbors = future_states[neighbor_idx]
controller = self.controller[iid]
if type(controller) is IDM:
next_xyrv = controller.update(state=stat[[0, 1, 3, 4]], path=self.route[iid], dt=self.dt,
neighbors=neighbors)
elif type(controller) is AttackPlanner:
safe_neighbors = neighbors[1:, ...]
next_xyrv = controller.update(state=stat[[0, 1, 3, 4]], unified_map=self.unified_map, dt=0.1,
neighbors=safe_neighbors, attacked_states=future_states[0],
new_plan=((t // self.dt) % self.ATTACK_FREQ == 0))
elif type(controller) is ConstantPlanner:
next_xyrv = controller.update(state=stat[[0, 1, 3, 4]], dt=self.dt)
elif type(controller) is UnicyclePlanner:
next_xyrv = controller.update(dt=self.dt)
else:
raise NotImplementedError
next_stat = torch.zeros_like(stat)
next_stat[[0, 1, 3, 4]] = next_xyrv.float()
next_stat[2] = stat[2]
self.stats[iid] = next_stat
b2w = np.eye(4)
h = self.ground_height(next_xyrv[0].numpy(), next_xyrv[1].numpy())
if type(controller) is UnicyclePlanner:
# b2w[:3, :3] = SCR.from_euler('xzy', [pitch_roll[0], pitch_roll[1], stat[3]]).as_matrix()
b2w[:3, :3] = SCR.from_euler('y', [-stat[3]]).as_matrix()
b2w[:3, 3] = np.array([next_stat[0], h + stat[2], next_stat[1]])
else:
b2w[:3, :3] = SCR.from_euler('y', [-stat[3] - np.pi / 2 - self.rectify_angle]).as_matrix()
b2w[:3, 3] = np.array([next_stat[0], h + stat[2], next_stat[1]])
b2ws[iid] = torch.tensor(b2w).float().cuda()
return [b2ws, {}]
class UnifiedMap:
def __init__(self, datapath, version, scene_name):
self.datapath = datapath
self.version = version
self.dataset = UnifiedDataset(
desired_data=[self.version],
data_dirs={
self.version: self.datapath,
},
cache_location="/app/app_datas/nusc_map_cache",
only_types=[AgentType.VEHICLE],
agent_interaction_distances=defaultdict(lambda: 50.0),
desired_dt=0.1,
num_workers=4,
verbose=True,
)
self.map_api = MapAPI(self.dataset.cache_path)
self.scene = None
for scene in list(self.dataset.scenes()):
if scene.name == scene_name:
self.scene = scene
assert self.scene is not None, f"Can't find scene {scene_name}"
self.vector_map = self.map_api.get_map(
f"{self.version}:{self.scene.location}"
)
self.ego_start_pos, self.ego_start_yaw = self.get_start_pose()
self.rectify_angle = 0
if self.ego_start_yaw < 0:
self.ego_start_yaw += np.pi
self.rectify_angle = np.pi
self.PATH_LENGTH = 100
def get_start_pose(self):
sim_scene: SimulationScene = SimulationScene(
env_name=self.version,
scene_name=f"sim_scene",
scene=self.scene,
dataset=self.dataset,
init_timestep=0,
freeze_agents=True,
)
obs = sim_scene.reset()
assert obs.agent_name[0] == 'ego', 'The first agent is not ego'
# We consider position of the first ego frame as origin
# This suppose is ok when the first frame front camera pose is set as origin
ego_start_pos = obs.curr_agent_state.position[0]
ego_start_yaw = obs.curr_agent_state.heading[0]
return ego_start_pos.numpy(), ego_start_yaw.item()
def xyzr_local2world(self, stat):
alpha = np.arctan(stat[0] / stat[1])
beta = self.ego_start_yaw - alpha
dist = np.linalg.norm(stat[:2])
delta_x = dist * np.cos(beta)
delta_y = dist * np.sin(beta)
world_stat = np.zeros(4)
world_stat[0] = delta_x + self.ego_start_pos[0]
world_stat[1] = delta_y + self.ego_start_pos[1]
world_stat[3] = stat[3] + self.ego_start_yaw
return world_stat
def batch_xyzr_world2local(self, stat):
beta = np.arctan((stat[:, 1] - self.ego_start_pos[1]) / (stat[:, 0] - self.ego_start_pos[0]))
alpha = self.ego_start_yaw - beta
dist = np.linalg.norm(stat[:, :2] - self.ego_start_pos, axis=1)
delta_x = dist * np.sin(alpha)
delta_y = dist * np.cos(alpha)
local_stat = np.zeros_like(stat)
local_stat[:, 0] = delta_x
local_stat[:, 1] = delta_y
local_stat[:, 3] = stat[:, 3] - self.ego_start_yaw
return local_stat
def get_route(self, stat):
# stat: a, b, height, yaw, v
curr_xyzr = self.xyzr_local2world(stat[:4].numpy())
# lanes = self.vector_map.get_current_lane(curr_xyzr, max_dist=5, max_heading_error=np.pi/3)
lanes = self.vector_map.get_current_lane(curr_xyzr)
if len(lanes) > 0:
curr_lane = lanes[0]
path = self.batch_xyzr_world2local(curr_lane.center.xyzh)
total_path_length = np.linalg.norm(curr_lane.center.xy[1:] - curr_lane.center.xy[:-1], axis=1).sum()
# random select next lanes until reach PATH_LENGTH
while total_path_length < self.PATH_LENGTH:
next_lanes = list(curr_lane.next_lanes)
if len(next_lanes) == 0:
break
next_lane = self.vector_map.get_road_lane(next_lanes[np.random.randint(len(next_lanes))])
path = np.vstack([path, self.batch_xyzr_world2local(next_lane.center.xyzh)])
total_path_length += np.linalg.norm(next_lane.center.xy[1:] - next_lane.center.xy[:-1], axis=1).sum()
curr_lane = next_lane
else:
path = None
return path