import io import tempfile import zipfile from collections import defaultdict from typing import Tuple, List, Dict, Any import cv2 import numpy as np import pycolmap from PIL import Image as PImage from scipy.spatial.distance import cdist from sklearn.cluster import DBSCAN from hoho2025.color_mappings import ade20k_color_mapping, gestalt_color_mapping class Config: """Configuration for wireframe extraction pipeline.""" EDGE_THRESHOLD = 15.5 MERGE_THRESHOLD = 0.65 PRUNE_DISTANCE = 3.0 SEARCH_RADIUS = 14 MIN_EDGE_LENGTH = 0.15 MAX_EDGE_LENGTH = 27.0 MORPHOLOGY_KERNEL = 3 REFINEMENT_CLUSTER_EPS = 0.45 REFINEMENT_MIN_SAMPLES = 1 def empty_solution() -> Tuple[np.ndarray, List[Tuple[int, int]]]: """Returns an empty wireframe solution.""" return np.zeros((2, 3)), [(0, 1)] def get_vertices_and_edges_from_segmentation( gest_seg_np: np.ndarray, edge_th: float = Config.EDGE_THRESHOLD ) -> Tuple[List[Dict[str, Any]], List[Tuple[int, int]]]: """ Detects roof vertices and edges from a Gestalt segmentation mask. Args: gest_seg_np: Segmentation mask as a numpy array. edge_th: Distance threshold for associating edges to vertices. Returns: vertices: List of detected vertices with coordinates and type. connections: List of edge connections (vertex index pairs). """ vertices, connections = [], [] if not isinstance(gest_seg_np, np.ndarray): gest_seg_np = np.array(gest_seg_np) # Vertex detection for v_type, color_name in [("apex", "apex"), ("eave_end_point", "eave_end_point")]: color = np.array(gestalt_color_mapping[color_name]) mask = cv2.inRange(gest_seg_np, color - 0.5, color + 0.5) if mask.sum() == 0: continue output = cv2.connectedComponentsWithStats(mask, 8, cv2.CV_32S) numLabels, labels, stats, centroids = output for i in range(1, numLabels): mask_i = (labels == i).astype(np.uint8) M = cv2.moments(mask_i) if M["m00"] > 0: cx, cy = M["m10"] / M["m00"], M["m01"] / M["m00"] else: ys, xs = np.where(mask_i) if len(xs) > 0: cx, cy = np.mean(xs), np.mean(ys) else: continue vertices.append({"xy": np.array([cx, cy]), "type": v_type}) if not vertices: return [], [] apex_pts = np.array([v['xy'] for v in vertices]) # Edge detection edge_classes = ['eave', 'ridge', 'rake', 'valley'] for edge_class in edge_classes: if edge_class not in gestalt_color_mapping: continue edge_color = np.array(gestalt_color_mapping[edge_class]) mask_raw = cv2.inRange(gest_seg_np, edge_color - 0.5, edge_color + 0.5) # Improved morphology: close then open kernel = np.ones((Config.MORPHOLOGY_KERNEL, Config.MORPHOLOGY_KERNEL), np.uint8) mask = cv2.morphologyEx(mask_raw, cv2.MORPH_CLOSE, kernel) mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((2, 2), np.uint8)) if mask.sum() == 0: continue output = cv2.connectedComponentsWithStats(mask, 8, cv2.CV_32S) numLabels, labels, stats, centroids = output for lbl in range(1, numLabels): ys, xs = np.where(labels == lbl) if len(xs) < 2: continue pts_for_fit = np.column_stack([xs, ys]).astype(np.float32) line_params = cv2.fitLine( pts_for_fit, distType=cv2.DIST_L2, param=0, reps=0.01, aeps=0.01 ) vx, vy, x0, y0 = line_params.ravel() proj = ((xs - x0) * vx + (ys - y0) * vy) p1 = np.array([x0 + proj.min() * vx, y0 + proj.min() * vy]) p2 = np.array([x0 + proj.max() * vx, y0 + proj.max() * vy]) if len(apex_pts) < 2: continue dists = np.array([point_to_segment_dist(apex_pts[i], p1, p2) for i in range(len(apex_pts))]) near_indices = np.where(dists <= edge_th)[0] if len(near_indices) < 2: continue for i in range(len(near_indices)): for j in range(i + 1, len(near_indices)): conn = tuple(sorted((near_indices[i], near_indices[j]))) if conn not in connections: connections.append(conn) return vertices, connections def get_uv_depth( vertices: List[Dict[str, Any]], depth_fitted: np.ndarray, sparse_depth: np.ndarray, search_radius: int = Config.SEARCH_RADIUS, ) -> Tuple[np.ndarray, np.ndarray]: """ Assigns depth to each vertex using a weighted search in the sparse depth map. Args: vertices: List of detected vertices. depth_fitted: Dense depth map. sparse_depth: Sparse depth map. search_radius: Search radius for depth assignment. Returns: uv: 2D coordinates of vertices. vertex_depth: Depth values for each vertex. """ uv = np.array([vert['xy'] for vert in vertices], dtype=np.float32) uv_int = np.round(uv).astype(np.int32) H, W = depth_fitted.shape[:2] uv_int[:, 0] = np.clip(uv_int[:, 0], 0, W - 1) uv_int[:, 1] = np.clip(uv_int[:, 1], 0, H - 1) vertex_depth = np.zeros(len(vertices), dtype=np.float32) for i, (x_i, y_i) in enumerate(uv_int): x0, x1 = max(0, x_i - search_radius), min(W, x_i + search_radius + 1) y0, y1 = max(0, y_i - search_radius), min(H, y_i + search_radius + 1) region = sparse_depth[y0:y1, x0:x1] valid_y, valid_x = np.where(region > 0) if valid_y.size > 0: global_x, global_y = x0 + valid_x, y0 + valid_y dist_sq = (global_x - x_i) ** 2 + (global_y - y_i) ** 2 weights = np.exp(-dist_sq / (2 * (search_radius / 3) ** 2)) vertex_depth[i] = np.sum(weights * region[valid_y, valid_x]) / np.sum(weights) else: vertex_depth[i] = depth_fitted[y_i, x_i] return uv, vertex_depth def read_colmap_rec(colmap_data: bytes) -> pycolmap.Reconstruction: """Reads a COLMAP reconstruction from a zipped binary.""" with tempfile.TemporaryDirectory() as tmpdir: with zipfile.ZipFile(io.BytesIO(colmap_data), "r") as zf: zf.extractall(tmpdir) rec = pycolmap.Reconstruction(tmpdir) return rec def convert_entry_to_human_readable(entry: Dict[str, Any]) -> Dict[str, Any]: """Converts a raw entry to a human-readable format.""" out = {} for k, v in entry.items(): if k == 'colmap_binary': out[k] = read_colmap_rec(v) elif k in ['K', 'R', 't']: try: out[k] = np.array(v) except ValueError: out[k] = v else: out[k] = v out['__key__'] = entry.get('order_id', 'unknown_id') return out def get_house_mask(ade20k_seg: np.ndarray) -> np.ndarray: """Returns a mask for house/building regions from ADE20K segmentation.""" house_classes_ade20k = [ 'wall', 'house', 'building;edifice', 'door;double;door', 'windowpane;window' ] np_seg = np.array(ade20k_seg) full_mask = np.zeros(np_seg.shape[:2], dtype=np.uint8) for c in house_classes_ade20k: if c in ade20k_color_mapping: color = np.array(ade20k_color_mapping[c]) mask = cv2.inRange(np_seg, color - 0.5, color + 0.5) full_mask = np.logical_or(full_mask, mask) return full_mask def point_to_segment_dist(pt: np.ndarray, seg_p1: np.ndarray, seg_p2: np.ndarray) -> float: """Computes the distance from a point to a line segment.""" if np.allclose(seg_p1, seg_p2): return np.linalg.norm(pt - seg_p1) seg_vec = seg_p2 - seg_p1 pt_vec = pt - seg_p1 seg_len2 = seg_vec.dot(seg_vec) t = max(0, min(1, pt_vec.dot(seg_vec) / seg_len2)) proj = seg_p1 + t * seg_vec return np.linalg.norm(pt - proj) def get_sparse_depth( colmap_rec: pycolmap.Reconstruction, img_id_substring: str, depth: np.ndarray ) -> Tuple[np.ndarray, bool, Any]: """Projects COLMAP 3D points into the image to create a sparse depth map.""" H, W = depth.shape found_img = None for img_id_c, col_img in colmap_rec.images.items(): if img_id_substring in col_img.name: found_img = col_img break if found_img is None: return np.zeros((H, W), dtype=np.float32), False, None points_xyz = [ p3D.xyz for pid, p3D in colmap_rec.points3D.items() if found_img.has_point3D(pid) ] if not points_xyz: return np.zeros((H, W), dtype=np.float32), False, found_img points_xyz = np.array(points_xyz) uv, z_vals = [], [] for xyz in points_xyz: proj = found_img.project_point(xyz) if proj is not None: u_i, v_i = int(round(proj[0])), int(round(proj[1])) if 0 <= u_i < W and 0 <= v_i < H: uv.append((u_i, v_i)) mat4x4 = np.eye(4) mat4x4[:3, :4] = found_img.cam_from_world.matrix() p_cam = mat4x4 @ np.array([xyz[0], xyz[1], xyz[2], 1.0]) z_vals.append(p_cam[2] / p_cam[3]) uv, z_vals = np.array(uv, dtype=int), np.array(z_vals) depth_out = np.zeros((H, W), dtype=np.float32) if len(uv) > 0: depth_out[uv[:, 1], uv[:, 0]] = z_vals return depth_out, True, found_img def fit_scale_robust_median( depth: np.ndarray, sparse_depth: np.ndarray, validity_mask: np.ndarray = None ) -> Tuple[float, np.ndarray]: """Fits a scale factor between dense and sparse depth using the median ratio.""" mask = (sparse_depth > 0.1) & (depth > 0.1) & (sparse_depth < 50) & (depth < 50) if validity_mask is not None: mask &= validity_mask X, Y = depth[mask], sparse_depth[mask] if len(X) < 5: return 1.0, depth ratios = Y / X alpha = np.median(ratios) return alpha, alpha * depth def get_fitted_dense_depth( depth: np.ndarray, colmap_rec: pycolmap.Reconstruction, img_id: str, ade20k_seg: np.ndarray ) -> Tuple[np.ndarray, np.ndarray, bool, Any]: """Fits the dense depth map to the sparse COLMAP depth.""" depth_np = np.array(depth) / 1000.0 depth_sparse, found_sparse, col_img = get_sparse_depth(colmap_rec, img_id, depth_np) if not found_sparse: return depth_np, np.zeros_like(depth_np), False, None house_mask = get_house_mask(ade20k_seg) k, depth_fitted = fit_scale_robust_median(depth_np, depth_sparse, validity_mask=house_mask) return depth_fitted, depth_sparse, True, col_img def project_vertices_to_3d( uv: np.ndarray, depth_vert: np.ndarray, col_img: Any ) -> np.ndarray: """Projects 2D vertices to 3D using camera intrinsics and depth.""" xy_local = np.ones((len(uv), 3)) K = col_img.camera.calibration_matrix() xy_local[:, 0] = (uv[:, 0] - K[0, 2]) / K[0, 0] xy_local[:, 1] = (uv[:, 1] - K[1, 2]) / K[1, 1] vertices_3d_local = xy_local * depth_vert[..., None] world_to_cam = np.eye(4) world_to_cam[:3] = col_img.cam_from_world.matrix() cam_to_world = np.linalg.inv(world_to_cam) vertices_3d_homogeneous = cv2.convertPointsToHomogeneous(vertices_3d_local) vertices_3d = cv2.transform(vertices_3d_homogeneous, cam_to_world) return cv2.convertPointsFromHomogeneous(vertices_3d).reshape(-1, 3) def merge_vertices_3d( vert_edge_per_image: Dict[int, Tuple[List[Dict[str, Any]], List[Tuple[int, int]], np.ndarray]], th: float = Config.MERGE_THRESHOLD, ) -> Tuple[np.ndarray, List[Tuple[int, int]]]: """ Merges 3D vertices and edges across multiple views. Args: vert_edge_per_image: Dictionary of per-image vertices, edges, and 3D vertices. th: Distance threshold for merging. Returns: new_vertices: Merged 3D vertices. new_connections: Merged edge connections. """ all_3d_vertices, connections_3d, types = [], [], [] cur_start = 0 for cimg_idx, (vertices, connections, vertices_3d) in vert_edge_per_image.items(): if len(vertices) == 0 or len(vertices_3d) == 0: continue types += [int(v['type'] == 'apex') for v in vertices] all_3d_vertices.append(vertices_3d) connections_3d += [(x + cur_start, y + cur_start) for (x, y) in connections] cur_start += len(vertices_3d) if len(all_3d_vertices) == 0: return np.array([]), [] all_3d_vertices = np.concatenate(all_3d_vertices, axis=0) if len(all_3d_vertices) == 0: return np.array([]), [] distmat = cdist(all_3d_vertices, all_3d_vertices) types = np.array(types).reshape(-1, 1) same_types = cdist(types, types) mask_to_merge = (distmat <= th) & (same_types == 0) new_vertices, new_connections = [], [] to_merge = sorted(list(set([tuple(a.nonzero()[0].tolist()) for a in mask_to_merge]))) to_merge_final = defaultdict(list) for i in range(len(all_3d_vertices)): for j in to_merge: if i in j: to_merge_final[i] += j for k, v in to_merge_final.items(): to_merge_final[k] = list(set(v)) already_there, merged = set(), [] for k, v in to_merge_final.items(): if k in already_there: continue merged.append(v) for vv in v: already_there.add(vv) old_idx_to_new = {} count = 0 for idxs in merged: if len(idxs) > 0: new_vertices.append(all_3d_vertices[idxs].mean(axis=0)) for idx in idxs: old_idx_to_new[idx] = count count += 1 if len(new_vertices) == 0: return np.array([]), [] new_vertices = np.array(new_vertices) for conn in connections_3d: if conn[0] in old_idx_to_new and conn[1] in old_idx_to_new: new_con = sorted((old_idx_to_new[conn[0]], old_idx_to_new[conn[1]])) if new_con[0] != new_con[1] and new_con not in new_connections: new_connections.append(new_con) return new_vertices, new_connections def prune_too_far( all_3d_vertices: np.ndarray, connections_3d: List[Tuple[int, int]], colmap_rec: pycolmap.Reconstruction, th: float = Config.PRUNE_DISTANCE ) -> Tuple[np.ndarray, List[Tuple[int, int]]]: """Prunes 3D vertices and edges that are too far from COLMAP points.""" if len(all_3d_vertices) == 0: return all_3d_vertices, connections_3d xyz_sfm = np.array([v.xyz for v in colmap_rec.points3D.values()]) if len(xyz_sfm) == 0: return all_3d_vertices, connections_3d distmat = cdist(all_3d_vertices, xyz_sfm) mask = distmat.min(axis=1) <= th if not np.any(mask): return np.empty((0, 3)), [] old_to_new_idx = {old: new for new, old in enumerate(np.where(mask)[0])} new_vertices = all_3d_vertices[mask] new_connections = [] for u, v in connections_3d: if u in old_to_new_idx and v in old_to_new_idx: new_connections.append((old_to_new_idx[u], old_to_new_idx[v])) return new_vertices, new_connections def metric_aware_refine( vertices: np.ndarray, edges: List[Tuple[int, int]] ) -> Tuple[np.ndarray, List[Tuple[int, int]]]: """ Refines vertices and edges using clustering and edge length constraints. Args: vertices: 3D vertex coordinates. edges: List of edge connections. Returns: final_vertices: Refined 3D vertices. final_edges: Refined edge connections. """ if len(vertices) < 2: return vertices, edges clustering = DBSCAN( eps=Config.REFINEMENT_CLUSTER_EPS, min_samples=Config.REFINEMENT_MIN_SAMPLES ).fit(vertices) labels = clustering.labels_ refined_vertices = vertices.copy() refined_centers = {} for label in set(labels): if label == -1: continue refined_centers[label] = np.mean(vertices[labels == label], axis=0) for i, label in enumerate(labels): if label in refined_centers: refined_vertices[i] = refined_centers[label] final_edges_set = set() for u, v in edges: if u >= len(refined_vertices) or v >= len(refined_vertices): continue p1 = refined_vertices[u] p2 = refined_vertices[v] dist = np.linalg.norm(p1 - p2) if Config.MIN_EDGE_LENGTH <= dist <= Config.MAX_EDGE_LENGTH: if not np.allclose(p1, p2, atol=1e-3): final_edges_set.add(tuple(sorted((u, v)))) if not final_edges_set: return np.empty((0, 3)), [] used_idxs = set(u for u, v in final_edges_set) | set(v for u, v in final_edges_set) sorted_used_idxs = sorted(list(used_idxs)) final_map = {old_id: new_id for new_id, old_id in enumerate(sorted_used_idxs)} final_vertices = np.array([refined_vertices[old_id] for old_id in sorted_used_idxs]) final_edges = [(final_map[u], final_map[v]) for u, v in final_edges_set] return final_vertices, final_edges def predict_wireframe(entry: Dict[str, Any]) -> Tuple[np.ndarray, List[Tuple[int, int]]]: """ Main prediction function for the S23DR wireframe challenge. Args: entry: Input data entry. Returns: final_vertices: 3D vertices of the wireframe. final_edges: Edge connections. """ try: good_entry = convert_entry_to_human_readable(entry) colmap_rec = good_entry['colmap_binary'] vert_edge_per_image = {} for i, (gest_img, img_id, ade_seg, depth_img) in enumerate( zip( good_entry['gestalt'], good_entry['image_ids'], good_entry['ade'], good_entry['depth'], ) ): depth_size = (gest_img.width, gest_img.height) gest_seg_np = np.array(gest_img.resize(depth_size)).astype(np.uint8) vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np) if not vertices: vert_edge_per_image[i] = ([], [], []) else: depth_fitted, depth_sparse, found_sparse, col_img = get_fitted_dense_depth( depth_img, colmap_rec, img_id, ade_seg ) if found_sparse and col_img is not None: uv, depth_vert = get_uv_depth(vertices, depth_fitted, depth_sparse) vertices_3d = project_vertices_to_3d(uv, depth_vert, col_img) vert_edge_per_image[i] = (vertices, connections, vertices_3d) else: vert_edge_per_image[i] = (vertices, connections, []) all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image) all_3d_vertices, connections_3d = prune_too_far( all_3d_vertices, connections_3d, colmap_rec ) final_vertices, final_edges = metric_aware_refine( all_3d_vertices, connections_3d ) if len(final_vertices) < 2 or len(final_edges) < 1: return empty_solution() return final_vertices, final_edges except Exception as e: print(f"An error occurred in the main prediction pipeline: {e}") import traceback traceback.print_exc() return empty_solution()