hyzhou404's picture
private scenes
7f3c2df
from logging import raiseExceptions
import numpy as np
import torch
import pdb
from ..utils import geometry_utils as GeoUtils
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import random
from .forward_sampler import ForwardSampler
STATE_INDEX = [0, 1, 2, 4]
device = "cuda" if torch.cuda.is_available() else "cpu"
def mean_control_effort_coefficients(x0, dx0, xf, dxf):
"""Returns `(c4, c3, c2)` corresponding to `c4 * tf**-4 + c3 * tf**-3 + c2 * tf**-2`."""
return (12 * (x0 - xf) ** 2, 12 * (dx0 + dxf) * (x0 - xf), 4 * dx0 ** 2 + 4 * dx0 * dxf + 4 * dxf ** 2)
def cubic_spline_coefficients(x0, dx0, xf, dxf, tf):
return (x0, dx0, -2 * dx0 / tf - dxf / tf - 3 * x0 / tf ** 2 + 3 * xf / tf ** 2,
dx0 / tf ** 2 + dxf / tf ** 2 + 2 * x0 / tf ** 3 - 2 * xf / tf ** 3)
def compute_interpolating_spline(state_0, state_f, tf):
dx0, dy0 = state_0[..., 2] * \
torch.cos(state_0[..., 3]), state_0[..., 2] * \
torch.sin(state_0[..., 3])
dxf, dyf = state_f[..., 2] * \
torch.cos(state_f[..., 3]), state_f[..., 2] * \
torch.sin(state_f[..., 3])
tf = tf * torch.ones_like(state_0[..., 0])
return (
torch.stack(cubic_spline_coefficients(
state_0[..., 0], dx0, state_f[..., 0], dxf, tf), -1),
torch.stack(cubic_spline_coefficients(
state_0[..., 1], dy0, state_f[..., 1], dyf, tf), -1),
tf,
)
def compute_spline_xyvaqrt(x_coefficients, y_coefficients, tf, N=10):
t = torch.arange(N).unsqueeze(0).to(tf.device) * tf.unsqueeze(-1) / (N - 1)
tp = t[..., None] ** torch.arange(4).to(tf.device)
dtp = t[..., None] ** torch.tensor([0, 0, 1, 2]
).to(tf.device) * torch.arange(4).to(tf.device)
ddtp = t[..., None] ** torch.tensor([0, 0, 0, 1]).to(
tf.device) * torch.tensor([0, 0, 2, 6]).to(tf.device)
x_coefficients = x_coefficients.unsqueeze(-1)
y_coefficients = y_coefficients.unsqueeze(-1)
vx = dtp @ x_coefficients
vy = dtp @ y_coefficients
v = torch.hypot(vx, vy)
v_pos = torch.clip(v, min=1e-4)
ax = ddtp @ x_coefficients
ay = ddtp @ y_coefficients
a = (ax * vx + ay * vy) / v_pos
r = (-ax * vy + ay * vx) / (v_pos ** 2)
yaw = torch.atan2(vy, vx)
return torch.cat((
tp @ x_coefficients,
tp @ y_coefficients,
v,
a,
yaw,
r,
t.unsqueeze(-1),
), -1)
def patch_yaw_low_speed(traj):
idx = traj[...,]
def mean_control_effort_coefficients(x0, dx0, xf, dxf):
"""Returns `(c4, c3, c2)` corresponding to `c4 * tf**-4 + c3 * tf**-3 + c2 * tf**-2`."""
return (12 * (x0 - xf) ** 2, 12 * (dx0 + dxf) * (x0 - xf), 4 * dx0 ** 2 + 4 * dx0 * dxf + 4 * dxf ** 2)
class SplinePlanner(object):
def __init__(self, device, dx_grid=None, dy_grid=None, acce_grid=None, dyaw_grid=None, max_steer=0.5, max_rvel=8,
acce_bound=[-6, 4], vbound=[-2.0, 30], spline_order=3, N_seg=10, low_speed_threshold=2.0, seed=0):
self.spline_order = spline_order
self.device = device
assert spline_order == 3
if dx_grid is None:
# self.dx_grid = torch.tensor([-4., 0, 4.]).to(self.device)
self.dx_grid = torch.tensor([0.]).to(self.device)
else:
self.dx_grid = torch.tensor(dx_grid).to(self.device)
if dy_grid is None:
self.dy_grid = torch.tensor([-3., -1.5, 0, 1.5, 3.]).to(self.device)
else:
self.dy_grid = torch.tensor(dy_grid).to(self.device)
self.dy_grid_lane = torch.tensor([-2., 0, 2., ]).to(self.device)
if acce_grid is None:
# self.acce_grid = torch.tensor([-1., -0.5, 0., 0.5, 1.]).to(self.device)
self.acce_grid = torch.tensor([-1., 0., 1.]).to(self.device)
else:
self.acce_grid = torch.tensor(acce_grid).to(self.device)
if dyaw_grid is None:
self.dyaw_grid = torch.tensor(
[-np.pi / 6, 0, np.pi / 6]).to(self.device)
else:
self.dyaw_grid = torch.tensor(dyaw_grid).to(self.device)
self.max_steer = max_steer
self.max_rvel = max_rvel
self.psi_bound = [-np.pi * 0.75, np.pi * 0.75]
self.acce_bound = acce_bound
self.vbound = vbound
self.N_seg = N_seg
self.low_speed_threshold = low_speed_threshold
self.forward_sampler = ForwardSampler(acce_grid=self.acce_grid, dhm_grid=torch.linspace(-0.7, 0.7, 9),
dhf_grid=[-0.4, 0, 0.4], dt=0.1, device=self.device)
torch.manual_seed(seed)
def calc_trajectories(self, x0, tf, xf, N=None):
if N is None:
N = self.N_seg
if x0.ndim == 1:
x0_tile = x0.tile(xf.shape[0], 1)
xc, yc, tf = compute_interpolating_spline(x0_tile, xf, tf)
elif x0.ndim == xf.ndim:
xc, yc, tf = compute_interpolating_spline(x0, xf, tf)
else:
raise ValueError("wrong dimension for x0")
traj = compute_spline_xyvaqrt(xc, yc, tf, N)
return traj
def gen_terminals_lane(self, x0, tf, lanes):
if lanes is None or len(lanes) == 0:
return self.gen_terminals(x0, tf)
gs = [self.dx_grid.shape[0], self.dy_grid_lane.shape[0], self.acce_grid.shape[0]]
dx = self.dx_grid[:, None, None, None].repeat(1, gs[1], gs[2], 1).flatten()
dy = self.dy_grid_lane[None, :, None, None].repeat(gs[0], 1, gs[2], 1).flatten()
dv = self.acce_grid[None, None, :, None].repeat(
gs[0], gs[1], 1, 1).flatten() * tf
xf = list()
if x0.ndim == 1:
for lane in lanes:
f, p_start = lane
if isinstance(p_start, np.ndarray):
p_start = torch.from_numpy(p_start).to(x0.device)
elif isinstance(p_start, torch.Tensor):
p_start = p_start.to(x0.device)
offset = x0[:2] - p_start[:2]
s_offset = offset[0] * \
torch.cos(p_start[2]) + offset[1] * torch.sin(p_start[2])
ds = dx + dv / 2 * tf + x0[2:3] * tf
ss = ds + s_offset
xyyaw = torch.from_numpy(f(ss.cpu().numpy())).type(
torch.float).to(x0.device)
xyyaw[..., 1] += dy
xf.append(
torch.cat((xyyaw[:, :2], dv.reshape(-1, 1) + x0[2:3], xyyaw[:, 2:]), -1))
# adding the end points not fixated on lane
xf_straight = torch.stack([ds, dy, dv + x0[2], x0[3].tile(ds.shape[0])], -1)
xf.append(xf_straight)
elif x0.ndim == 2:
for lane in lanes:
f, p_start = lane
if isinstance(p_start, np.ndarray):
p_start = torch.from_numpy(p_start).to(x0.device)
elif isinstance(p_start, torch.Tensor):
p_start = p_start.to(x0.device)
offset = x0[:, :2] - p_start[None, :2]
s_offset = offset[:, 0] * torch.cos(p_start[2]) + offset[:, 1] * torch.sin(p_start[2])
ds = (dx + dv / 2 * tf).unsqueeze(0) + x0[:, 2:3] * tf
ss = ds + s_offset.unsqueeze(-1)
xyyaw = torch.from_numpy(f(ss.cpu().numpy())).type(
torch.float).to(x0.device)
xyyaw[..., 1] += dy
xf.append(torch.cat((xyyaw[..., :2], dv.tile(
x0.shape[0], 1).unsqueeze(-1) + x0[:, None, 2:3], xyyaw[..., 2:]), -1))
# adding the end points not fixated on lane
xf_straight = torch.stack([ds, dy.tile(x0.shape[0], 1), dv.tile(x0.shape[0], 1) + x0[:, None, 2],
x0[:, None, 3].tile(1, ds.shape[1])], -1)
xf.append(xf_straight)
else:
raise ValueError("x0 must have dimension 1 or 2")
xf = torch.cat(xf, -2)
return xf
def gen_terminals(self, x0, tf):
gs = [self.dx_grid.shape[0], self.dy_grid.shape[0],
self.acce_grid.shape[0], self.dyaw_grid.shape[0]]
dx = self.dx_grid[:, None, None, None].repeat(
1, gs[1], gs[2], gs[3]).flatten()
dy = self.dy_grid[None, :, None, None].repeat(
gs[0], 1, gs[2], gs[3]).flatten()
dv = tf * self.acce_grid[None, None, :,
None].repeat(gs[0], gs[1], 1, gs[3]).flatten()
dyaw = self.dyaw_grid[None, None, None, :].repeat(
gs[0], gs[1], gs[2], 1).flatten()
delta_x = torch.stack([dx, dy, dv, dyaw], -1)
if x0.ndim == 1:
xy = torch.cat(
(delta_x[:, 0:1] + delta_x[:, 2:3] / 2 * tf + x0[2:3] * tf, delta_x[:, 1:2]), -1)
rotated_delta_xy = GeoUtils.batch_rotate_2D(xy, x0[3])
refpsi = torch.arctan2(rotated_delta_xy[..., 1], rotated_delta_xy[..., 0])
rotated_xy = rotated_delta_xy + x0[:2]
return torch.cat((rotated_xy, delta_x[:, 2:3] + x0[2:3], delta_x[:, 3:] + refpsi.unsqueeze(-1)), -1)
elif x0.ndim == 2:
delta_x = torch.tile(delta_x, [x0.shape[0], 1, 1])
xy = torch.cat(
(delta_x[:, :, 0:1] + delta_x[:, :, 2:3] / 2 * tf + x0[:, None, 2:3] * tf, delta_x[:, :, 1:2]), -1)
rotated_delta_xy = GeoUtils.batch_rotate_2D(xy, x0[:, 3:4])
refpsi = torch.arctan2(rotated_delta_xy[..., 1], rotated_delta_xy[..., 0])
rotated_xy = rotated_delta_xy + x0[:, None, :2]
return torch.cat(
(rotated_xy, delta_x[:, :, 2:3] + x0[:, None, 2:3], delta_x[:, :, 3:] + refpsi.unsqueeze(-1)), -1)
else:
raise ValueError("x0 must have dimension 1 or 2")
def feasible_flag(self, traj, xf):
diff = traj[..., -1, STATE_INDEX] - xf
feas_flag = ((traj[..., 2] >= self.vbound[0]) & (traj[..., 2] < self.vbound[1]) &
(traj[..., 4] >= self.psi_bound[0]) & (traj[..., 4] < self.psi_bound[1]) &
(traj[..., 3] >= self.acce_bound[0]) & (traj[..., 3] <= self.acce_bound[1]) &
(torch.abs(traj[..., 5] * traj[..., 2]) <= self.max_rvel) & (
torch.clip(torch.abs(traj[..., 2]), min=0.5) * self.max_steer >= torch.abs(
traj[..., 5]))).all(1) & (
diff.abs() < 5e-3).all(-1)
return feas_flag
def gen_trajectories(self, x0, tf, lanes=None, dyn_filter=True, N=None, lane_only=False):
if N is None:
N = self.N_seg
if lanes is not None:
if isinstance(lanes, torch.Tensor):
lanes = lanes.cpu().numpy()
lane_interp = [GeoUtils.interp_lanes(lane) for lane in lanes]
xf_lane = self.gen_terminals_lane(
x0, tf, lane_interp)
else:
xf_lane = None
if lane_only:
assert xf_lane is not None
xf_set = xf_lane
else:
xf_set = self.gen_terminals(x0, tf)
if xf_lane is not None:
xf_set = torch.cat((xf_lane, xf_set), 0)
x0[..., 2] = torch.clip(x0[..., 2], min=1e-3)
xf_set[..., 2] = torch.clip(xf_set[..., 2], min=1e-3)
# x, y, v, a, yaw,r, t
traj = self.calc_trajectories(x0, tf, xf_set, N)
if dyn_filter:
feas_flag = self.feasible_flag(traj, xf_set)
traj = traj[feas_flag]
xf = xf_set[feas_flag]
traj = traj[..., 1:, :] # remove the first time step
if x0[2] < self.low_speed_threshold:
# call forward sampler when velocity is low
extra_traj = self.forward_sampler.sample_trajs(x0.unsqueeze(0), int(tf / self.forward_sampler.dt)).squeeze(
0)
f = interp1d(np.arange(1, extra_traj.shape[-2] + 1) * self.forward_sampler.dt, extra_traj.cpu().numpy(),
axis=-2)
extra_traj = torch.from_numpy(f(np.arange(1, N) * tf / N)).to(self.device)
traj = torch.cat((traj, extra_traj), 0)
return traj, traj[..., -1, STATE_INDEX]
@staticmethod
def get_similarity_flag(x0, x1, thres=[2.0, 0.5, 2.0, np.pi / 12]):
thres = torch.tensor(thres, device=x0.device)
diff = x0.unsqueeze(-3) - x1.unsqueeze(-2)
flag = diff.abs() < thres
flag = flag.all(-1).any(-1)
return flag
def gen_terminals_hardcoded(self, x0_set, tf):
X0, Y0, v0, psi0 = x0_set[..., 0:1], x0_set[..., 1:2], x0_set[..., 2:3], x0_set[..., 3:]
xf_set = list()
# drive straight
xf_straight = torch.cat((X0 + v0 * tf * torch.cos(psi0), Y0 + v0 * tf * torch.sin(psi0), v0, psi0),
-1).unsqueeze(1)
xf_set.append(xf_straight)
# hard brake
decel = torch.clip(-v0 / tf, min=self.acce_bound[0])
xf_brake = torch.cat((X0 + (v0 + decel * 0.5 * tf) * tf * torch.cos(psi0),
Y0 + (v0 + decel * 0.5 * tf) * tf * torch.sin(psi0), v0 + decel * tf, psi0),
-1).unsqueeze(1)
xf_set.append(xf_brake)
xf_set = torch.cat(xf_set, 1)
return xf_set
def gen_trajectory_batch(self, x0_set, tf, lanes=None, dyn_filter=True, N=None, max_children=None):
if N is None:
N = self.N_seg
device = x0_set.device
xf_set_sample = self.gen_terminals(x0_set, tf)
importance_score = torch.rand(xf_set_sample.shape[:2], device=device)
xf_set_hardcoded = self.gen_terminals_hardcoded(x0_set, tf)
xf_set = torch.cat((xf_set_sample, xf_set_hardcoded), 1)
importance_score = torch.cat((importance_score, 2 * torch.ones(xf_set_hardcoded.shape[:2], device=device)), 1)
if lanes is not None:
lane_interp = [GeoUtils.interp_lanes(lane) for lane in lanes]
xf_set_lane = self.gen_terminals_lane(x0_set, tf, lane_interp)
xf_set = torch.cat((xf_set, xf_set_lane), -2)
importance_score = torch.cat((importance_score, torch.ones(xf_set_lane.shape[:2], device=x0_set.device)), 1)
x0_set[..., 2] = torch.clip(x0_set[..., 2], min=1e-3)
xf_set[..., 2] = torch.clip(xf_set[..., 2], min=1e-3)
num_node = x0_set.shape[0]
num = xf_set.shape[1]
x0_tiled = x0_set.repeat_interleave(num, 0)
xf_tiled = xf_set.reshape(-1, xf_set.shape[-1])
traj = self.calc_trajectories(x0_tiled, tf, xf_tiled, N)
if dyn_filter:
feas_flag = self.feasible_flag(traj, xf_tiled)
else:
feas_flag = torch.ones(
num * num_node, dtype=torch.bool).to(x0_set.device)
feas_flag = feas_flag.reshape(num_node, num)
traj = traj.reshape(num_node, num, *traj.shape[1:])
if (x0_set[:, 2] < self.low_speed_threshold).any():
extra_traj = self.forward_sampler.sample_trajs(x0_set, int(tf / self.forward_sampler.dt))
f = interp1d(np.arange(1, extra_traj.shape[-2] + 1) * self.forward_sampler.dt, extra_traj.cpu().numpy(),
axis=-2, bounds_error=False, fill_value="extrapolate")
extra_traj = torch.from_numpy(f(np.arange(0, N) * tf / N)).to(self.device)
traj = torch.cat((traj, extra_traj), 1)
extra_importance_score = torch.rand(extra_traj.shape[:2], device=device)
importance_score = torch.cat((importance_score, extra_importance_score), 1)
feas_flag = torch.cat((feas_flag, torch.ones(extra_traj.shape[:2], device=device)), 1)
importance_score = importance_score * feas_flag
chosen_idx = [torch.where(importance_score[i])[0].tolist() for i in range(num_node)]
if max_children is not None:
chosen_idx = [idx if len(idx) <= max_children else torch.topk(importance_score[i], max_children)[1] for
i, idx in enumerate(chosen_idx)]
traj_batch = [traj[i, chosen_idx[i], 1:] for i in range(num_node)]
return traj_batch
def gen_trajectory_tree(self, x0, tf, n_layers, dyn_filter=True, N=None):
if N is None:
N = self.N_seg
trajs = list()
nodes = [x0[None, :]]
for i in range(n_layers):
xf = self.gen_terminals(nodes[i], tf)
x0i = torch.tile(nodes[i], [xf.shape[1], 1])
xf = xf.reshape(-1, xf.shape[-1])
traj = self.calc_trajectories(x0i, tf, xf, N)
if dyn_filter:
feas_flag = self.feasible_flag(traj, xf)
traj = traj[feas_flag]
xf = xf[feas_flag]
trajs.append(traj)
nodes.append(xf.reshape(-1, xf.shape[-1]))
return trajs, nodes[1:]
if __name__ == "__main__":
planner = SplinePlanner("cuda")
x0 = torch.tensor([1., 2., 1., 0.]).cuda()
tf = 5
traj, xf = planner.gen_trajectories(x0, tf)
trajs = planner.gen_trajectory_batch(xf, tf)
# x, y, v, a, yaw,r, t = traj
msize = 12
trajs, nodes = planner.gen_trajectory_tree(x0, tf, 2)
x0 = x0.cpu().numpy()
traj = traj.cpu().numpy()
plt.figure(figsize=(20, 10))
plt.plot(x0[0], x0[1], marker="o", color="b", markersize=msize)
for node, traj in zip(nodes, trajs):
node = node.cpu().numpy()
traj = traj.cpu().numpy()
x = traj[..., 0]
y = traj[..., 1]
plt.plot(x.T, y.T, color="k")
for p in node:
plt.plot(p[0], p[1], marker="o", color="b", markersize=msize)
plt.show()