S23DR_solution_2025 / solution.py
donquicode's picture
Upload folder using huggingface_hub
4999c45 verified
import io
import tempfile
import zipfile
from collections import defaultdict
from typing import Tuple, List, Dict, Any
import cv2
import numpy as np
import pycolmap
from PIL import Image as PImage
from scipy.spatial.distance import cdist
from sklearn.cluster import DBSCAN
from hoho2025.color_mappings import ade20k_color_mapping, gestalt_color_mapping
class Config:
"""Configuration for wireframe extraction pipeline."""
EDGE_THRESHOLD = 15.5
MERGE_THRESHOLD = 0.65
PRUNE_DISTANCE = 3.0
SEARCH_RADIUS = 14
MIN_EDGE_LENGTH = 0.15
MAX_EDGE_LENGTH = 27.0
MORPHOLOGY_KERNEL = 3
REFINEMENT_CLUSTER_EPS = 0.45
REFINEMENT_MIN_SAMPLES = 1
def empty_solution() -> Tuple[np.ndarray, List[Tuple[int, int]]]:
"""Returns an empty wireframe solution."""
return np.zeros((2, 3)), [(0, 1)]
def get_vertices_and_edges_from_segmentation(
gest_seg_np: np.ndarray, edge_th: float = Config.EDGE_THRESHOLD
) -> Tuple[List[Dict[str, Any]], List[Tuple[int, int]]]:
"""
Detects roof vertices and edges from a Gestalt segmentation mask.
Args:
gest_seg_np: Segmentation mask as a numpy array.
edge_th: Distance threshold for associating edges to vertices.
Returns:
vertices: List of detected vertices with coordinates and type.
connections: List of edge connections (vertex index pairs).
"""
vertices, connections = [], []
if not isinstance(gest_seg_np, np.ndarray):
gest_seg_np = np.array(gest_seg_np)
# Vertex detection
for v_type, color_name in [("apex", "apex"), ("eave_end_point", "eave_end_point")]:
color = np.array(gestalt_color_mapping[color_name])
mask = cv2.inRange(gest_seg_np, color - 0.5, color + 0.5)
if mask.sum() == 0:
continue
output = cv2.connectedComponentsWithStats(mask, 8, cv2.CV_32S)
numLabels, labels, stats, centroids = output
for i in range(1, numLabels):
mask_i = (labels == i).astype(np.uint8)
M = cv2.moments(mask_i)
if M["m00"] > 0:
cx, cy = M["m10"] / M["m00"], M["m01"] / M["m00"]
else:
ys, xs = np.where(mask_i)
if len(xs) > 0:
cx, cy = np.mean(xs), np.mean(ys)
else:
continue
vertices.append({"xy": np.array([cx, cy]), "type": v_type})
if not vertices:
return [], []
apex_pts = np.array([v['xy'] for v in vertices])
# Edge detection
edge_classes = ['eave', 'ridge', 'rake', 'valley']
for edge_class in edge_classes:
if edge_class not in gestalt_color_mapping:
continue
edge_color = np.array(gestalt_color_mapping[edge_class])
mask_raw = cv2.inRange(gest_seg_np, edge_color - 0.5, edge_color + 0.5)
# Improved morphology: close then open
kernel = np.ones((Config.MORPHOLOGY_KERNEL, Config.MORPHOLOGY_KERNEL), np.uint8)
mask = cv2.morphologyEx(mask_raw, cv2.MORPH_CLOSE, kernel)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((2, 2), np.uint8))
if mask.sum() == 0:
continue
output = cv2.connectedComponentsWithStats(mask, 8, cv2.CV_32S)
numLabels, labels, stats, centroids = output
for lbl in range(1, numLabels):
ys, xs = np.where(labels == lbl)
if len(xs) < 2:
continue
pts_for_fit = np.column_stack([xs, ys]).astype(np.float32)
line_params = cv2.fitLine(
pts_for_fit, distType=cv2.DIST_L2, param=0, reps=0.01, aeps=0.01
)
vx, vy, x0, y0 = line_params.ravel()
proj = ((xs - x0) * vx + (ys - y0) * vy)
p1 = np.array([x0 + proj.min() * vx, y0 + proj.min() * vy])
p2 = np.array([x0 + proj.max() * vx, y0 + proj.max() * vy])
if len(apex_pts) < 2:
continue
dists = np.array([point_to_segment_dist(apex_pts[i], p1, p2) for i in range(len(apex_pts))])
near_indices = np.where(dists <= edge_th)[0]
if len(near_indices) < 2:
continue
for i in range(len(near_indices)):
for j in range(i + 1, len(near_indices)):
conn = tuple(sorted((near_indices[i], near_indices[j])))
if conn not in connections:
connections.append(conn)
return vertices, connections
def get_uv_depth(
vertices: List[Dict[str, Any]],
depth_fitted: np.ndarray,
sparse_depth: np.ndarray,
search_radius: int = Config.SEARCH_RADIUS,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Assigns depth to each vertex using a weighted search in the sparse depth map.
Args:
vertices: List of detected vertices.
depth_fitted: Dense depth map.
sparse_depth: Sparse depth map.
search_radius: Search radius for depth assignment.
Returns:
uv: 2D coordinates of vertices.
vertex_depth: Depth values for each vertex.
"""
uv = np.array([vert['xy'] for vert in vertices], dtype=np.float32)
uv_int = np.round(uv).astype(np.int32)
H, W = depth_fitted.shape[:2]
uv_int[:, 0] = np.clip(uv_int[:, 0], 0, W - 1)
uv_int[:, 1] = np.clip(uv_int[:, 1], 0, H - 1)
vertex_depth = np.zeros(len(vertices), dtype=np.float32)
for i, (x_i, y_i) in enumerate(uv_int):
x0, x1 = max(0, x_i - search_radius), min(W, x_i + search_radius + 1)
y0, y1 = max(0, y_i - search_radius), min(H, y_i + search_radius + 1)
region = sparse_depth[y0:y1, x0:x1]
valid_y, valid_x = np.where(region > 0)
if valid_y.size > 0:
global_x, global_y = x0 + valid_x, y0 + valid_y
dist_sq = (global_x - x_i) ** 2 + (global_y - y_i) ** 2
weights = np.exp(-dist_sq / (2 * (search_radius / 3) ** 2))
vertex_depth[i] = np.sum(weights * region[valid_y, valid_x]) / np.sum(weights)
else:
vertex_depth[i] = depth_fitted[y_i, x_i]
return uv, vertex_depth
def read_colmap_rec(colmap_data: bytes) -> pycolmap.Reconstruction:
"""Reads a COLMAP reconstruction from a zipped binary."""
with tempfile.TemporaryDirectory() as tmpdir:
with zipfile.ZipFile(io.BytesIO(colmap_data), "r") as zf:
zf.extractall(tmpdir)
rec = pycolmap.Reconstruction(tmpdir)
return rec
def convert_entry_to_human_readable(entry: Dict[str, Any]) -> Dict[str, Any]:
"""Converts a raw entry to a human-readable format."""
out = {}
for k, v in entry.items():
if k == 'colmap_binary':
out[k] = read_colmap_rec(v)
elif k in ['K', 'R', 't']:
try:
out[k] = np.array(v)
except ValueError:
out[k] = v
else:
out[k] = v
out['__key__'] = entry.get('order_id', 'unknown_id')
return out
def get_house_mask(ade20k_seg: np.ndarray) -> np.ndarray:
"""Returns a mask for house/building regions from ADE20K segmentation."""
house_classes_ade20k = [
'wall', 'house', 'building;edifice', 'door;double;door', 'windowpane;window'
]
np_seg = np.array(ade20k_seg)
full_mask = np.zeros(np_seg.shape[:2], dtype=np.uint8)
for c in house_classes_ade20k:
if c in ade20k_color_mapping:
color = np.array(ade20k_color_mapping[c])
mask = cv2.inRange(np_seg, color - 0.5, color + 0.5)
full_mask = np.logical_or(full_mask, mask)
return full_mask
def point_to_segment_dist(pt: np.ndarray, seg_p1: np.ndarray, seg_p2: np.ndarray) -> float:
"""Computes the distance from a point to a line segment."""
if np.allclose(seg_p1, seg_p2):
return np.linalg.norm(pt - seg_p1)
seg_vec = seg_p2 - seg_p1
pt_vec = pt - seg_p1
seg_len2 = seg_vec.dot(seg_vec)
t = max(0, min(1, pt_vec.dot(seg_vec) / seg_len2))
proj = seg_p1 + t * seg_vec
return np.linalg.norm(pt - proj)
def get_sparse_depth(
colmap_rec: pycolmap.Reconstruction, img_id_substring: str, depth: np.ndarray
) -> Tuple[np.ndarray, bool, Any]:
"""Projects COLMAP 3D points into the image to create a sparse depth map."""
H, W = depth.shape
found_img = None
for img_id_c, col_img in colmap_rec.images.items():
if img_id_substring in col_img.name:
found_img = col_img
break
if found_img is None:
return np.zeros((H, W), dtype=np.float32), False, None
points_xyz = [
p3D.xyz for pid, p3D in colmap_rec.points3D.items() if found_img.has_point3D(pid)
]
if not points_xyz:
return np.zeros((H, W), dtype=np.float32), False, found_img
points_xyz = np.array(points_xyz)
uv, z_vals = [], []
for xyz in points_xyz:
proj = found_img.project_point(xyz)
if proj is not None:
u_i, v_i = int(round(proj[0])), int(round(proj[1]))
if 0 <= u_i < W and 0 <= v_i < H:
uv.append((u_i, v_i))
mat4x4 = np.eye(4)
mat4x4[:3, :4] = found_img.cam_from_world.matrix()
p_cam = mat4x4 @ np.array([xyz[0], xyz[1], xyz[2], 1.0])
z_vals.append(p_cam[2] / p_cam[3])
uv, z_vals = np.array(uv, dtype=int), np.array(z_vals)
depth_out = np.zeros((H, W), dtype=np.float32)
if len(uv) > 0:
depth_out[uv[:, 1], uv[:, 0]] = z_vals
return depth_out, True, found_img
def fit_scale_robust_median(
depth: np.ndarray, sparse_depth: np.ndarray, validity_mask: np.ndarray = None
) -> Tuple[float, np.ndarray]:
"""Fits a scale factor between dense and sparse depth using the median ratio."""
mask = (sparse_depth > 0.1) & (depth > 0.1) & (sparse_depth < 50) & (depth < 50)
if validity_mask is not None:
mask &= validity_mask
X, Y = depth[mask], sparse_depth[mask]
if len(X) < 5:
return 1.0, depth
ratios = Y / X
alpha = np.median(ratios)
return alpha, alpha * depth
def get_fitted_dense_depth(
depth: np.ndarray, colmap_rec: pycolmap.Reconstruction, img_id: str, ade20k_seg: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, bool, Any]:
"""Fits the dense depth map to the sparse COLMAP depth."""
depth_np = np.array(depth) / 1000.0
depth_sparse, found_sparse, col_img = get_sparse_depth(colmap_rec, img_id, depth_np)
if not found_sparse:
return depth_np, np.zeros_like(depth_np), False, None
house_mask = get_house_mask(ade20k_seg)
k, depth_fitted = fit_scale_robust_median(depth_np, depth_sparse, validity_mask=house_mask)
return depth_fitted, depth_sparse, True, col_img
def project_vertices_to_3d(
uv: np.ndarray, depth_vert: np.ndarray, col_img: Any
) -> np.ndarray:
"""Projects 2D vertices to 3D using camera intrinsics and depth."""
xy_local = np.ones((len(uv), 3))
K = col_img.camera.calibration_matrix()
xy_local[:, 0] = (uv[:, 0] - K[0, 2]) / K[0, 0]
xy_local[:, 1] = (uv[:, 1] - K[1, 2]) / K[1, 1]
vertices_3d_local = xy_local * depth_vert[..., None]
world_to_cam = np.eye(4)
world_to_cam[:3] = col_img.cam_from_world.matrix()
cam_to_world = np.linalg.inv(world_to_cam)
vertices_3d_homogeneous = cv2.convertPointsToHomogeneous(vertices_3d_local)
vertices_3d = cv2.transform(vertices_3d_homogeneous, cam_to_world)
return cv2.convertPointsFromHomogeneous(vertices_3d).reshape(-1, 3)
def merge_vertices_3d(
vert_edge_per_image: Dict[int, Tuple[List[Dict[str, Any]], List[Tuple[int, int]], np.ndarray]],
th: float = Config.MERGE_THRESHOLD,
) -> Tuple[np.ndarray, List[Tuple[int, int]]]:
"""
Merges 3D vertices and edges across multiple views.
Args:
vert_edge_per_image: Dictionary of per-image vertices, edges, and 3D vertices.
th: Distance threshold for merging.
Returns:
new_vertices: Merged 3D vertices.
new_connections: Merged edge connections.
"""
all_3d_vertices, connections_3d, types = [], [], []
cur_start = 0
for cimg_idx, (vertices, connections, vertices_3d) in vert_edge_per_image.items():
if len(vertices) == 0 or len(vertices_3d) == 0:
continue
types += [int(v['type'] == 'apex') for v in vertices]
all_3d_vertices.append(vertices_3d)
connections_3d += [(x + cur_start, y + cur_start) for (x, y) in connections]
cur_start += len(vertices_3d)
if len(all_3d_vertices) == 0:
return np.array([]), []
all_3d_vertices = np.concatenate(all_3d_vertices, axis=0)
if len(all_3d_vertices) == 0:
return np.array([]), []
distmat = cdist(all_3d_vertices, all_3d_vertices)
types = np.array(types).reshape(-1, 1)
same_types = cdist(types, types)
mask_to_merge = (distmat <= th) & (same_types == 0)
new_vertices, new_connections = [], []
to_merge = sorted(list(set([tuple(a.nonzero()[0].tolist()) for a in mask_to_merge])))
to_merge_final = defaultdict(list)
for i in range(len(all_3d_vertices)):
for j in to_merge:
if i in j:
to_merge_final[i] += j
for k, v in to_merge_final.items():
to_merge_final[k] = list(set(v))
already_there, merged = set(), []
for k, v in to_merge_final.items():
if k in already_there:
continue
merged.append(v)
for vv in v:
already_there.add(vv)
old_idx_to_new = {}
count = 0
for idxs in merged:
if len(idxs) > 0:
new_vertices.append(all_3d_vertices[idxs].mean(axis=0))
for idx in idxs:
old_idx_to_new[idx] = count
count += 1
if len(new_vertices) == 0:
return np.array([]), []
new_vertices = np.array(new_vertices)
for conn in connections_3d:
if conn[0] in old_idx_to_new and conn[1] in old_idx_to_new:
new_con = sorted((old_idx_to_new[conn[0]], old_idx_to_new[conn[1]]))
if new_con[0] != new_con[1] and new_con not in new_connections:
new_connections.append(new_con)
return new_vertices, new_connections
def prune_too_far(
all_3d_vertices: np.ndarray, connections_3d: List[Tuple[int, int]], colmap_rec: pycolmap.Reconstruction, th: float = Config.PRUNE_DISTANCE
) -> Tuple[np.ndarray, List[Tuple[int, int]]]:
"""Prunes 3D vertices and edges that are too far from COLMAP points."""
if len(all_3d_vertices) == 0:
return all_3d_vertices, connections_3d
xyz_sfm = np.array([v.xyz for v in colmap_rec.points3D.values()])
if len(xyz_sfm) == 0:
return all_3d_vertices, connections_3d
distmat = cdist(all_3d_vertices, xyz_sfm)
mask = distmat.min(axis=1) <= th
if not np.any(mask):
return np.empty((0, 3)), []
old_to_new_idx = {old: new for new, old in enumerate(np.where(mask)[0])}
new_vertices = all_3d_vertices[mask]
new_connections = []
for u, v in connections_3d:
if u in old_to_new_idx and v in old_to_new_idx:
new_connections.append((old_to_new_idx[u], old_to_new_idx[v]))
return new_vertices, new_connections
def metric_aware_refine(
vertices: np.ndarray, edges: List[Tuple[int, int]]
) -> Tuple[np.ndarray, List[Tuple[int, int]]]:
"""
Refines vertices and edges using clustering and edge length constraints.
Args:
vertices: 3D vertex coordinates.
edges: List of edge connections.
Returns:
final_vertices: Refined 3D vertices.
final_edges: Refined edge connections.
"""
if len(vertices) < 2:
return vertices, edges
clustering = DBSCAN(
eps=Config.REFINEMENT_CLUSTER_EPS, min_samples=Config.REFINEMENT_MIN_SAMPLES
).fit(vertices)
labels = clustering.labels_
refined_vertices = vertices.copy()
refined_centers = {}
for label in set(labels):
if label == -1:
continue
refined_centers[label] = np.mean(vertices[labels == label], axis=0)
for i, label in enumerate(labels):
if label in refined_centers:
refined_vertices[i] = refined_centers[label]
final_edges_set = set()
for u, v in edges:
if u >= len(refined_vertices) or v >= len(refined_vertices):
continue
p1 = refined_vertices[u]
p2 = refined_vertices[v]
dist = np.linalg.norm(p1 - p2)
if Config.MIN_EDGE_LENGTH <= dist <= Config.MAX_EDGE_LENGTH:
if not np.allclose(p1, p2, atol=1e-3):
final_edges_set.add(tuple(sorted((u, v))))
if not final_edges_set:
return np.empty((0, 3)), []
used_idxs = set(u for u, v in final_edges_set) | set(v for u, v in final_edges_set)
sorted_used_idxs = sorted(list(used_idxs))
final_map = {old_id: new_id for new_id, old_id in enumerate(sorted_used_idxs)}
final_vertices = np.array([refined_vertices[old_id] for old_id in sorted_used_idxs])
final_edges = [(final_map[u], final_map[v]) for u, v in final_edges_set]
return final_vertices, final_edges
def predict_wireframe(entry: Dict[str, Any]) -> Tuple[np.ndarray, List[Tuple[int, int]]]:
"""
Main prediction function for the S23DR wireframe challenge.
Args:
entry: Input data entry.
Returns:
final_vertices: 3D vertices of the wireframe.
final_edges: Edge connections.
"""
try:
good_entry = convert_entry_to_human_readable(entry)
colmap_rec = good_entry['colmap_binary']
vert_edge_per_image = {}
for i, (gest_img, img_id, ade_seg, depth_img) in enumerate(
zip(
good_entry['gestalt'],
good_entry['image_ids'],
good_entry['ade'],
good_entry['depth'],
)
):
depth_size = (gest_img.width, gest_img.height)
gest_seg_np = np.array(gest_img.resize(depth_size)).astype(np.uint8)
vertices, connections = get_vertices_and_edges_from_segmentation(gest_seg_np)
if not vertices:
vert_edge_per_image[i] = ([], [], [])
else:
depth_fitted, depth_sparse, found_sparse, col_img = get_fitted_dense_depth(
depth_img, colmap_rec, img_id, ade_seg
)
if found_sparse and col_img is not None:
uv, depth_vert = get_uv_depth(vertices, depth_fitted, depth_sparse)
vertices_3d = project_vertices_to_3d(uv, depth_vert, col_img)
vert_edge_per_image[i] = (vertices, connections, vertices_3d)
else:
vert_edge_per_image[i] = (vertices, connections, [])
all_3d_vertices, connections_3d = merge_vertices_3d(vert_edge_per_image)
all_3d_vertices, connections_3d = prune_too_far(
all_3d_vertices, connections_3d, colmap_rec
)
final_vertices, final_edges = metric_aware_refine(
all_3d_vertices, connections_3d
)
if len(final_vertices) < 2 or len(final_edges) < 1:
return empty_solution()
return final_vertices, final_edges
except Exception as e:
print(f"An error occurred in the main prediction pipeline: {e}")
import traceback
traceback.print_exc()
return empty_solution()