|
import time |
|
import numpy as np |
|
import xml.etree.ElementTree as ET |
|
|
|
from jax import jit, vmap |
|
import jax.numpy as jnp |
|
|
|
|
|
def rot_x(o): |
|
return jnp.array([ |
|
[1, 0, 0], |
|
[0, jnp.cos(o), jnp.sin(o)], |
|
[0, -jnp.sin(o), jnp.cos(o)]]) |
|
|
|
|
|
def rot_y(p): |
|
return jnp.array([ |
|
[jnp.cos(p), 0, -jnp.sin(p)], |
|
[0, 1, 0], |
|
[jnp.sin(p), 0, jnp.cos(p)]]) |
|
|
|
|
|
def rot_z(k): |
|
return jnp.array([ |
|
[jnp.cos(k), jnp.sin(k), 0], |
|
[-jnp.sin(k), jnp.cos(k), 0], |
|
[0, 0, 1]]) |
|
|
|
|
|
def rot_zyx(o, p, k): |
|
return rot_z(k) @ rot_y(p) @ rot_x(o) |
|
|
|
|
|
def rms_dict(_s): |
|
R = rot_zyx(_s['omega'], _s['phi'], _s['kappa']) |
|
M = jnp.array([_s['X'], _s['Y'], _s['Z']]) |
|
S = jnp.array([_s['Xs'], _s['Ys'], _s['Zs']]) |
|
RMS = R @ (M - S) |
|
return RMS |
|
|
|
|
|
def xy_frame(_s): |
|
RMS = rms_dict(_s) |
|
m = - RMS / RMS[2] |
|
x, y, _ = m |
|
z = -RMS[2] |
|
return x, y, z |
|
|
|
|
|
def corr_dist_agi(x, y, _s): |
|
rc = x ** 2 + y ** 2 |
|
dr = 1 + _s['k1'] * rc + _s['k2'] * rc ** 2 + _s['k3'] * rc ** 3 + _s['k4'] * rc ** 4 + _s['k5'] * rc ** 5 |
|
drx = x * dr |
|
dry = y * dr |
|
|
|
dtx = _s['p1'] * (rc + 2 * x ** 2) + 2 * _s['p2'] * x * y * (1 + _s['p3'] * rc + _s['p4'] * rc ** 2) |
|
dty = _s['p2'] * (rc + 2 * y ** 2) + 2 * _s['p1'] * x * y * (1 + _s['p3'] * rc + _s['p4'] * rc ** 2) |
|
xp = drx + dtx |
|
yp = dry + dty |
|
|
|
fx = _s['width'] * 0.5 + _s['cx'] + xp * _s['f'] + xp * _s['b1'] + yp * _s['b2'] |
|
fy = _s['height'] * 0.5 + _s['cy'] + yp * _s['f'] |
|
return fx, fy |
|
|
|
|
|
def f_frame_agi(_s): |
|
x, y, z = xy_frame(_s) |
|
w_2 = _s['width'] / _s['f'] / 2 |
|
h_2 = _s['height'] / _s['f'] / 2 |
|
ins = (x >= -w_2) & (x < w_2) & (y >= -h_2) & (y < h_2) & (z > 0) |
|
y = -y |
|
fx, fy = corr_dist_agi(x, y, _s) |
|
return fx, fy, z, ins |
|
|
|
|
|
def read_camera_file(filepath, offset): |
|
data = {} |
|
with open(filepath, 'r') as file: |
|
for line in file: |
|
if line.startswith("#"): |
|
continue |
|
values = line.strip().split() |
|
if len(values) < 16: |
|
continue |
|
photo_id = values[0] |
|
data[photo_id] = { |
|
"Xs": float(values[1]) - offset[0], |
|
"Ys": float(values[2]) - offset[1], |
|
"Zs": float(values[3]) - offset[2], |
|
"omega": np.radians(float(values[4])), |
|
"phi": np.radians(float(values[5])), |
|
"kappa": np.radians(float(values[6])), |
|
} |
|
return data |
|
|
|
def parse_calibration_xml(file_path): |
|
tree = ET.parse(file_path) |
|
root = tree.getroot() |
|
|
|
|
|
calibration_data = { |
|
|
|
'width': 0, |
|
'height': 0, |
|
'f': 0.0, |
|
'cx': 0.0, |
|
'cy': 0.0, |
|
'k1': 0.0, |
|
'k2': 0.0, |
|
'k3': 0.0, |
|
'k4': 0.0, |
|
'k5': 0.0, |
|
'p1': 0.0, |
|
'p2': 0.0, |
|
'p3': 0.0, |
|
'p4': 0.0, |
|
'b1': 0.0, |
|
'b2': 0.0, |
|
|
|
} |
|
|
|
for element in root: |
|
if element.tag in calibration_data: |
|
if element.tag in ['projection', 'date']: |
|
continue |
|
elif element.tag in ['width', 'height']: |
|
calibration_data[element.tag] = int(element.text) |
|
else: |
|
calibration_data[element.tag] = float(element.text) |
|
|
|
return calibration_data |
|
|
|
|
|
def get_pixel_values(image, i, j, in_bounds, nb_classes=10): |
|
""" |
|
Retrieve pixel values from `image` at specified `coords`, marking out-of-bounds |
|
coordinates with `jnp.nan`. Uses only `jax.numpy` operations. |
|
|
|
Parameters: |
|
image (jnp.ndarray): 2D image array. |
|
fx, fy (jnp.ndarray): 2 arrays of shape (N, 1) representing i and j coordinates. |
|
|
|
Returns: |
|
jnp.ndarray: Array of pixel values with `jnp.nan` for out-of-bounds coordinates. |
|
""" |
|
|
|
pixel_values = jnp.full((i.shape[0], nb_classes), jnp.nan) |
|
values = image[j[in_bounds], i[in_bounds]] |
|
pixel_values = pixel_values.at[in_bounds].set(image[j[in_bounds], i[in_bounds]]) |
|
pixel_values = pixel_values.at[in_bounds].set(values) |
|
return pixel_values |
|
|
|
|
|
def compute_depth_map(i, j, z, depth_map, buffer_size, threshold=0.05): |
|
height, width = depth_map.shape |
|
|
|
offsets = jnp.arange(-buffer_size, buffer_size + 1) |
|
|
|
di, dj = jnp.meshgrid(offsets, offsets, indexing='ij') |
|
di = di.ravel() |
|
dj = dj.ravel() |
|
|
|
neighbor_i = (i[:, None] + di).clip(0, width - 1) |
|
neighbor_j = (j[:, None] + dj).clip(0, height - 1) |
|
neighbor_depths = jnp.repeat(z[:, None], len(di), axis=1) |
|
|
|
neighbor_i = neighbor_i.ravel() |
|
neighbor_j = neighbor_j.ravel() |
|
neighbor_depths = neighbor_depths.ravel() |
|
|
|
depth_map = depth_map.at[neighbor_j, neighbor_i].min(neighbor_depths) |
|
|
|
visibility = jnp.abs(depth_map[j, i] - z) <= threshold |
|
return depth_map, visibility |
|
|
|
|
|
def project_classes_into_image(i, j, classes, img_classes, buffer_size): |
|
|
|
height, width = img_classes.shape |
|
|
|
offsets = jnp.arange(-buffer_size, buffer_size + 1) |
|
|
|
di, dj = jnp.meshgrid(offsets, offsets, indexing='ij') |
|
di = di.ravel() |
|
dj = dj.ravel() |
|
|
|
neighbor_i = (i[:, None] + di).clip(0, width - 1) |
|
neighbor_j = (j[:, None] + dj).clip(0, height - 1) |
|
neighbor_classes = jnp.repeat(classes[:, None], len(di), axis=1) |
|
|
|
neighbor_i = neighbor_i.ravel() |
|
neighbor_j = neighbor_j.ravel() |
|
neighbor_classes = neighbor_classes.ravel() |
|
|
|
img_classes = img_classes.at[neighbor_j, neighbor_i].set(neighbor_classes) |
|
return img_classes |
|
|