|
import os |
|
import sys |
|
import json |
|
import gzip |
|
import argparse |
|
import numpy as np |
|
from PIL import Image |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torchvision |
|
from einops import rearrange |
|
from lpips import LPIPS |
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
from src.model.model.anysplat import AnySplat |
|
from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri |
|
from src.model.encoder.vggt.utils.load_fn import load_and_preprocess_images |
|
from src.utils.pose import align_to_first_camera, calculate_auc_np, convert_pt3d_RT_to_opencv, se3_to_relative_pose_error |
|
from src.misc.cam_utils import camera_normalization, pose_auc, rotation_6d_to_matrix, update_pose, get_pnp_pose |
|
|
|
def setup_args(): |
|
"""Set up command-line arguments for the CO3D evaluation script.""" |
|
parser = argparse.ArgumentParser(description='Test AnySplat on CO3D dataset') |
|
parser.add_argument('--debug', action='store_true', help='Enable debug mode (only test on specific category)') |
|
parser.add_argument('--use_ba', action='store_true', default=False, help='Enable bundle adjustment') |
|
parser.add_argument('--fast_eval', action='store_true', default=False, help='Only evaluate 10 sequences per category') |
|
parser.add_argument('--min_num_images', type=int, default=50, help='Minimum number of images for a sequence') |
|
parser.add_argument('--num_frames', type=int, default=10, help='Number of frames to use for testing') |
|
parser.add_argument('--co3d_dir', type=str, required=True, help='Path to CO3D dataset') |
|
parser.add_argument('--co3d_anno_dir', type=str, required=True, help='Path to CO3D annotations') |
|
parser.add_argument('--seed', type=int, default=0, help='Random seed for reproducibility') |
|
return parser.parse_args() |
|
|
|
lpips = LPIPS(net="vgg") |
|
|
|
def rendering_loss(pred_image, image): |
|
lpips_loss = lpips.forward(rearrange(pred_image, "b v c h w -> (b v) c h w"), rearrange(image, "b v c h w -> (b v) c h w"), normalize=True) |
|
delta = pred_image - (image + 1) / 2 |
|
mse_loss = (delta**2).mean() |
|
return mse_loss + 0.05 * lpips_loss.mean() |
|
|
|
def process_sequence(model, seq_name, seq_data, category, co3d_dir, min_num_images, num_frames, use_ba, device, dtype): |
|
""" |
|
Process a single sequence and compute pose errors. |
|
|
|
Args: |
|
model: AnySplat model |
|
seq_name: Sequence name |
|
seq_data: Sequence data |
|
category: Category name |
|
co3d_dir: CO3D dataset directory |
|
min_num_images: Minimum number of images required |
|
num_frames: Number of frames to sample |
|
use_ba: Whether to use bundle adjustment |
|
device: Device to run on |
|
dtype: Data type for model inference |
|
|
|
Returns: |
|
rError: Rotation errors |
|
tError: Translation errors |
|
""" |
|
if len(seq_data) < min_num_images: |
|
return None, None |
|
|
|
metadata = [] |
|
for data in seq_data: |
|
|
|
if data["T"][0] + data["T"][1] + data["T"][2] > 1e5: |
|
return None, None |
|
|
|
extri_opencv = convert_pt3d_RT_to_opencv(data["R"], data["T"]) |
|
metadata.append({ |
|
"filepath": data["filepath"], |
|
"extri": extri_opencv, |
|
}) |
|
|
|
ids = np.random.choice(len(metadata), num_frames, replace=False) |
|
image_names = [os.path.join(co3d_dir, metadata[i]["filepath"]) for i in ids] |
|
gt_extri = [np.array(metadata[i]["extri"]) for i in ids] |
|
gt_extri = np.stack(gt_extri, axis=0) |
|
|
|
max_size = max(Image.open(image_names[0]).size) |
|
if max_size < 448: |
|
return None, None |
|
images = load_and_preprocess_images(image_names)[None].to(device) |
|
|
|
batch = { |
|
"context": { |
|
"image": images*2.0-1, |
|
"image_names": image_names, |
|
"index": ids, |
|
}, |
|
"scene": "co3d" |
|
} |
|
|
|
if use_ba: |
|
try: |
|
encoder_output = model.encoder( |
|
batch, |
|
global_step=0, |
|
visualization_dump={}, |
|
) |
|
gaussians, pred_context_pose = encoder_output.gaussians, encoder_output.pred_context_pose |
|
pred_extrinsic = pred_context_pose['extrinsic'] |
|
pred_intrinsic = pred_context_pose['intrinsic'] |
|
|
|
b, v, _, h, w = images.shape |
|
with torch.set_grad_enabled(True), torch.cuda.amp.autocast(enabled=False, dtype=torch.float32): |
|
cam_rot_delta = nn.Parameter(torch.zeros([b, v, 6], requires_grad=True, device=pred_extrinsic.device, dtype=torch.float32)) |
|
cam_trans_delta = nn.Parameter(torch.zeros([b, v, 3], requires_grad=True, device=pred_extrinsic.device, dtype=torch.float32)) |
|
opt_params = [] |
|
model.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], dtype=torch.float32).to(pred_extrinsic.device)) |
|
opt_params.append( |
|
{ |
|
"params": [cam_rot_delta], |
|
"lr": 0.005, |
|
} |
|
) |
|
opt_params.append( |
|
{ |
|
"params": [cam_trans_delta], |
|
"lr": 0.005, |
|
} |
|
) |
|
pose_optimizer = torch.optim.Adam(opt_params) |
|
extrinsics = pred_extrinsic.clone().float() |
|
|
|
for i in range(100): |
|
pose_optimizer.zero_grad() |
|
dx, drot = cam_trans_delta, cam_rot_delta |
|
rot = rotation_6d_to_matrix( |
|
drot + model.identity.expand(b, v, -1) |
|
) |
|
|
|
transform = torch.eye(4, device=extrinsics.device).repeat((b, v, 1, 1)) |
|
transform[..., :3, :3] = rot |
|
transform[..., :3, 3] = dx |
|
|
|
new_extrinsics = torch.matmul(extrinsics, transform) |
|
|
|
output = model.decoder.forward( |
|
gaussians, |
|
new_extrinsics, |
|
pred_intrinsic.float(), |
|
0.1, |
|
100.0, |
|
(h, w), |
|
|
|
|
|
) |
|
|
|
rendering_loss = rendering_loss(output.color, images*2.0-1) |
|
torchvision.utils.save_image(output.color[0], f"outputs/vis/output_co3d_{i}.png") |
|
print(f"Rendering loss: {rendering_loss.item()}") |
|
|
|
|
|
rendering_loss.backward() |
|
pose_optimizer.step() |
|
torchvision.utils.save_image(images[0], f"outputs/vis/gt_co3d.png") |
|
pred_extrinsic = new_extrinsics.inverse()[0][:,:-1,:] |
|
|
|
except Exception as e: |
|
print(f"BA failed with error: {e}. Falling back to standard VGGT inference.") |
|
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype): |
|
aggregated_tokens_list, patch_start_idx = model.encoder.aggregator(images, intermediate_layer_idx=model.encoder.cfg.intermediate_layer_idx) |
|
with torch.cuda.amp.autocast(dtype=torch.float32): |
|
fp32_tokens = [token.float() for token in aggregated_tokens_list] |
|
pred_all_pose_enc = model.encoder.camera_head(fp32_tokens)[-1] |
|
pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, images.shape[-2:]) |
|
pred_extrinsic = pred_all_extrinsic[0] |
|
else: |
|
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype): |
|
aggregated_tokens_list, patch_start_idx = model.encoder.aggregator(images, intermediate_layer_idx=model.encoder.cfg.intermediate_layer_idx) |
|
with torch.cuda.amp.autocast(dtype=torch.float32): |
|
fp32_tokens = [token.float() for token in aggregated_tokens_list] |
|
pred_all_pose_enc = model.encoder.camera_head(fp32_tokens)[-1] |
|
pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, images.shape[-2:]) |
|
pred_extrinsic = pred_all_extrinsic[0] |
|
|
|
with torch.cuda.amp.autocast(dtype=torch.float32): |
|
gt_extrinsic = torch.from_numpy(gt_extri).to(device) |
|
add_row = torch.tensor([0, 0, 0, 1], device=device).expand(pred_extrinsic.size(0), 1, 4) |
|
|
|
pred_se3 = torch.cat((pred_extrinsic, add_row), dim=1) |
|
gt_se3 = torch.cat((gt_extrinsic, add_row), dim=1) |
|
|
|
|
|
|
|
|
|
gt_se3 = align_to_first_camera(gt_se3) |
|
|
|
rel_rangle_deg, rel_tangle_deg = se3_to_relative_pose_error(pred_se3, gt_se3, num_frames) |
|
print(f"{category} sequence {seq_name} Rot Error: {rel_rangle_deg.mean().item():.4f}") |
|
print(f"{category} sequence {seq_name} Trans Error: {rel_tangle_deg.mean().item():.4f}") |
|
|
|
return rel_rangle_deg.cpu().numpy(), rel_tangle_deg.cpu().numpy() |
|
|
|
def evaluate(args: argparse.Namespace): |
|
model = AnySplat.from_pretrained("lhjiang/anysplat") |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
model.eval() |
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
SEEN_CATEGORIES = [ |
|
"apple", "backpack", "banana", "baseballbat", "baseballglove", |
|
"bench", "bicycle", "bottle", "bowl", "broccoli", |
|
"cake", "car", "carrot", "cellphone", "chair", |
|
"cup", "donut", "hairdryer", "handbag", "hydrant", |
|
"keyboard", "laptop", "microwave", "motorcycle", "mouse", |
|
"orange", "parkingmeter", "pizza", "plant", "stopsign", |
|
"teddybear", "toaster", "toilet", "toybus", "toyplane", |
|
"toytrain", "toytruck", "tv", "umbrella", "vase", "wineglass", |
|
] |
|
|
|
if args.debug: |
|
SEEN_CATEGORIES = ["apple"] |
|
|
|
per_category_results = {} |
|
|
|
for category in SEEN_CATEGORIES: |
|
print(f"Loading annotation for {category} test set") |
|
annotation_file = os.path.join(args.co3d_anno_dir, f"{category}_test.jgz") |
|
|
|
try: |
|
with gzip.open(annotation_file, "r") as fin: |
|
annotation = json.loads(fin.read()) |
|
except FileNotFoundError: |
|
print(f"Annotation file not found for {category}, skipping") |
|
continue |
|
|
|
rError = [] |
|
tError = [] |
|
|
|
for seq_name, seq_data in annotation.items(): |
|
print("-" * 50) |
|
|
|
print(f"Processing {seq_name} for {category} test set") |
|
if args.debug and not os.path.exists(os.path.join(args.co3d_dir, category, seq_name)): |
|
print(f"Skipping {seq_name} (not found)") |
|
continue |
|
|
|
seq_rError, seq_tError = process_sequence( |
|
model, seq_name, seq_data, category, args.co3d_dir, |
|
args.min_num_images, args.num_frames, args.use_ba, device, torch.bfloat16 |
|
) |
|
|
|
print("-" * 50) |
|
|
|
if seq_rError is not None and seq_tError is not None: |
|
rError.extend(seq_rError) |
|
tError.extend(seq_tError) |
|
|
|
if not rError: |
|
print(f"No valid sequences found for {category}, skipping") |
|
continue |
|
|
|
rError = np.array(rError) |
|
tError = np.array(tError) |
|
|
|
thresholds = [5, 10, 20, 30] |
|
Aucs = {} |
|
|
|
for threshold in thresholds: |
|
Auc, _ = calculate_auc_np(rError, tError, max_threshold=threshold) |
|
Aucs[threshold] = Auc |
|
|
|
print("="*80) |
|
print(f"AUC of {category} test set: {Aucs[30]:.4f}") |
|
print("="*80) |
|
|
|
per_category_results[category] = { |
|
"rError": rError, |
|
"tError": tError, |
|
"Auc_5": Aucs[5], |
|
"Auc_10": Aucs[10], |
|
"Auc_20": Aucs[20], |
|
"Auc_30": Aucs[30], |
|
} |
|
|
|
|
|
print("\nSummary of AUC results:") |
|
print("-"*50) |
|
for category in sorted(per_category_results.keys()): |
|
print(f"{category:<15} AUC_5: {per_category_results[category]['Auc_5']:.4f}") |
|
print(f"{category:<15} AUC_30: {per_category_results[category]['Auc_30']:.4f}") |
|
print(f"{category:<15} AUC_20: {per_category_results[category]['Auc_20']:.4f}") |
|
print(f"{category:<15} AUC_10: {per_category_results[category]['Auc_10']:.4f}") |
|
|
|
if per_category_results: |
|
mean_AUC_30 = np.mean([per_category_results[category]["Auc_30"] for category in per_category_results]) |
|
mean_AUC_20 = np.mean([per_category_results[category]["Auc_20"] for category in per_category_results]) |
|
mean_AUC_10 = np.mean([per_category_results[category]["Auc_10"] for category in per_category_results]) |
|
mean_AUC_5 = np.mean([per_category_results[category]["Auc_5"] for category in per_category_results]) |
|
print("-"*50) |
|
print(f"Mean AUC_5: {mean_AUC_5:.4f}") |
|
print(f"Mean AUC_30: {mean_AUC_30:.4f}") |
|
print(f"Mean AUC_20: {mean_AUC_20:.4f}") |
|
print(f"Mean AUC_10: {mean_AUC_10:.4f}") |
|
|
|
|
|
|
|
|
|
import datetime |
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
random_index = timestamp |
|
results_file = f"co3d_results_{random_index}.txt" |
|
|
|
with open(results_file, "w") as f: |
|
f.write("CO3D Evaluation Results\n") |
|
f.write("=" * 50 + "\n\n") |
|
|
|
f.write("Per-category results:\n") |
|
f.write("-" * 50 + "\n") |
|
for category in sorted(per_category_results.keys()): |
|
f.write(f"{category:<15} AUC_30: {per_category_results[category]['Auc_30']:.4f}\n") |
|
f.write(f"{category:<15} AUC_20: {per_category_results[category]['Auc_20']:.4f}\n") |
|
f.write(f"{category:<15} AUC_10: {per_category_results[category]['Auc_10']:.4f}\n") |
|
f.write(f"{category:<15} AUC_5: {per_category_results[category]['Auc_5']:.4f}\n") |
|
f.write("\n") |
|
|
|
if per_category_results: |
|
f.write("-" * 50 + "\n") |
|
f.write(f"Mean AUC_30: {mean_AUC_30:.4f}\n") |
|
f.write(f"Mean AUC_20: {mean_AUC_20:.4f}\n") |
|
f.write(f"Mean AUC_10: {mean_AUC_10:.4f}\n") |
|
f.write(f"Mean AUC_5: {mean_AUC_5:.4f}\n") |
|
f.write("\n" + "=" * 50 + "\n") |
|
|
|
print(f"Results saved to {results_file}") |
|
|
|
|
|
if __name__ == "__main__": |
|
args = setup_args() |
|
evaluate(args) |
|
|