import os import json import torch import torchvision import numpy as np from tqdm.auto import tqdm from diffusionsfm.dataset.co3d_v2 import ( Co3dDataset, full_scene_scale, ) from pytorch3d.renderer import PerspectiveCameras from diffusionsfm.utils.visualization import filter_and_align_point_clouds from diffusionsfm.inference.load_model import load_model from diffusionsfm.inference.predict import predict_cameras from diffusionsfm.utils.geometry import ( compute_angular_error_batch, get_error, n_to_np_rotations, ) from diffusionsfm.utils.slurm import init_slurm_signals_if_slurm from diffusionsfm.utils.rays import cameras_to_rays from diffusionsfm.utils.rays import normalize_cameras_batch @torch.no_grad() def evaluate( cfg, model, dataset, num_images, device, use_pbar=True, calculate_intrinsics=True, additional_timesteps=(), num_evaluate=None, max_num_images=None, mode=None, metrics=True, load_depth=True, ): if cfg.training.get("dpt_head", False): H_in = W_in = 224 H_out = W_out = cfg.training.full_num_patches_y else: H_in = H_out = cfg.model.num_patches_x W_in = W_out = cfg.model.num_patches_y results = {} instances = np.arange(0, len(dataset)) if num_evaluate is None else np.linspace(0, len(dataset) - 1, num_evaluate, endpoint=True, dtype=int) instances = tqdm(instances) if use_pbar else instances for counter, idx in enumerate(instances): batch = dataset[idx] instance = batch["model_id"] images = batch["image"].to(device) focal_length = batch["focal_length"].to(device)[:num_images] R = batch["R"].to(device)[:num_images] T = batch["T"].to(device)[:num_images] crop_parameters = batch["crop_parameters"].to(device)[:num_images] if load_depth: depths = batch["depth"].to(device)[:num_images] depth_masks = batch["depth_masks"].to(device)[:num_images] try: object_masks = batch["object_masks"].to(device)[:num_images] except KeyError: object_masks = depth_masks.clone() # Normalize cameras and scale depths for output resolution cameras_gt = PerspectiveCameras( R=R, T=T, focal_length=focal_length, device=device ) cameras_gt, _, _ = normalize_cameras_batch( [cameras_gt], first_cam_mediod=cfg.training.first_cam_mediod, normalize_first_camera=cfg.training.normalize_first_camera, depths=depths.unsqueeze(0), crop_parameters=crop_parameters.unsqueeze(0), num_patches_x=H_in, num_patches_y=W_in, return_scales=True, ) cameras_gt = cameras_gt[0] gt_rays = cameras_to_rays( cameras=cameras_gt, num_patches_x=H_in, num_patches_y=W_in, crop_parameters=crop_parameters, depths=depths, mode=mode, ) gt_points = gt_rays.get_segments().view(num_images, -1, 3) resize = torchvision.transforms.Resize( 224, antialias=False, interpolation=torchvision.transforms.InterpolationMode.NEAREST_EXACT, ) else: cameras_gt = PerspectiveCameras( R=R, T=T, focal_length=focal_length, device=device ) pred_cameras, additional_cams = predict_cameras( model, images, device, crop_parameters=crop_parameters, num_patches_x=H_out, num_patches_y=W_out, max_num_images=max_num_images, additional_timesteps=additional_timesteps, calculate_intrinsics=calculate_intrinsics, mode=mode, return_rays=True, use_homogeneous=cfg.model.get("use_homogeneous", False), ) cameras_to_evaluate = additional_cams + [pred_cameras] all_cams_batch = dataset.get_data( sequence_name=instance, ids=np.arange(0, batch["n"]), no_images=True ) gt_scene_scale = full_scene_scale(all_cams_batch) R_gt = R T_gt = T errors = [] for _, (camera, pred_rays) in enumerate(cameras_to_evaluate): R_pred = camera.R T_pred = camera.T f_pred = camera.focal_length R_pred_rel = n_to_np_rotations(num_images, R_pred).cpu().numpy() R_gt_rel = n_to_np_rotations(num_images, batch["R"]).cpu().numpy() R_error = compute_angular_error_batch(R_pred_rel, R_gt_rel) CC_error, _ = get_error(True, R_pred, T_pred, R_gt, T_gt, gt_scene_scale) if load_depth and metrics: # Evaluate outputs at the same resolution as DUSt3R pred_points = pred_rays.get_segments().view(num_images, H_out, H_out, 3) pred_points = pred_points.permute(0, 3, 1, 2) pred_points = resize(pred_points).permute(0, 2, 3, 1).view(num_images, H_in*W_in, 3) ( _, _, _, _, metric_values, ) = filter_and_align_point_clouds( num_images, gt_points, pred_points, depth_masks, depth_masks, images, metrics=metrics, num_patches_x=H_in, ) ( _, _, _, _, object_metric_values, ) = filter_and_align_point_clouds( num_images, gt_points, pred_points, depth_masks * object_masks, depth_masks * object_masks, images, metrics=metrics, num_patches_x=H_in, ) result = { "R_pred": R_pred.detach().cpu().numpy().tolist(), "T_pred": T_pred.detach().cpu().numpy().tolist(), "f_pred": f_pred.detach().cpu().numpy().tolist(), "R_gt": R_gt.detach().cpu().numpy().tolist(), "T_gt": T_gt.detach().cpu().numpy().tolist(), "f_gt": focal_length.detach().cpu().numpy().tolist(), "scene_scale": gt_scene_scale, "R_error": R_error.tolist(), "CC_error": CC_error, } if load_depth and metrics: result["CD"] = metric_values[1] result["CD_Object"] = object_metric_values[1] else: result["CD"] = 0 result["CD_Object"] = 0 errors.append(result) results[instance] = errors if counter == len(dataset) - 1: break return results def save_results( output_dir, checkpoint=800_000, category="hydrant", num_images=None, calculate_additional_timesteps=True, calculate_intrinsics=True, split="test", force=False, sample_num=1, max_num_images=None, dataset="co3d", ): init_slurm_signals_if_slurm() os.umask(000) # Default to 777 permissions eval_path = os.path.join( output_dir, f"eval_{dataset}", f"{category}_{num_images}_{sample_num}_ckpt{checkpoint}.json", ) if os.path.exists(eval_path) and not force: print(f"File {eval_path} already exists. Skipping.") return if num_images is not None and num_images > 8: custom_keys = {"model.num_images": num_images} ignore_keys = ["pos_table"] else: custom_keys = None ignore_keys = [] device = torch.device("cuda") model, cfg = load_model( output_dir, checkpoint=checkpoint, device=device, custom_keys=custom_keys, ignore_keys=ignore_keys, ) if num_images is None: num_images = cfg.dataset.num_images if cfg.training.dpt_head: # Evaluate outputs at the same resolution as DUSt3R depth_size = 224 else: depth_size = cfg.model.num_patches_x dataset = Co3dDataset( category=category, split=split, num_images=num_images, apply_augmentation=False, sample_num=None if split == "train" else sample_num, use_global_intrinsics=cfg.dataset.use_global_intrinsics, load_depths=True, center_crop=True, depth_size=depth_size, mask_holes=not cfg.training.regression, img_size=256 if cfg.model.unet_diffuser else 224, ) print(f"Category {category} {len(dataset)}") if calculate_additional_timesteps: additional_timesteps = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90] else: additional_timesteps = [] results = evaluate( cfg=cfg, model=model, dataset=dataset, num_images=num_images, device=device, calculate_intrinsics=calculate_intrinsics, additional_timesteps=additional_timesteps, max_num_images=max_num_images, mode="segment", ) os.makedirs(os.path.dirname(eval_path), exist_ok=True) with open(eval_path, "w") as f: json.dump(results, f)