import os from pathlib import Path import sys import json import gzip import argparse import numpy as np from PIL import Image import torch import torch.nn as nn import torchvision from einops import rearrange sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from src.evaluation.metrics import compute_lpips, compute_psnr, compute_ssim from misc.image_io import save_image, save_interpolated_video from src.utils.image import process_image from src.model.model.anysplat import AnySplat from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri def setup_args(): """Set up command-line arguments for the eval NVS script.""" parser = argparse.ArgumentParser(description='Test AnySplat on NVS evaluation') parser.add_argument('--data_dir', type=str, required=True, help='Path to NVS dataset') parser.add_argument('--llffhold', type=int, default=8, help='LLFF holdout') parser.add_argument('--output_path', type=str, default="outputs/nvs", help='Path to output directory') return parser.parse_args() def compute_metrics(pred_image, image): psnr = compute_psnr(pred_image, image) ssim = compute_ssim(pred_image, image) lpips = compute_lpips(pred_image, image) return psnr, ssim, lpips def evaluate(args: argparse.Namespace): model = AnySplat.from_pretrained("lhjiang/anysplat") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() for param in model.parameters(): param.requires_grad = False os.makedirs(args.output_path, exist_ok=True) # load images image_folder = args.data_dir image_names = sorted([os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]) images = [process_image(img_path) for img_path in image_names] ctx_indices = [idx for idx, name in enumerate(image_names) if idx % args.llffhold != 0] tgt_indices = [idx for idx, name in enumerate(image_names) if idx % args.llffhold == 0] ctx_images = torch.stack([images[i] for i in ctx_indices], dim=0).unsqueeze(0).to(device) tgt_images = torch.stack([images[i] for i in tgt_indices], dim=0).unsqueeze(0).to(device) ctx_images = (ctx_images+1)*0.5 tgt_images = (tgt_images+1)*0.5 b, v, _, h, w = tgt_images.shape # run inference encoder_output = model.encoder( ctx_images, global_step=0, visualization_dump={}, ) gaussians, pred_context_pose = encoder_output.gaussians, encoder_output.pred_context_pose num_context_view = ctx_images.shape[1] vggt_input_image = torch.cat((ctx_images, tgt_images), dim=1).to(torch.bfloat16) with torch.no_grad(), torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16): aggregated_tokens_list, patch_start_idx = model.encoder.aggregator(vggt_input_image, intermediate_layer_idx=model.encoder.cfg.intermediate_layer_idx) with torch.cuda.amp.autocast(enabled=False): fp32_tokens = [token.float() for token in aggregated_tokens_list] pred_all_pose_enc = model.encoder.camera_head(fp32_tokens)[-1] pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, vggt_input_image.shape[-2:]) extrinsic_padding = torch.tensor([0, 0, 0, 1], device=pred_all_extrinsic.device, dtype=pred_all_extrinsic.dtype).view(1, 1, 1, 4).repeat(b, vggt_input_image.shape[1], 1, 1) pred_all_extrinsic = torch.cat([pred_all_extrinsic, extrinsic_padding], dim=2).inverse() pred_all_intrinsic[:, :, 0] = pred_all_intrinsic[:, :, 0] / w pred_all_intrinsic[:, :, 1] = pred_all_intrinsic[:, :, 1] / h pred_all_context_extrinsic, pred_all_target_extrinsic = pred_all_extrinsic[:, :num_context_view], pred_all_extrinsic[:, num_context_view:] pred_all_context_intrinsic, pred_all_target_intrinsic = pred_all_intrinsic[:, :num_context_view], pred_all_intrinsic[:, num_context_view:] scale_factor = pred_context_pose['extrinsic'][:, :, :3, 3].mean() / pred_all_context_extrinsic[:, :, :3, 3].mean() pred_all_target_extrinsic[..., :3, 3] = pred_all_target_extrinsic[..., :3, 3] * scale_factor pred_all_context_extrinsic[..., :3, 3] = pred_all_context_extrinsic[..., :3, 3] * scale_factor print("scale_factor:", scale_factor) output = model.decoder.forward( gaussians, pred_all_target_extrinsic, pred_all_target_intrinsic.float(), torch.ones(1, v, device=device) * 0.01, torch.ones(1, v, device=device) * 100, (h, w) ) save_interpolated_video(pred_all_context_extrinsic, pred_all_context_intrinsic, b, h, w, gaussians, args.output_path, model.decoder) # Save original images save_path = Path(args.output_path) # os.makedirs(save_path, exist_ok=True) for idx, (gt_image, pred_image) in enumerate(zip(tgt_images[0], output.color[0])): save_image(gt_image, save_path / "gt" / f"{idx:0>6}.jpg") save_image(pred_image, save_path / "pred" / f"{idx:0>6}.jpg") # compute metrics psnr, ssim, lpips = compute_metrics(output.color[0], tgt_images[0]) print(f"PSNR: {psnr.mean():.2f}, SSIM: {ssim.mean():.3f}, LPIPS: {lpips.mean():.3f}") if __name__ == "__main__": args = setup_args() evaluate(args)