|
import io |
|
import tempfile |
|
import zipfile |
|
from collections import defaultdict |
|
from typing import Tuple, List, Dict, Any |
|
|
|
import cv2 |
|
import numpy as np |
|
import pycolmap |
|
from PIL import Image as PImage |
|
from scipy.spatial.distance import cdist |
|
from sklearn.cluster import DBSCAN |
|
|
|
from hoho2025.color_mappings import ade20k_color_mapping, gestalt_color_mapping |
|
|
|
|
|
class Config: |
|
"""Configuration for wireframe extraction pipeline.""" |
|
EDGE_THRESHOLD = 15.5 |
|
MERGE_THRESHOLD = 0.65 |
|
PRUNE_DISTANCE = 3.0 |
|
SEARCH_RADIUS = 14 |
|
MIN_EDGE_LENGTH = 0.15 |
|
MAX_EDGE_LENGTH = 27.0 |
|
MORPHOLOGY_KERNEL = 3 |
|
REFINEMENT_CLUSTER_EPS = 0.45 |
|
REFINEMENT_MIN_SAMPLES = 1 |
|
|
|
|
|
def empty_solution() -> Tuple[np.ndarray, List[Tuple[int, int]]]: |
|
"""Returns an empty wireframe solution.""" |
|
return np.zeros((2, 3)), [(0, 1)] |
|
|
|
|
|
def get_vertices_and_edges_from_segmentation( |
|
gest_seg_np: np.ndarray, edge_th: float = Config.EDGE_THRESHOLD |
|
) -> Tuple[List[Dict[str, Any]], List[Tuple[int, int]]]: |
|
""" |
|
Detects roof vertices and edges from a Gestalt segmentation mask. |
|
|
|
Args: |
|
gest_seg_np: Segmentation mask as a numpy array. |
|
edge_th: Distance threshold for associating edges to vertices. |
|
|
|
Returns: |
|
vertices: List of detected vertices with coordinates and type. |
|
connections: List of edge connections (vertex index pairs). |
|
""" |
|
vertices, connections = [], [] |
|
if not isinstance(gest_seg_np, np.ndarray): |
|
gest_seg_np = np.array(gest_seg_np) |
|
|
|
|
|
for v_type, color_name in [("apex", "apex"), ("eave_end_point", "eave_end_point")]: |
|
color = np.array(gestalt_color_mapping[color_name]) |
|
mask = cv2.inRange(gest_seg_np, color - 0.5, color + 0.5) |
|
if mask.sum() == 0: |
|
continue |
|
|
|
output = cv2.connectedComponentsWithStats(mask, 8, cv2.CV_32S) |
|
numLabels, labels, stats, centroids = output |
|
for i in range(1, numLabels): |
|
mask_i = (labels == i).astype(np.uint8) |
|
M = cv2.moments(mask_i) |
|
if M["m00"] > 0: |
|
cx, cy = M["m10"] / M["m00"], M["m01"] / M["m00"] |
|
else: |
|
ys, xs = np.where(mask_i) |
|
if len(xs) > 0: |
|
cx, cy = np.mean(xs), np.mean(ys) |
|
else: |
|
continue |
|
vertices.append({"xy": np.array([cx, cy]), "type": v_type}) |
|
|
|
if not vertices: |
|
return [], [] |
|
|
|
apex_pts = np.array([v['xy'] for v in vertices]) |
|
|
|
|
|
edge_classes = ['eave', 'ridge', 'rake', 'valley'] |
|
for edge_class in edge_classes: |
|
if edge_class not in gestalt_color_mapping: |
|
continue |
|
edge_color = np.array(gestalt_color_mapping[edge_class]) |
|
mask_raw = cv2.inRange(gest_seg_np, edge_color - 0.5, edge_color + 0.5) |
|
|
|
|
|
kernel = np.ones((Config.MORPHOLOGY_KERNEL, Config.MORPHOLOGY_KERNEL), np.uint8) |
|
mask = cv2.morphologyEx(mask_raw, cv2.MORPH_CLOSE, kernel) |
|
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((2, 2), np.uint8)) |
|
|
|
if mask.sum() == 0: |
|
continue |
|
|
|
output = cv2.connectedComponentsWithStats(mask, 8, cv2.CV_32S) |
|
numLabels, labels, stats, centroids = output |
|
for lbl in range(1, numLabels): |
|
ys, xs = np.where(labels == lbl) |
|
if len(xs) < 2: |
|
continue |
|
pts_for_fit = np.column_stack([xs, ys]).astype(np.float32) |
|
line_params = cv2.fitLine( |
|
pts_for_fit, distType=cv2.DIST_L2, param=0, reps=0.01, aeps=0.01 |
|
) |
|
vx, vy, x0, y0 = line_params.ravel() |
|
proj = ((xs - x0) * vx + (ys - y0) * vy) |
|
p1 = np.array([x0 + proj.min() * vx, y0 + proj.min() * vy]) |
|
p2 = np.array([x0 + proj.max() * vx, y0 + proj.max() * vy]) |
|
if len(apex_pts) < 2: |
|
continue |
|
dists = np.array([point_to_segment_dist(apex_pts[i], p1, p2) for i in range(len(apex_pts))]) |
|
near_indices = np.where(dists <= edge_th)[0] |
|
if len(near_indices) < 2: |
|
continue |
|
for i in range(len(near_indices)): |
|
for j in range(i + 1, len(near_indices)): |
|
conn = tuple(sorted((near_indices[i], near_indices[j]))) |
|
if conn not in connections: |
|
connections.append(conn) |
|
return vertices, connections |
|
|
|
|
|
def get_uv_depth( |
|
vertices: List[Dict[str, Any]], |
|
depth_fitted: np.ndarray, |
|
sparse_depth: np.ndarray, |
|
search_radius: int = Config.SEARCH_RADIUS, |
|
) -> Tuple[np.ndarray, np.ndarray]: |
|
""" |
|
Assigns depth to each vertex using a weighted search in the sparse depth map. |
|
|
|
Args: |
|
vertices: List of detected vertices. |
|
depth_fitted: Dense depth map. |
|
sparse_depth: Sparse depth map. |
|
search_radius: Search radius for depth assignment. |
|
|
|
Returns: |
|
uv: 2D coordinates of vertices. |
|
vertex_depth: Depth values for each vertex. |
|
""" |
|
uv = np.array([vert['xy'] for vert in vertices], dtype=np.float32) |
|
uv_int = np.round(uv).astype(np.int32) |
|
H, W = depth_fitted.shape[:2] |
|
uv_int[:, 0] = np.clip(uv_int[:, 0], 0, W - 1) |
|
uv_int[:, 1] = np.clip(uv_int[:, 1], 0, H - 1) |
|
vertex_depth = np.zeros(len(vertices), dtype=np.float32) |
|
|
|
for i, (x_i, y_i) in enumerate(uv_int): |
|
x0, x1 = max(0, x_i - search_radius), min(W, x_i + search_radius + 1) |
|
y0, y1 = max(0, y_i - search_radius), min(H, y_i + search_radius + 1) |
|
region = sparse_depth[y0:y1, x0:x1] |
|
valid_y, valid_x = np.where(region > 0) |
|
|
|
if valid_y.size > 0: |
|
global_x, global_y = x0 + valid_x, y0 + valid_y |
|
dist_sq = (global_x - x_i) ** 2 + (global_y - y_i) ** 2 |
|
weights = np.exp(-dist_sq / (2 * (search_radius / 3) ** 2)) |
|
vertex_depth[i] = np.sum(weights * region[valid_y, valid_x]) / np.sum(weights) |
|
else: |
|
vertex_depth[i] = depth_fitted[y_i, x_i] |
|
|
|
return uv, vertex_depth |
|
|
|
|
|
def read_colmap_rec(colmap_data: bytes) -> pycolmap.Reconstruction: |
|
"""Reads a COLMAP reconstruction from a zipped binary.""" |
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
with zipfile.ZipFile(io.BytesIO(colmap_data), "r") as zf: |
|
zf.extractall(tmpdir) |
|
rec = pycolmap.Reconstruction(tmpdir) |
|
return rec |
|
|
|
|
|
def convert_entry_to_human_readable(entry: Dict[str, Any]) -> Dict[str, Any]: |
|
"""Converts a raw entry to a human-readable format.""" |
|
out = {} |
|
for k, v in entry.items(): |
|
if k == 'colmap_binary': |
|
out[k] = read_colmap_rec(v) |
|
elif k in ['K', 'R', 't']: |
|
try: |
|
out[k] = np.array(v) |
|
except ValueError: |
|
out[k] = v |
|
else: |
|
out[k] = v |
|
out['__key__'] = entry.get('order_id', 'unknown_id') |
|
return out |
|
|
|
|
|
def get_house_mask(ade20k_seg: np.ndarray) -> np.ndarray: |
|
"""Returns a mask for house/building regions from ADE20K segmentation.""" |
|
house_classes_ade20k = [ |
|
'wall', 'house', 'building;edifice', 'door;double;door', 'windowpane;window' |
|
] |
|
np_seg = np.array(ade20k_seg) |
|
full_mask = np.zeros(np_seg.shape[:2], dtype=np.uint8) |
|
for c in house_classes_ade20k: |
|
if c in ade20k_color_mapping: |
|
color = np.array(ade20k_color_mapping[c]) |
|
mask = cv2.inRange(np_seg, color - 0.5, color + 0.5) |
|
full_mask = np.logical_or(full_mask, mask) |
|
return full_mask |
|
|
|
|
|
def point_to_segment_dist(pt: np.ndarray, seg_p1: np.ndarray, seg_p2: np.ndarray) -> float: |
|
"""Computes the distance from a point to a line segment.""" |
|
if np.allclose(seg_p1, seg_p2): |
|
return np.linalg.norm(pt - seg_p1) |
|
seg_vec = seg_p2 - seg_p1 |
|
pt_vec = pt - seg_p1 |
|
seg_len2 = seg_vec.dot(seg_vec) |
|
t = max(0, min(1, pt_vec.dot(seg_vec) / seg_len2)) |
|
proj = seg_p1 + t * seg_vec |
|
return np.linalg.norm(pt - proj) |
|
|
|
|
|
def get_sparse_depth( |
|
colmap_rec: pycolmap.Reconstruction, img_id_substring: str, depth: np.ndarray |
|
) -> Tuple[np.ndarray, bool, Any]: |
|
"""Projects COLMAP 3D points into the image to create a sparse depth map.""" |
|
H, W = depth.shape |
|
found_img = None |
|
for img_id_c, col_img in colmap_rec.images.items(): |
|
if img_id_substring in col_img.name: |
|
found_img = col_img |
|
break |
|
if found_img is None: |
|
return np.zeros((H, W), dtype=np.float32), False, None |
|
points_xyz = [ |
|
p3D.xyz for pid, p3D in colmap_rec.points3D.items() if found_img.has_point3D(pid) |
|
] |
|
if not points_xyz: |
|
return np.zeros((H, W), dtype=np.float32), False, found_img |
|
points_xyz = np.array(points_xyz) |
|
uv, z_vals = [], [] |
|
for xyz in points_xyz: |
|
proj = found_img.project_point(xyz) |
|
if proj is not None: |
|
u_i, v_i = int(round(proj[0])), int(round(proj[1])) |
|
if 0 <= u_i < W and 0 <= v_i < H: |
|
uv.append((u_i, v_i)) |
|
mat4x4 = np.eye(4) |
|
mat4x4[:3, :4] = found_img.cam_from_world.matrix() |
|
p_cam = mat4x4 @ np.array([xyz[0], xyz[1], xyz[2], 1.0]) |
|
z_vals.append(p_cam[2] / p_cam[3]) |
|
uv, z_vals = np.array(uv, dtype=int), np.array(z_vals) |
|
depth_out = np.zeros((H, W), dtype=np.float32) |
|
if len(uv) > 0: |
|
depth_out[uv[:, 1], uv[:, 0]] = z_vals |
|
return depth_out, True, found_img |
|
|
|
|
|
def fit_scale_robust_median( |
|
depth: np.ndarray, sparse_depth: np.ndarray, validity_mask: np.ndarray = None |
|
) -> Tuple[float, np.ndarray]: |
|
"""Fits a scale factor between dense and sparse depth using the median ratio.""" |
|
mask = (sparse_depth > 0.1) & (depth > 0.1) & (sparse_depth < 50) & (depth < 50) |
|
if validity_mask is not None: |
|
mask &= validity_mask |
|
X, Y = depth[mask], sparse_depth[mask] |
|
if len(X) < 5: |
|
return 1.0, depth |
|
ratios = Y / X |
|
alpha = np.median(ratios) |
|
return alpha, alpha * depth |
|
|
|
|
|
def get_fitted_dense_depth( |
|
depth: np.ndarray, colmap_rec: pycolmap.Reconstruction, img_id: str, ade20k_seg: np.ndarray |
|
) -> Tuple[np.ndarray, np.ndarray, bool, Any]: |
|
"""Fits the dense depth map to the sparse COLMAP depth.""" |
|
depth_np = np.array(depth) / 1000.0 |
|
depth_sparse, found_sparse, col_img = get_sparse_depth(colmap_rec, img_id, depth_np) |
|
if not found_sparse: |
|
return depth_np, np.zeros_like(depth_np), False, None |
|
house_mask = get_house_mask(ade20k_seg) |
|
k, depth_fitted = fit_scale_robust_median(depth_np, depth_sparse, validity_mask=house_mask) |
|
return depth_fitted, depth_sparse, True, col_img |
|
|
|
|
|
def project_vertices_to_3d( |
|
uv: np.ndarray, depth_vert: np.ndarray, col_img: Any |
|
) -> np.ndarray: |
|
"""Projects 2D vertices to 3D using camera intrinsics and depth.""" |
|
xy_local = np.ones((len(uv), 3)) |
|
K = col_img.camera.calibration_matrix() |
|
xy_local[:, 0] = (uv[:, 0] - K[0, 2]) / K[0, 0] |
|
xy_local[:, 1] = (uv[:, 1] - K[1, 2]) / K[1, 1] |
|
vertices_3d_local = xy_local * depth_vert[..., None] |
|
world_to_cam = np.eye(4) |
|
world_to_cam[:3] = col_img.cam_from_world.matrix() |
|
cam_to_world = np.linalg.inv(world_to_cam) |
|
vertices_3d_homogeneous = cv2.convertPointsToHomogeneous(vertices_3d_local) |
|
vertices_3d = cv2.transform(vertices_3d_homogeneous, cam_to_world) |
|
return cv2.convertPointsFromHomogeneous(vertices_3d).reshape(-1, 3) |
|
|
|
|
|
def merge_vertices_3d( |
|
vert_edge_per_image: Dict[int, Tuple[List[Dict[str, Any]], List[Tuple[int, int]], np.ndarray]], |
|
th: float = Config.MERGE_THRESHOLD, |
|
) -> Tuple[np.ndarray, List[Tuple[int, int]]]: |
|
""" |
|
Merges 3D vertices and edges across multiple views. |
|
|
|
Args: |
|
vert_edge_per_image: Dictionary of per-image vertices, edges, and 3D vertices. |
|
th: Distance threshold for merging. |
|
|
|
Returns: |
|
new_vertices: Merged 3D vertices. |
|
new_connections: Merged edge connections. |
|
""" |
|
all_3d_vertices, connections_3d, types = [], [], [] |
|
cur_start = 0 |
|
for cimg_idx, (vertices, connections, vertices_3d) in vert_edge_per_image.items(): |
|
if len(vertices) == 0 or len(vertices_3d) == 0: |
|
continue |
|
types += [int(v['type'] == 'apex') for v in vertices] |
|
all_3d_vertices.append(vertices_3d) |
|
connections_3d += [(x + cur_start, y + cur_start) for (x, y) in connections] |
|
cur_start += len(vertices_3d) |
|
if len(all_3d_vertices) == 0: |
|
return np.array([]), [] |
|
all_3d_vertices = np.concatenate(all_3d_vertices, axis=0) |
|
if len(all_3d_vertices) == 0: |
|
return np.array([]), [] |
|
distmat = cdist(all_3d_vertices, all_3d_vertices) |
|
types = np.array(types).reshape(-1, 1) |
|
same_types = cdist(types, types) |
|
mask_to_merge = (distmat <= th) & (same_types == 0) |
|
new_vertices, new_connections = [], [] |
|
to_merge = sorted(list(set([tuple(a.nonzero()[0].tolist()) for a in mask_to_merge]))) |
|
to_merge_final = defaultdict(list) |
|
for i in range(len(all_3d_vertices)): |
|
for j in to_merge: |
|
if i in j: |
|
to_merge_final[i] += j |
|
for k, v in to_merge_final.items(): |
|
to_merge_final[k] = list(set(v)) |
|
already_there, merged = set(), [] |
|
for k, v in to_merge_final.items(): |
|
if k in already_there: |
|
continue |
|
merged.append(v) |
|
for vv in v: |
|
already_there.add(vv) |
|
old_idx_to_new = {} |
|
count = 0 |
|
for idxs in merged: |
|
if len(idxs) > 0: |
|
new_vertices.append(all_3d_vertices[idxs].mean(axis=0)) |
|
for idx in idxs: |
|
old_idx_to_new[idx] = count |
|
count += 1 |
|
if len(new_vertices) == 0: |
|
return np.array([]), [] |
|
new_vertices = np.array(new_vertices) |
|
for conn in connections_3d: |
|
if conn[0] in old_idx_to_new and conn[1] in old_idx_to_new: |
|
new_con = sorted((old_idx_to_new[conn[0]], old_idx_to_new[conn[1]])) |
|
if new_con[0] != new_con[1] and new_con not in new_connections: |
|
new_connections.append(new_con) |
|
return new_vertices, new_connections |
|
|
|
|
|
def prune_too_far( |
|
all_3d_vertices: np.ndarray, connections_3d: List[Tuple[int, int]], colmap_rec: pycolmap.Reconstruction, th: float = Config.PRUNE_DISTANCE |
|
) -> Tuple[np.ndarray, List[Tuple[int, int]]]: |
|
"""Prunes 3D vertices and edges that are too far from COLMAP points.""" |
|
if len(all_3d_vertices) == 0: |
|
return all_3d_vertices, connections_3d |
|
xyz_sfm = np.array([v.xyz for v in colmap_rec.points3D.values()]) |
|
if len(xyz_sfm) == 0: |
|
return all_3d_vertices, connections_3d |
|
distmat = cdist(all_3d_vertices, xyz_sfm) |
|
mask = distmat.min(axis=1) <= th |
|
if not np.any(mask): |
|
return np.empty((0, 3)), [] |
|
old_to_new_idx = {old: new for new, old in enumerate(np.where(mask)[0])} |
|
new_vertices = all_3d_vertices[mask] |
|
new_connections = [] |
|
for u, v in connections_3d: |
|
if u in old_to_new_idx and v in old_to_new_idx: |
|
new_connections.append((old_to_new_idx[u], old_to_new_idx[v])) |
|
return new_vertices, new_connections |
|
|
|
|
|
def metric_aware_refine( |
|
vertices: np.ndarray, edges: List[Tuple[int, int]] |
|
) -> Tuple[np.ndarray, List[Tuple[int, int]]]: |
|
""" |
|
Refines vertices and edges using clustering and edge length constraints. |
|
|
|
Args: |
|
vertices: 3D vertex coordinates. |
|
edges: List of edge connections. |
|
|
|
Returns: |
|
final_vertices: Refined 3D vertices. |
|
final_edges: Refined edge connections. |
|
""" |
|
if len(vertices) < 2: |
|
return vertices, edges |
|
clustering = DBSCAN( |
|
eps=Config.REFINEMENT_CLUSTER_EPS, min_samples=Config.REFINEMENT_MIN_SAMPLES |
|
).fit(vertices) |
|
labels = clustering.labels_ |
|
refined_vertices = vertices.copy() |
|
refined_centers = {} |
|
for label in set(labels): |
|
if label == -1: |
|
continue |
|
refined_centers[label] = np.mean(vertices[labels == label], axis=0) |
|
for i, label in enumerate(labels): |
|
if label in refined_centers: |
|
refined_vertices[i] = refined_centers[label] |
|
final_edges_set = set() |
|
for u, v in edges: |
|
if u >= len(refined_vertices) or v >= len(refined_vertices): |
|
continue |
|
p1 = refined_vertices[u] |
|
p2 = refined_vertices[v] |
|
dist = np.linalg.norm(p1 - p2) |
|
if Config.MIN_EDGE_LENGTH <= dist <= Config.MAX_EDGE_LENGTH: |
|
if not np.allclose(p1, p2, atol=1e-3): |
|
final_edges_set.add(tuple(sorted((u, v)))) |
|
if not final_edges_set: |
|
return np.empty((0, 3)), [] |
|
used_idxs = set(u for u, v in final_edges_set) | set(v for u, v in final_edges_set) |
|
sorted_used_idxs = sorted(list(used_idxs)) |
|
final_map = {old_id: new_id for new_id, old_id in enumerate(sorted_used_idxs)} |
|
final_vertices = np.array([refined_vertices[old_id] for old_id in sorted_used_idxs]) |
|
final_edges = [(final_map[u], final_map[v]) for u, v in final_edges_set] |
|
return final_vertices, final_edges |
|
|
|
|
|
def predict_wireframe(entry: Dict[str, Any]) -> Tuple[np.ndarray, List[Tuple[int, int]]]: |
|
""" |
|
Main prediction function for the S23DR wireframe challenge. |
|
|
|
Args: |
|
entry: Input data entry. |
|
|
|
Returns: |
|
final_vertices: 3D vertices of the wireframe. |
|
final_edges: Edge connections. |
|
""" |
|
try: |
|
good_entry = convert_entry_to_human_readable(entry) |
|
colmap_rec = good_entry['colmap_binary'] |
|
vert_edge_per_image = {} |
|
for i, (gest_img, img_id, ade_seg, depth_img) in enumerate( |
|
zip( |
|
good_entry['gestalt'], |
|
good_entry['image_ids'], |
|
good_entry['ade'], |
|
good_entry['depth'], |
|
) |
|
): |
|
depth_size = (gest_img.width, gest_img.height) |
|
gest_seg_np = np.array(gest_img.resize(depth_size)).astype(np.uint8) |
|
vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np) |
|
if not vertices: |
|
vert_edge_per_image[i] = ([], [], []) |
|
else: |
|
depth_fitted, depth_sparse, found_sparse, col_img = get_fitted_dense_depth( |
|
depth_img, colmap_rec, img_id, ade_seg |
|
) |
|
if found_sparse and col_img is not None: |
|
uv, depth_vert = get_uv_depth(vertices, depth_fitted, depth_sparse) |
|
vertices_3d = project_vertices_to_3d(uv, depth_vert, col_img) |
|
vert_edge_per_image[i] = (vertices, connections, vertices_3d) |
|
else: |
|
vert_edge_per_image[i] = (vertices, connections, []) |
|
all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image) |
|
all_3d_vertices, connections_3d = prune_too_far( |
|
all_3d_vertices, connections_3d, colmap_rec |
|
) |
|
final_vertices, final_edges = metric_aware_refine( |
|
all_3d_vertices, connections_3d |
|
) |
|
if len(final_vertices) < 2 or len(final_edges) < 1: |
|
return empty_solution() |
|
return final_vertices, final_edges |
|
except Exception as e: |
|
print(f"An error occurred in the main prediction pipeline: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return empty_solution() |