Spaces:
Paused
Paused
from logging import raiseExceptions | |
import numpy as np | |
import torch | |
import pdb | |
from ..utils import geometry_utils as GeoUtils | |
import matplotlib.pyplot as plt | |
from scipy.interpolate import interp1d | |
import random | |
from .forward_sampler import ForwardSampler | |
STATE_INDEX = [0, 1, 2, 4] | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def mean_control_effort_coefficients(x0, dx0, xf, dxf): | |
"""Returns `(c4, c3, c2)` corresponding to `c4 * tf**-4 + c3 * tf**-3 + c2 * tf**-2`.""" | |
return (12 * (x0 - xf) ** 2, 12 * (dx0 + dxf) * (x0 - xf), 4 * dx0 ** 2 + 4 * dx0 * dxf + 4 * dxf ** 2) | |
def cubic_spline_coefficients(x0, dx0, xf, dxf, tf): | |
return (x0, dx0, -2 * dx0 / tf - dxf / tf - 3 * x0 / tf ** 2 + 3 * xf / tf ** 2, | |
dx0 / tf ** 2 + dxf / tf ** 2 + 2 * x0 / tf ** 3 - 2 * xf / tf ** 3) | |
def compute_interpolating_spline(state_0, state_f, tf): | |
dx0, dy0 = state_0[..., 2] * \ | |
torch.cos(state_0[..., 3]), state_0[..., 2] * \ | |
torch.sin(state_0[..., 3]) | |
dxf, dyf = state_f[..., 2] * \ | |
torch.cos(state_f[..., 3]), state_f[..., 2] * \ | |
torch.sin(state_f[..., 3]) | |
tf = tf * torch.ones_like(state_0[..., 0]) | |
return ( | |
torch.stack(cubic_spline_coefficients( | |
state_0[..., 0], dx0, state_f[..., 0], dxf, tf), -1), | |
torch.stack(cubic_spline_coefficients( | |
state_0[..., 1], dy0, state_f[..., 1], dyf, tf), -1), | |
tf, | |
) | |
def compute_spline_xyvaqrt(x_coefficients, y_coefficients, tf, N=10): | |
t = torch.arange(N).unsqueeze(0).to(tf.device) * tf.unsqueeze(-1) / (N - 1) | |
tp = t[..., None] ** torch.arange(4).to(tf.device) | |
dtp = t[..., None] ** torch.tensor([0, 0, 1, 2] | |
).to(tf.device) * torch.arange(4).to(tf.device) | |
ddtp = t[..., None] ** torch.tensor([0, 0, 0, 1]).to( | |
tf.device) * torch.tensor([0, 0, 2, 6]).to(tf.device) | |
x_coefficients = x_coefficients.unsqueeze(-1) | |
y_coefficients = y_coefficients.unsqueeze(-1) | |
vx = dtp @ x_coefficients | |
vy = dtp @ y_coefficients | |
v = torch.hypot(vx, vy) | |
v_pos = torch.clip(v, min=1e-4) | |
ax = ddtp @ x_coefficients | |
ay = ddtp @ y_coefficients | |
a = (ax * vx + ay * vy) / v_pos | |
r = (-ax * vy + ay * vx) / (v_pos ** 2) | |
yaw = torch.atan2(vy, vx) | |
return torch.cat(( | |
tp @ x_coefficients, | |
tp @ y_coefficients, | |
v, | |
a, | |
yaw, | |
r, | |
t.unsqueeze(-1), | |
), -1) | |
def patch_yaw_low_speed(traj): | |
idx = traj[...,] | |
def mean_control_effort_coefficients(x0, dx0, xf, dxf): | |
"""Returns `(c4, c3, c2)` corresponding to `c4 * tf**-4 + c3 * tf**-3 + c2 * tf**-2`.""" | |
return (12 * (x0 - xf) ** 2, 12 * (dx0 + dxf) * (x0 - xf), 4 * dx0 ** 2 + 4 * dx0 * dxf + 4 * dxf ** 2) | |
class SplinePlanner(object): | |
def __init__(self, device, dx_grid=None, dy_grid=None, acce_grid=None, dyaw_grid=None, max_steer=0.5, max_rvel=8, | |
acce_bound=[-6, 4], vbound=[-2.0, 30], spline_order=3, N_seg=10, low_speed_threshold=2.0, seed=0): | |
self.spline_order = spline_order | |
self.device = device | |
assert spline_order == 3 | |
if dx_grid is None: | |
# self.dx_grid = torch.tensor([-4., 0, 4.]).to(self.device) | |
self.dx_grid = torch.tensor([0.]).to(self.device) | |
else: | |
self.dx_grid = torch.tensor(dx_grid).to(self.device) | |
if dy_grid is None: | |
self.dy_grid = torch.tensor([-3., -1.5, 0, 1.5, 3.]).to(self.device) | |
else: | |
self.dy_grid = torch.tensor(dy_grid).to(self.device) | |
self.dy_grid_lane = torch.tensor([-2., 0, 2., ]).to(self.device) | |
if acce_grid is None: | |
# self.acce_grid = torch.tensor([-1., -0.5, 0., 0.5, 1.]).to(self.device) | |
self.acce_grid = torch.tensor([-1., 0., 1.]).to(self.device) | |
else: | |
self.acce_grid = torch.tensor(acce_grid).to(self.device) | |
if dyaw_grid is None: | |
self.dyaw_grid = torch.tensor( | |
[-np.pi / 6, 0, np.pi / 6]).to(self.device) | |
else: | |
self.dyaw_grid = torch.tensor(dyaw_grid).to(self.device) | |
self.max_steer = max_steer | |
self.max_rvel = max_rvel | |
self.psi_bound = [-np.pi * 0.75, np.pi * 0.75] | |
self.acce_bound = acce_bound | |
self.vbound = vbound | |
self.N_seg = N_seg | |
self.low_speed_threshold = low_speed_threshold | |
self.forward_sampler = ForwardSampler(acce_grid=self.acce_grid, dhm_grid=torch.linspace(-0.7, 0.7, 9), | |
dhf_grid=[-0.4, 0, 0.4], dt=0.1, device=self.device) | |
torch.manual_seed(seed) | |
def calc_trajectories(self, x0, tf, xf, N=None): | |
if N is None: | |
N = self.N_seg | |
if x0.ndim == 1: | |
x0_tile = x0.tile(xf.shape[0], 1) | |
xc, yc, tf = compute_interpolating_spline(x0_tile, xf, tf) | |
elif x0.ndim == xf.ndim: | |
xc, yc, tf = compute_interpolating_spline(x0, xf, tf) | |
else: | |
raise ValueError("wrong dimension for x0") | |
traj = compute_spline_xyvaqrt(xc, yc, tf, N) | |
return traj | |
def gen_terminals_lane(self, x0, tf, lanes): | |
if lanes is None or len(lanes) == 0: | |
return self.gen_terminals(x0, tf) | |
gs = [self.dx_grid.shape[0], self.dy_grid_lane.shape[0], self.acce_grid.shape[0]] | |
dx = self.dx_grid[:, None, None, None].repeat(1, gs[1], gs[2], 1).flatten() | |
dy = self.dy_grid_lane[None, :, None, None].repeat(gs[0], 1, gs[2], 1).flatten() | |
dv = self.acce_grid[None, None, :, None].repeat( | |
gs[0], gs[1], 1, 1).flatten() * tf | |
xf = list() | |
if x0.ndim == 1: | |
for lane in lanes: | |
f, p_start = lane | |
if isinstance(p_start, np.ndarray): | |
p_start = torch.from_numpy(p_start).to(x0.device) | |
elif isinstance(p_start, torch.Tensor): | |
p_start = p_start.to(x0.device) | |
offset = x0[:2] - p_start[:2] | |
s_offset = offset[0] * \ | |
torch.cos(p_start[2]) + offset[1] * torch.sin(p_start[2]) | |
ds = dx + dv / 2 * tf + x0[2:3] * tf | |
ss = ds + s_offset | |
xyyaw = torch.from_numpy(f(ss.cpu().numpy())).type( | |
torch.float).to(x0.device) | |
xyyaw[..., 1] += dy | |
xf.append( | |
torch.cat((xyyaw[:, :2], dv.reshape(-1, 1) + x0[2:3], xyyaw[:, 2:]), -1)) | |
# adding the end points not fixated on lane | |
xf_straight = torch.stack([ds, dy, dv + x0[2], x0[3].tile(ds.shape[0])], -1) | |
xf.append(xf_straight) | |
elif x0.ndim == 2: | |
for lane in lanes: | |
f, p_start = lane | |
if isinstance(p_start, np.ndarray): | |
p_start = torch.from_numpy(p_start).to(x0.device) | |
elif isinstance(p_start, torch.Tensor): | |
p_start = p_start.to(x0.device) | |
offset = x0[:, :2] - p_start[None, :2] | |
s_offset = offset[:, 0] * torch.cos(p_start[2]) + offset[:, 1] * torch.sin(p_start[2]) | |
ds = (dx + dv / 2 * tf).unsqueeze(0) + x0[:, 2:3] * tf | |
ss = ds + s_offset.unsqueeze(-1) | |
xyyaw = torch.from_numpy(f(ss.cpu().numpy())).type( | |
torch.float).to(x0.device) | |
xyyaw[..., 1] += dy | |
xf.append(torch.cat((xyyaw[..., :2], dv.tile( | |
x0.shape[0], 1).unsqueeze(-1) + x0[:, None, 2:3], xyyaw[..., 2:]), -1)) | |
# adding the end points not fixated on lane | |
xf_straight = torch.stack([ds, dy.tile(x0.shape[0], 1), dv.tile(x0.shape[0], 1) + x0[:, None, 2], | |
x0[:, None, 3].tile(1, ds.shape[1])], -1) | |
xf.append(xf_straight) | |
else: | |
raise ValueError("x0 must have dimension 1 or 2") | |
xf = torch.cat(xf, -2) | |
return xf | |
def gen_terminals(self, x0, tf): | |
gs = [self.dx_grid.shape[0], self.dy_grid.shape[0], | |
self.acce_grid.shape[0], self.dyaw_grid.shape[0]] | |
dx = self.dx_grid[:, None, None, None].repeat( | |
1, gs[1], gs[2], gs[3]).flatten() | |
dy = self.dy_grid[None, :, None, None].repeat( | |
gs[0], 1, gs[2], gs[3]).flatten() | |
dv = tf * self.acce_grid[None, None, :, | |
None].repeat(gs[0], gs[1], 1, gs[3]).flatten() | |
dyaw = self.dyaw_grid[None, None, None, :].repeat( | |
gs[0], gs[1], gs[2], 1).flatten() | |
delta_x = torch.stack([dx, dy, dv, dyaw], -1) | |
if x0.ndim == 1: | |
xy = torch.cat( | |
(delta_x[:, 0:1] + delta_x[:, 2:3] / 2 * tf + x0[2:3] * tf, delta_x[:, 1:2]), -1) | |
rotated_delta_xy = GeoUtils.batch_rotate_2D(xy, x0[3]) | |
refpsi = torch.arctan2(rotated_delta_xy[..., 1], rotated_delta_xy[..., 0]) | |
rotated_xy = rotated_delta_xy + x0[:2] | |
return torch.cat((rotated_xy, delta_x[:, 2:3] + x0[2:3], delta_x[:, 3:] + refpsi.unsqueeze(-1)), -1) | |
elif x0.ndim == 2: | |
delta_x = torch.tile(delta_x, [x0.shape[0], 1, 1]) | |
xy = torch.cat( | |
(delta_x[:, :, 0:1] + delta_x[:, :, 2:3] / 2 * tf + x0[:, None, 2:3] * tf, delta_x[:, :, 1:2]), -1) | |
rotated_delta_xy = GeoUtils.batch_rotate_2D(xy, x0[:, 3:4]) | |
refpsi = torch.arctan2(rotated_delta_xy[..., 1], rotated_delta_xy[..., 0]) | |
rotated_xy = rotated_delta_xy + x0[:, None, :2] | |
return torch.cat( | |
(rotated_xy, delta_x[:, :, 2:3] + x0[:, None, 2:3], delta_x[:, :, 3:] + refpsi.unsqueeze(-1)), -1) | |
else: | |
raise ValueError("x0 must have dimension 1 or 2") | |
def feasible_flag(self, traj, xf): | |
diff = traj[..., -1, STATE_INDEX] - xf | |
feas_flag = ((traj[..., 2] >= self.vbound[0]) & (traj[..., 2] < self.vbound[1]) & | |
(traj[..., 4] >= self.psi_bound[0]) & (traj[..., 4] < self.psi_bound[1]) & | |
(traj[..., 3] >= self.acce_bound[0]) & (traj[..., 3] <= self.acce_bound[1]) & | |
(torch.abs(traj[..., 5] * traj[..., 2]) <= self.max_rvel) & ( | |
torch.clip(torch.abs(traj[..., 2]), min=0.5) * self.max_steer >= torch.abs( | |
traj[..., 5]))).all(1) & ( | |
diff.abs() < 5e-3).all(-1) | |
return feas_flag | |
def gen_trajectories(self, x0, tf, lanes=None, dyn_filter=True, N=None, lane_only=False): | |
if N is None: | |
N = self.N_seg | |
if lanes is not None: | |
if isinstance(lanes, torch.Tensor): | |
lanes = lanes.cpu().numpy() | |
lane_interp = [GeoUtils.interp_lanes(lane) for lane in lanes] | |
xf_lane = self.gen_terminals_lane( | |
x0, tf, lane_interp) | |
else: | |
xf_lane = None | |
if lane_only: | |
assert xf_lane is not None | |
xf_set = xf_lane | |
else: | |
xf_set = self.gen_terminals(x0, tf) | |
if xf_lane is not None: | |
xf_set = torch.cat((xf_lane, xf_set), 0) | |
x0[..., 2] = torch.clip(x0[..., 2], min=1e-3) | |
xf_set[..., 2] = torch.clip(xf_set[..., 2], min=1e-3) | |
# x, y, v, a, yaw,r, t | |
traj = self.calc_trajectories(x0, tf, xf_set, N) | |
if dyn_filter: | |
feas_flag = self.feasible_flag(traj, xf_set) | |
traj = traj[feas_flag] | |
xf = xf_set[feas_flag] | |
traj = traj[..., 1:, :] # remove the first time step | |
if x0[2] < self.low_speed_threshold: | |
# call forward sampler when velocity is low | |
extra_traj = self.forward_sampler.sample_trajs(x0.unsqueeze(0), int(tf / self.forward_sampler.dt)).squeeze( | |
0) | |
f = interp1d(np.arange(1, extra_traj.shape[-2] + 1) * self.forward_sampler.dt, extra_traj.cpu().numpy(), | |
axis=-2) | |
extra_traj = torch.from_numpy(f(np.arange(1, N) * tf / N)).to(self.device) | |
traj = torch.cat((traj, extra_traj), 0) | |
return traj, traj[..., -1, STATE_INDEX] | |
def get_similarity_flag(x0, x1, thres=[2.0, 0.5, 2.0, np.pi / 12]): | |
thres = torch.tensor(thres, device=x0.device) | |
diff = x0.unsqueeze(-3) - x1.unsqueeze(-2) | |
flag = diff.abs() < thres | |
flag = flag.all(-1).any(-1) | |
return flag | |
def gen_terminals_hardcoded(self, x0_set, tf): | |
X0, Y0, v0, psi0 = x0_set[..., 0:1], x0_set[..., 1:2], x0_set[..., 2:3], x0_set[..., 3:] | |
xf_set = list() | |
# drive straight | |
xf_straight = torch.cat((X0 + v0 * tf * torch.cos(psi0), Y0 + v0 * tf * torch.sin(psi0), v0, psi0), | |
-1).unsqueeze(1) | |
xf_set.append(xf_straight) | |
# hard brake | |
decel = torch.clip(-v0 / tf, min=self.acce_bound[0]) | |
xf_brake = torch.cat((X0 + (v0 + decel * 0.5 * tf) * tf * torch.cos(psi0), | |
Y0 + (v0 + decel * 0.5 * tf) * tf * torch.sin(psi0), v0 + decel * tf, psi0), | |
-1).unsqueeze(1) | |
xf_set.append(xf_brake) | |
xf_set = torch.cat(xf_set, 1) | |
return xf_set | |
def gen_trajectory_batch(self, x0_set, tf, lanes=None, dyn_filter=True, N=None, max_children=None): | |
if N is None: | |
N = self.N_seg | |
device = x0_set.device | |
xf_set_sample = self.gen_terminals(x0_set, tf) | |
importance_score = torch.rand(xf_set_sample.shape[:2], device=device) | |
xf_set_hardcoded = self.gen_terminals_hardcoded(x0_set, tf) | |
xf_set = torch.cat((xf_set_sample, xf_set_hardcoded), 1) | |
importance_score = torch.cat((importance_score, 2 * torch.ones(xf_set_hardcoded.shape[:2], device=device)), 1) | |
if lanes is not None: | |
lane_interp = [GeoUtils.interp_lanes(lane) for lane in lanes] | |
xf_set_lane = self.gen_terminals_lane(x0_set, tf, lane_interp) | |
xf_set = torch.cat((xf_set, xf_set_lane), -2) | |
importance_score = torch.cat((importance_score, torch.ones(xf_set_lane.shape[:2], device=x0_set.device)), 1) | |
x0_set[..., 2] = torch.clip(x0_set[..., 2], min=1e-3) | |
xf_set[..., 2] = torch.clip(xf_set[..., 2], min=1e-3) | |
num_node = x0_set.shape[0] | |
num = xf_set.shape[1] | |
x0_tiled = x0_set.repeat_interleave(num, 0) | |
xf_tiled = xf_set.reshape(-1, xf_set.shape[-1]) | |
traj = self.calc_trajectories(x0_tiled, tf, xf_tiled, N) | |
if dyn_filter: | |
feas_flag = self.feasible_flag(traj, xf_tiled) | |
else: | |
feas_flag = torch.ones( | |
num * num_node, dtype=torch.bool).to(x0_set.device) | |
feas_flag = feas_flag.reshape(num_node, num) | |
traj = traj.reshape(num_node, num, *traj.shape[1:]) | |
if (x0_set[:, 2] < self.low_speed_threshold).any(): | |
extra_traj = self.forward_sampler.sample_trajs(x0_set, int(tf / self.forward_sampler.dt)) | |
f = interp1d(np.arange(1, extra_traj.shape[-2] + 1) * self.forward_sampler.dt, extra_traj.cpu().numpy(), | |
axis=-2, bounds_error=False, fill_value="extrapolate") | |
extra_traj = torch.from_numpy(f(np.arange(0, N) * tf / N)).to(self.device) | |
traj = torch.cat((traj, extra_traj), 1) | |
extra_importance_score = torch.rand(extra_traj.shape[:2], device=device) | |
importance_score = torch.cat((importance_score, extra_importance_score), 1) | |
feas_flag = torch.cat((feas_flag, torch.ones(extra_traj.shape[:2], device=device)), 1) | |
importance_score = importance_score * feas_flag | |
chosen_idx = [torch.where(importance_score[i])[0].tolist() for i in range(num_node)] | |
if max_children is not None: | |
chosen_idx = [idx if len(idx) <= max_children else torch.topk(importance_score[i], max_children)[1] for | |
i, idx in enumerate(chosen_idx)] | |
traj_batch = [traj[i, chosen_idx[i], 1:] for i in range(num_node)] | |
return traj_batch | |
def gen_trajectory_tree(self, x0, tf, n_layers, dyn_filter=True, N=None): | |
if N is None: | |
N = self.N_seg | |
trajs = list() | |
nodes = [x0[None, :]] | |
for i in range(n_layers): | |
xf = self.gen_terminals(nodes[i], tf) | |
x0i = torch.tile(nodes[i], [xf.shape[1], 1]) | |
xf = xf.reshape(-1, xf.shape[-1]) | |
traj = self.calc_trajectories(x0i, tf, xf, N) | |
if dyn_filter: | |
feas_flag = self.feasible_flag(traj, xf) | |
traj = traj[feas_flag] | |
xf = xf[feas_flag] | |
trajs.append(traj) | |
nodes.append(xf.reshape(-1, xf.shape[-1])) | |
return trajs, nodes[1:] | |
if __name__ == "__main__": | |
planner = SplinePlanner("cuda") | |
x0 = torch.tensor([1., 2., 1., 0.]).cuda() | |
tf = 5 | |
traj, xf = planner.gen_trajectories(x0, tf) | |
trajs = planner.gen_trajectory_batch(xf, tf) | |
# x, y, v, a, yaw,r, t = traj | |
msize = 12 | |
trajs, nodes = planner.gen_trajectory_tree(x0, tf, 2) | |
x0 = x0.cpu().numpy() | |
traj = traj.cpu().numpy() | |
plt.figure(figsize=(20, 10)) | |
plt.plot(x0[0], x0[1], marker="o", color="b", markersize=msize) | |
for node, traj in zip(nodes, trajs): | |
node = node.cpu().numpy() | |
traj = traj.cpu().numpy() | |
x = traj[..., 0] | |
y = traj[..., 1] | |
plt.plot(x.T, y.T, color="k") | |
for p in node: | |
plt.plot(p[0], p[1], marker="o", color="b", markersize=msize) | |
plt.show() | |