DepthAnything-AC / util /visualize_utils.py
ghost233lism's picture
upload models
7f0f123 verified
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
def visualize_geo_prior(img, geo_prior, save_path, batch_idx=0, point_coords=None, normalize=True, alpha=0.6):
"""
Visualize geometric prior matrix and overlay the result on the original image
Args:
img: Original image tensor [B,C,H,W]
geo_prior: Geometric prior tensor with shape [B,HW,HW]
save_path: Save path
batch_idx: Batch index to visualize
point_coords: Reference point coordinates in format (h, w). If None, center point will be used
normalize: Whether to normalize the display result
alpha: Heatmap transparency, 0.0 means completely transparent, 1.0 means completely opaque
"""
B, HW, _ = geo_prior.shape
H = int(np.sqrt(HW))
W = H
geo_prior_single = geo_prior[batch_idx] # [HW,HW]
if point_coords is None:
center_h, center_w = H // 2, W // 2
point_idx = center_h * W + center_w
else:
h, w = point_coords
point_idx = h * W + w
relation = geo_prior_single[point_idx] # [HW]
relation_map = relation.reshape(H, W)
relation_np = relation_map.detach().cpu().numpy()
if normalize:
relation_np = (relation_np - relation_np.min()) / (relation_np.max() - relation_np.min() + 1e-6)
orig_img = img[batch_idx].detach().cpu().numpy()
orig_img = np.transpose(orig_img, (1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
orig_img = std * orig_img + mean
orig_img = np.clip(orig_img * 255, 0, 255).astype(np.uint8)
orig_img = cv2.cvtColor(orig_img, cv2.COLOR_RGB2BGR)
orig_h, orig_w = orig_img.shape[:2]
colored_map = cv2.applyColorMap((relation_np * 255).astype(np.uint8), cv2.COLORMAP_RAINBOW)
colored_map = cv2.resize(colored_map, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
overlay = cv2.addWeighted(orig_img, 1-alpha, colored_map, alpha, 0)
if point_coords is None:
center_w_orig = int(center_w * orig_w / W)
center_h_orig = int(center_h * orig_h / H)
cv2.drawMarker(overlay, (center_w_orig, center_h_orig), (255, 255, 255), cv2.MARKER_CROSS, 20, 2)
else:
w_orig = int(w * orig_w / W)
h_orig = int(h * orig_h / H)
cv2.drawMarker(overlay, (w_orig, h_orig), (255, 255, 255), cv2.MARKER_CROSS, 20, 2)
cv2.imwrite(save_path.replace('.png', '_overlay.png'), overlay)
colored_map = cv2.applyColorMap((relation_np * 255).astype(np.uint8), cv2.COLORMAP_RAINBOW)
cv2.imwrite(save_path.replace('.png', '_heatmap.png'), colored_map)
cv2.imwrite(save_path.replace('.png', '_original.png'), orig_img)
plt.figure(figsize=(10, 8))
plt.imshow(relation_np, cmap='rainbow')
plt.colorbar(label='Geometric Prior Strength')
if point_coords is None:
plt.plot(center_w, center_h, 'w*', markersize=10)
else:
plt.plot(w, h, 'w*', markersize=10)
plt.title(f'Geometric Prior Visualization (Ref Point: {"center" if point_coords is None else f"({point_coords[0]}, {point_coords[1]})"})')
plt.savefig(save_path)
plt.close()
return relation_map
def save_feature_visualization(feature_map, save_path):
"""
Visualize feature map by averaging all feature maps into one image and resize to 518*518
Args:
feature_map: feature map tensor with shape [C,H,W]
save_path: save path
"""
if len(feature_map.shape) == 4:
feature_map = feature_map.squeeze(0)
mean_feature = torch.mean(feature_map, dim=0).detach().cpu().numpy()
mean_feature = (mean_feature - mean_feature.min()) / (mean_feature.max() - mean_feature.min() + 1e-6)
mean_feature = (mean_feature * 255).astype(np.uint8)
mean_feature = cv2.resize(mean_feature, (518, 518), interpolation=cv2.INTER_LINEAR)
colored_feature = cv2.applyColorMap(mean_feature, cv2.COLORMAP_VIRIDIS)
cv2.imwrite(save_path, colored_feature)
def save_depth_visualization(depth_map, filename):
"""
Save depth map visualization as a colored image.
Args:
depth_map (torch.Tensor): Depth map tensor with shape [H, W] or [B, H, W]
filename (str): Output file path for the visualization
"""
depth_norm = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) * 255.0
depth_norm = depth_norm.detach().cpu().numpy().astype(np.uint8)
colored_depth = cv2.applyColorMap(depth_norm, cv2.COLORMAP_INFERNO)
cv2.imwrite(filename, colored_depth)
def save_image(img_tensor, filename):
"""
Save image tensor as a BGR image file.
Args:
img_tensor (torch.Tensor): Image tensor with shape [C, H, W] or [B, C, H, W]
filename (str): Output file path for the image
"""
img = img_tensor.detach().cpu().numpy()
if img.shape[0] == 3:
img = np.transpose(img, (1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img = std * img + mean
img = np.clip(img * 255, 0, 255).astype(np.uint8)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imwrite(filename, img)