File size: 10,725 Bytes
7f3c2df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import numpy as np
import torch
from scipy.spatial.transform import Rotation as SCR
import roma
from collections import namedtuple
from sim.utils.agent_controller import constant_headaway
from sim.utils import agent_controller
from collections import defaultdict
from trajdata import AgentType, UnifiedDataset
from trajdata.maps import MapAPI
from trajdata.simulation import SimulationScene
from sim.utils.sim_utils import rt2pose, pose2rt
from sim.utils.agent_controller import IDM, AttackPlanner, ConstantPlanner, UnicyclePlanner
import os
import json

Model = namedtuple('Models', ['model_path', 'controller', 'controller_args'])


class planner:
    def __init__(self, plan_list, scene_path=None, dt=0.2, unified_map=None, ground=None):
        self.unified_map = unified_map
        self.ground = ground
        self.PREDICT_STEPS = 20
        self.NUM_NEIGHBORS = 3
        
        self.rectify_angle = 0
        if self.unified_map is not None:
            self.rectify_angle = self.unified_map.rectify_angle
        
        # plan_list: a, b, height, yaw, v, model_path, controller, controller_args: dict
        self.stats, self.route, self.controller, self.ckpts, self.wlhs = {}, {}, {}, {}, {}
        self.dt = dt
        self.ATTACK_FREQ = 3
        for iid, args in enumerate(plan_list):
            if args[6] == "UnicyclePlanner":
                # self.ckpts[f"agent_{iid}"] = os.path.join(scene_path, "ckpts", f"dynamic_{args[7]}_chkpnt30000.pth")
                # self.wlhs[f'agent_{iid}'] = [2.0, 4.0, 1.5]
                self.ckpts[f'agent_{iid}'] = os.path.join(args[5], 'gs.pth')
                with open(os.path.join(args[5], 'wlh.json')) as f:
                    self.wlhs[f'agent_{iid}'] = json.load(f)
                uc_configs = args[7]
                self.controller[f"agent_{iid}"] = UnicyclePlanner(os.path.join(scene_path, f"unicycle_{uc_configs['uc_id']}.pth"), speed=uc_configs['speed'])
                a, b, v, pitchroll, yaw, h = self.controller[f"agent_{iid}"].uc_model.forward(0.0)
                self.stats[f'agent_{iid}'] = torch.tensor([a, b, args[2], yaw, v])
                self.route[f'agent_{iid}'] = None
            else:
                model = Model(*args[5:])
                self.stats[f'agent_{iid}'] = torch.tensor(args[:5])  # a, b, height, yaw, v
                self.stats[f'agent_{iid}'][3] += self.rectify_angle
                self.route[f'agent_{iid}'] = None
                self.ckpts[f'agent_{iid}'] = os.path.join(model.model_path, 'gs.pth')
                with open(os.path.join(model.model_path, 'wlh.json')) as f:
                    self.wlhs[f'agent_{iid}'] = json.load(f)
                self.controller[f'agent_{iid}'] = getattr(agent_controller, model.controller)(**model.controller_args)
                if model.controller == "AttackPlanner":
                    self.ATTACK_FREQ = model.controller_args["ATTACK_FREQ"]

    def update_ground(self, ground):
        self.ground = ground

    def update_agent_route(self):
        assert self.unified_map is not None, "Map shouldn't be None to forecast agent path"
        for iid, stat in self.stats.items():
            path = self.unified_map.get_route(stat)
            if path is None:
                print("path not found at ", self.stats)
            if path is not None:
                self.route[iid] = torch.from_numpy(np.hstack([path[:, :2], path[:, -1:]]))

    def ground_height(self, u, v):
        cam_poses, cam_height, _ = self.ground
        cam_poses = torch.from_numpy(cam_poses)
        cam_dist = np.sqrt(
            (cam_poses[:-1, 0, 3] - u) ** 2 + (cam_poses[:-1, 2, 3] - v) ** 2
        )
        nearest_cam_idx = np.argmin(cam_dist, axis=0)
        nearest_c2w = cam_poses[nearest_cam_idx]
        
        nearest_w2c = np.linalg.inv(nearest_c2w)
        uv_local = nearest_w2c[:3, :3] @ np.array([u, 0, v]) + nearest_w2c[:3, 3]
        uv_local[1] = 0
        uv_world = nearest_c2w[:3, :3] @ uv_local + nearest_c2w[:3, 3]
        
        return uv_world[1] + cam_height

    def plan_traj(self, t, ego_stats):
        all_stats = [ego_stats]
        for iid, stat in self.stats.items():
            all_stats.append(stat[[0, 1, 3, 4]])  # a, b, yaw, v
        all_stats = torch.stack(all_stats, dim=0)
        future_states = constant_headaway(all_stats, num_steps=self.PREDICT_STEPS, dt=self.dt)

        b2ws = {}
        for iid, stat in self.stats.items():
            # find closet neighbors
            curr_xy_agents = all_stats[:, :2]
            distance_agents = torch.norm(curr_xy_agents - stat[:2], dim=-1)
            neighbor_idx = torch.argsort(distance_agents)[1:self.NUM_NEIGHBORS + 1]
            neighbors = future_states[neighbor_idx]

            controller = self.controller[iid]
            if type(controller) is IDM:
                next_xyrv = controller.update(state=stat[[0, 1, 3, 4]], path=self.route[iid], dt=self.dt,
                                              neighbors=neighbors)
            elif type(controller) is AttackPlanner:
                safe_neighbors = neighbors[1:, ...]
                next_xyrv = controller.update(state=stat[[0, 1, 3, 4]], unified_map=self.unified_map, dt=0.1,
                                              neighbors=safe_neighbors, attacked_states=future_states[0],
                                              new_plan=((t // self.dt) % self.ATTACK_FREQ == 0))
            elif type(controller) is ConstantPlanner:
                next_xyrv = controller.update(state=stat[[0, 1, 3, 4]], dt=self.dt)
            elif type(controller) is UnicyclePlanner:
                next_xyrv = controller.update(dt=self.dt)
            else:
                raise NotImplementedError
            next_stat = torch.zeros_like(stat)
            next_stat[[0, 1, 3, 4]] = next_xyrv.float()
            next_stat[2] = stat[2]
            self.stats[iid] = next_stat

            b2w = np.eye(4)
            h = self.ground_height(next_xyrv[0].numpy(), next_xyrv[1].numpy())
            if type(controller) is UnicyclePlanner:
                # b2w[:3, :3] = SCR.from_euler('xzy', [pitch_roll[0], pitch_roll[1], stat[3]]).as_matrix()
                b2w[:3, :3] = SCR.from_euler('y', [-stat[3]]).as_matrix()
                b2w[:3, 3] = np.array([next_stat[0], h + stat[2], next_stat[1]])
            else:
                b2w[:3, :3] = SCR.from_euler('y', [-stat[3] - np.pi / 2 - self.rectify_angle]).as_matrix()
                b2w[:3, 3] = np.array([next_stat[0], h + stat[2], next_stat[1]])
            b2ws[iid] = torch.tensor(b2w).float().cuda()

        return [b2ws, {}]


class UnifiedMap:
    def __init__(self, datapath, version, scene_name):
        self.datapath = datapath
        self.version = version

        self.dataset = UnifiedDataset(
            desired_data=[self.version],
            data_dirs={
                self.version: self.datapath,
            },
            cache_location="/app/app_datas/nusc_map_cache",
            only_types=[AgentType.VEHICLE],
            agent_interaction_distances=defaultdict(lambda: 50.0),
            desired_dt=0.1,
            num_workers=4,
            verbose=True,
        )

        self.map_api = MapAPI(self.dataset.cache_path)

        self.scene = None
        for scene in list(self.dataset.scenes()):
            if scene.name == scene_name:
                self.scene = scene
        assert self.scene is not None, f"Can't find scene {scene_name}"
        self.vector_map = self.map_api.get_map(
            f"{self.version}:{self.scene.location}"
        )
        self.ego_start_pos, self.ego_start_yaw = self.get_start_pose()
        self.rectify_angle = 0
        if self.ego_start_yaw < 0:
            self.ego_start_yaw += np.pi
            self.rectify_angle = np.pi
        self.PATH_LENGTH = 100

    def get_start_pose(self):
        sim_scene: SimulationScene = SimulationScene(
            env_name=self.version,
            scene_name=f"sim_scene",
            scene=self.scene,
            dataset=self.dataset,
            init_timestep=0,
            freeze_agents=True,
        )
        obs = sim_scene.reset()
        assert obs.agent_name[0] == 'ego', 'The first agent is not ego'
        # We consider position of the first ego frame as origin
        # This suppose is ok when the first frame front camera pose is set as origin
        ego_start_pos = obs.curr_agent_state.position[0]
        ego_start_yaw = obs.curr_agent_state.heading[0]
        return ego_start_pos.numpy(), ego_start_yaw.item()

    def xyzr_local2world(self, stat):
        alpha = np.arctan(stat[0] / stat[1])
        beta = self.ego_start_yaw - alpha
        dist = np.linalg.norm(stat[:2])
        delta_x = dist * np.cos(beta)
        delta_y = dist * np.sin(beta)

        world_stat = np.zeros(4)
        world_stat[0] = delta_x + self.ego_start_pos[0]
        world_stat[1] = delta_y + self.ego_start_pos[1]
        world_stat[3] = stat[3] + self.ego_start_yaw

        return world_stat

    def batch_xyzr_world2local(self, stat):
        beta = np.arctan((stat[:, 1] - self.ego_start_pos[1]) / (stat[:, 0] - self.ego_start_pos[0]))
        alpha = self.ego_start_yaw - beta
        dist = np.linalg.norm(stat[:, :2] - self.ego_start_pos, axis=1)
        delta_x = dist * np.sin(alpha)
        delta_y = dist * np.cos(alpha)

        local_stat = np.zeros_like(stat)
        local_stat[:, 0] = delta_x
        local_stat[:, 1] = delta_y
        local_stat[:, 3] = stat[:, 3] - self.ego_start_yaw

        return local_stat

    def get_route(self, stat):
        # stat: a, b, height, yaw, v
        curr_xyzr = self.xyzr_local2world(stat[:4].numpy())
        
        # lanes = self.vector_map.get_current_lane(curr_xyzr, max_dist=5, max_heading_error=np.pi/3)
        lanes = self.vector_map.get_current_lane(curr_xyzr)

        if len(lanes) > 0:
            curr_lane = lanes[0]
            path = self.batch_xyzr_world2local(curr_lane.center.xyzh)
            total_path_length = np.linalg.norm(curr_lane.center.xy[1:] - curr_lane.center.xy[:-1], axis=1).sum()
            # random select next lanes until reach PATH_LENGTH
            while total_path_length < self.PATH_LENGTH:
                next_lanes = list(curr_lane.next_lanes)
                if len(next_lanes) == 0:
                    break
                next_lane = self.vector_map.get_road_lane(next_lanes[np.random.randint(len(next_lanes))])
                path = np.vstack([path, self.batch_xyzr_world2local(next_lane.center.xyzh)])
                total_path_length += np.linalg.norm(next_lane.center.xy[1:] - next_lane.center.xy[:-1], axis=1).sum()
                curr_lane = next_lane
        else:
            path = None
        return path