import sys import os sys.path.append(os.path.dirname(os.path.dirname(__file__))) from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from copy import deepcopy from functools import partial from typing import Optional, Tuple, List, Any from dataclasses import dataclass from transformers import PretrainedConfig from transformers.file_utils import ModelOutput from dust3r.utils.misc import ( fill_default_args, freeze_all_params, is_symmetrized, interleave, transpose_to_landscape, ) from dust3r.heads import head_factory from dust3r.utils.camera import PoseEncoder from dust3r.patch_embed import get_patch_embed import dust3r.utils.path_to_croco # noqa: F401 from models.croco import CroCoNet, CrocoConfig # noqa from dust3r.blocks import ( Block, DecoderBlock, Mlp, Attention, CrossAttention, DropPath, CustomDecoderBlock, ) # noqa inf = float("inf") from accelerate.logging import get_logger printer = get_logger(__name__, log_level="DEBUG") @dataclass class ARCroco3DStereoOutput(ModelOutput): """ Custom output class for ARCroco3DStereo. """ ress: Optional[List[Any]] = None views: Optional[List[Any]] = None def strip_module(state_dict): """ Removes the 'module.' prefix from the keys of a state_dict. Args: state_dict (dict): The original state_dict with possible 'module.' prefixes. Returns: OrderedDict: A new state_dict with 'module.' prefixes removed. """ new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] if k.startswith("module.") else k new_state_dict[name] = v return new_state_dict def load_model(model_path, device, verbose=True): if verbose: print("... loading model from", model_path) ckpt = torch.load(model_path, map_location="cpu", weights_only=False) args = ckpt["args"].model.replace( "ManyAR_PatchEmbed", "PatchEmbedDust3R" ) # ManyAR only for aspect ratio not consistent if "landscape_only" not in args: args = args[:-2] + ", landscape_only=False))" else: args = args.replace(" ", "").replace( "landscape_only=True", "landscape_only=False" ) assert "landscape_only=False" in args if verbose: print(f"instantiating : {args}") net = eval(args) s = net.load_state_dict(ckpt["model"], strict=False) if verbose: print(s) return net.to(device) class ARCroco3DStereoConfig(PretrainedConfig): model_type = "arcroco_3d_stereo" def __init__( self, output_mode="pts3d", head_type="linear", # or dpt depth_mode=("exp", -float("inf"), float("inf")), conf_mode=("exp", 1, float("inf")), pose_mode=("exp", -float("inf"), float("inf")), freeze="none", landscape_only=True, patch_embed_cls="PatchEmbedDust3R", ray_enc_depth=2, state_size=324, local_mem_size=256, state_pe="2d", state_dec_num_heads=16, depth_head=False, rgb_head=False, pose_conf_head=False, pose_head=False, **croco_kwargs, ): super().__init__() self.output_mode = output_mode self.head_type = head_type self.depth_mode = depth_mode self.conf_mode = conf_mode self.pose_mode = pose_mode self.freeze = freeze self.landscape_only = landscape_only self.patch_embed_cls = patch_embed_cls self.ray_enc_depth = ray_enc_depth self.state_size = state_size self.state_pe = state_pe self.state_dec_num_heads = state_dec_num_heads self.local_mem_size = local_mem_size self.depth_head = depth_head self.rgb_head = rgb_head self.pose_conf_head = pose_conf_head self.pose_head = pose_head self.croco_kwargs = croco_kwargs class LocalMemory(nn.Module): def __init__( self, size, k_dim, v_dim, num_heads, depth=2, mlp_ratio=4.0, qkv_bias=False, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None, ) -> None: super().__init__() self.v_dim = v_dim self.proj_q = nn.Linear(k_dim, v_dim) self.masked_token = nn.Parameter( torch.randn(1, 1, v_dim) * 0.2, requires_grad=True ) self.mem = nn.Parameter( torch.randn(1, size, 2 * v_dim) * 0.2, requires_grad=True ) self.write_blocks = nn.ModuleList( [ DecoderBlock( 2 * v_dim, num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer, attn_drop=attn_drop, drop=drop, drop_path=drop_path, act_layer=act_layer, norm_mem=norm_mem, rope=rope, ) for _ in range(depth) ] ) self.read_blocks = nn.ModuleList( [ DecoderBlock( 2 * v_dim, num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer, attn_drop=attn_drop, drop=drop, drop_path=drop_path, act_layer=act_layer, norm_mem=norm_mem, rope=rope, ) for _ in range(depth) ] ) def update_mem(self, mem, feat_k, feat_v): """ mem_k: [B, size, C] mem_v: [B, size, C] feat_k: [B, 1, C] feat_v: [B, 1, C] """ feat_k = self.proj_q(feat_k) # [B, 1, C] feat = torch.cat([feat_k, feat_v], dim=-1) for blk in self.write_blocks: mem, _ = blk(mem, feat, None, None) return mem def inquire(self, query, mem): x = self.proj_q(query) # [B, 1, C] x = torch.cat([x, self.masked_token.expand(x.shape[0], -1, -1)], dim=-1) for blk in self.read_blocks: x, _ = blk(x, mem, None, None) return x[..., -self.v_dim :] class ARCroco3DStereo(CroCoNet): config_class = ARCroco3DStereoConfig base_model_prefix = "arcroco3dstereo" supports_gradient_checkpointing = True def __init__(self, config: ARCroco3DStereoConfig): self.gradient_checkpointing = False self.fixed_input_length = True config.croco_kwargs = fill_default_args( config.croco_kwargs, CrocoConfig.__init__ ) self.config = config self.patch_embed_cls = config.patch_embed_cls self.croco_args = config.croco_kwargs croco_cfg = CrocoConfig(**self.croco_args) super().__init__(croco_cfg) self.enc_blocks_ray_map = nn.ModuleList( [ Block( self.enc_embed_dim, 16, 4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), rope=self.rope, ) for _ in range(config.ray_enc_depth) ] ) self.enc_norm_ray_map = nn.LayerNorm(self.enc_embed_dim, eps=1e-6) self.dec_num_heads = self.croco_args["dec_num_heads"] self.pose_head_flag = config.pose_head if self.pose_head_flag: self.pose_token = nn.Parameter( torch.randn(1, 1, self.dec_embed_dim) * 0.02, requires_grad=True ) self.pose_retriever = LocalMemory( size=config.local_mem_size, k_dim=self.enc_embed_dim, v_dim=self.dec_embed_dim, num_heads=self.dec_num_heads, mlp_ratio=4, qkv_bias=True, attn_drop=0.0, norm_layer=partial(nn.LayerNorm, eps=1e-6), rope=None, ) self.register_tokens = nn.Embedding(config.state_size, self.enc_embed_dim) self.state_size = config.state_size self.state_pe = config.state_pe self.masked_img_token = nn.Parameter( torch.randn(1, self.enc_embed_dim) * 0.02, requires_grad=True ) self.masked_ray_map_token = nn.Parameter( torch.randn(1, self.enc_embed_dim) * 0.02, requires_grad=True ) self._set_state_decoder( self.enc_embed_dim, self.dec_embed_dim, config.state_dec_num_heads, self.dec_depth, self.croco_args.get("mlp_ratio", None), self.croco_args.get("norm_layer", None), self.croco_args.get("norm_im2_in_dec", None), ) self.set_downstream_head( config.output_mode, config.head_type, config.landscape_only, config.depth_mode, config.conf_mode, config.pose_mode, config.depth_head, config.rgb_head, config.pose_conf_head, config.pose_head, **self.croco_args, ) self.set_freeze(config.freeze) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kw): if os.path.isfile(pretrained_model_name_or_path): return load_model(pretrained_model_name_or_path, device="cpu") else: try: model = super(ARCroco3DStereo, cls).from_pretrained( pretrained_model_name_or_path, **kw ) except TypeError as e: raise Exception( f"tried to load {pretrained_model_name_or_path} from huggingface, but failed" ) return model def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): self.patch_embed = get_patch_embed( self.patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans=3 ) self.patch_embed_ray_map = get_patch_embed( self.patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans=6 ) def _set_decoder( self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec, ): self.dec_depth = dec_depth self.dec_embed_dim = dec_embed_dim self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True) self.dec_blocks = nn.ModuleList( [ DecoderBlock( dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope, ) for i in range(dec_depth) ] ) self.dec_norm = norm_layer(dec_embed_dim) def _set_state_decoder( self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec, ): self.dec_depth_state = dec_depth self.dec_embed_dim_state = dec_embed_dim self.decoder_embed_state = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True) self.dec_blocks_state = nn.ModuleList( [ DecoderBlock( dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope, ) for i in range(dec_depth) ] ) self.dec_norm_state = norm_layer(dec_embed_dim) def load_state_dict(self, ckpt, **kw): if all(k.startswith("module") for k in ckpt): ckpt = strip_module(ckpt) new_ckpt = dict(ckpt) if not any(k.startswith("dec_blocks_state") for k in ckpt): for key, value in ckpt.items(): if key.startswith("dec_blocks"): new_ckpt[key.replace("dec_blocks", "dec_blocks_state")] = value try: return super().load_state_dict(new_ckpt, **kw) except: try: new_new_ckpt = { k: v for k, v in new_ckpt.items() if not k.startswith("dec_blocks") and not k.startswith("dec_norm") and not k.startswith("decoder_embed") } return super().load_state_dict(new_new_ckpt, **kw) except: new_new_ckpt = {} for key in new_ckpt: if key in self.state_dict(): if new_ckpt[key].size() == self.state_dict()[key].size(): new_new_ckpt[key] = new_ckpt[key] else: printer.info( f"Skipping '{key}': size mismatch (ckpt: {new_ckpt[key].size()}, model: {self.state_dict()[key].size()})" ) else: printer.info(f"Skipping '{key}': not found in model") return super().load_state_dict(new_new_ckpt, **kw) def set_freeze(self, freeze): # this is for use by downstream models self.freeze = freeze to_be_frozen = { "none": [], "mask": [self.mask_token] if hasattr(self, "mask_token") else [], "encoder": [ self.patch_embed, self.patch_embed_ray_map, self.masked_img_token, self.masked_ray_map_token, self.enc_blocks, self.enc_blocks_ray_map, self.enc_norm, self.enc_norm_ray_map, ], "encoder_and_head": [ self.patch_embed, self.patch_embed_ray_map, self.masked_img_token, self.masked_ray_map_token, self.enc_blocks, self.enc_blocks_ray_map, self.enc_norm, self.enc_norm_ray_map, self.downstream_head, ], "encoder_and_decoder": [ self.patch_embed, self.patch_embed_ray_map, self.masked_img_token, self.masked_ray_map_token, self.enc_blocks, self.enc_blocks_ray_map, self.enc_norm, self.enc_norm_ray_map, self.dec_blocks, self.dec_blocks_state, self.pose_retriever, self.pose_token, self.register_tokens, self.decoder_embed_state, self.decoder_embed, self.dec_norm, self.dec_norm_state, ], "decoder": [ self.dec_blocks, self.dec_blocks_state, self.pose_retriever, self.pose_token, ], } freeze_all_params(to_be_frozen[freeze]) def _set_prediction_head(self, *args, **kwargs): """No prediction head""" return def set_downstream_head( self, output_mode, head_type, landscape_only, depth_mode, conf_mode, pose_mode, depth_head, rgb_head, pose_conf_head, pose_head, patch_size, img_size, **kw, ): assert ( img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0 ), f"{img_size=} must be multiple of {patch_size=}" self.output_mode = output_mode self.head_type = head_type self.depth_mode = depth_mode self.conf_mode = conf_mode self.pose_mode = pose_mode self.downstream_head = head_factory( head_type, output_mode, self, has_conf=bool(conf_mode), has_depth=bool(depth_head), has_rgb=bool(rgb_head), has_pose_conf=bool(pose_conf_head), has_pose=bool(pose_head), ) self.head = transpose_to_landscape( self.downstream_head, activate=landscape_only ) def _encode_image(self, image, true_shape): x, pos = self.patch_embed(image, true_shape=true_shape) assert self.enc_pos_embed is None for blk in self.enc_blocks: if self.gradient_checkpointing and self.training: x = checkpoint(blk, x, pos, use_reentrant=False) else: x = blk(x, pos) x = self.enc_norm(x) return [x], pos, None def _encode_ray_map(self, ray_map, true_shape): x, pos = self.patch_embed_ray_map(ray_map, true_shape=true_shape) assert self.enc_pos_embed is None for blk in self.enc_blocks_ray_map: if self.gradient_checkpointing and self.training: x = checkpoint(blk, x, pos, use_reentrant=False) else: x = blk(x, pos) x = self.enc_norm_ray_map(x) return [x], pos, None def _encode_state(self, image_tokens, image_pos): batch_size = image_tokens.shape[0] state_feat = self.register_tokens( torch.arange(self.state_size, device=image_pos.device) ) if self.state_pe == "1d": state_pos = ( torch.tensor( [[i, i] for i in range(self.state_size)], dtype=image_pos.dtype, device=image_pos.device, )[None] .expand(batch_size, -1, -1) .contiguous() ) # .long() elif self.state_pe == "2d": width = int(self.state_size**0.5) width = width + 1 if width % 2 == 1 else width state_pos = ( torch.tensor( [[i // width, i % width] for i in range(self.state_size)], dtype=image_pos.dtype, device=image_pos.device, )[None] .expand(batch_size, -1, -1) .contiguous() ) elif self.state_pe == "none": state_pos = None state_feat = state_feat[None].expand(batch_size, -1, -1) return state_feat, state_pos, None def _encode_views(self, views, img_mask=None, ray_mask=None): device = views[0]["img"].device batch_size = views[0]["img"].shape[0] given = True if img_mask is None and ray_mask is None: given = False if not given: img_mask = torch.stack( [view["img_mask"] for view in views], dim=0 ) # Shape: (num_views, batch_size) ray_mask = torch.stack( [view["ray_mask"] for view in views], dim=0 ) # Shape: (num_views, batch_size) imgs = torch.stack( [view["img"] for view in views], dim=0 ) # Shape: (num_views, batch_size, C, H, W) ray_maps = torch.stack( [view["ray_map"] for view in views], dim=0 ) # Shape: (num_views, batch_size, H, W, C) shapes = [] for view in views: if "true_shape" in view: shapes.append(view["true_shape"]) else: shape = torch.tensor(view["img"].shape[-2:], device=device) shapes.append(shape.unsqueeze(0).repeat(batch_size, 1)) shapes = torch.stack(shapes, dim=0).to( imgs.device ) # Shape: (num_views, batch_size, 2) imgs = imgs.view( -1, *imgs.shape[2:] ) # Shape: (num_views * batch_size, C, H, W) ray_maps = ray_maps.view( -1, *ray_maps.shape[2:] ) # Shape: (num_views * batch_size, H, W, C) shapes = shapes.view(-1, 2) # Shape: (num_views * batch_size, 2) img_masks_flat = img_mask.view(-1) # Shape: (num_views * batch_size) ray_masks_flat = ray_mask.view(-1) selected_imgs = imgs[img_masks_flat] selected_shapes = shapes[img_masks_flat] if selected_imgs.size(0) > 0: img_out, img_pos, _ = self._encode_image(selected_imgs, selected_shapes) else: raise NotImplementedError full_out = [ torch.zeros( len(views) * batch_size, *img_out[0].shape[1:], device=img_out[0].device ) for _ in range(len(img_out)) ] full_pos = torch.zeros( len(views) * batch_size, *img_pos.shape[1:], device=img_pos.device, dtype=img_pos.dtype, ) for i in range(len(img_out)): full_out[i][img_masks_flat] += img_out[i] full_out[i][~img_masks_flat] += self.masked_img_token full_pos[img_masks_flat] += img_pos ray_maps = ray_maps.permute(0, 3, 1, 2) # Change shape to (N, C, H, W) selected_ray_maps = ray_maps[ray_masks_flat] selected_shapes_ray = shapes[ray_masks_flat] if selected_ray_maps.size(0) > 0: ray_out, ray_pos, _ = self._encode_ray_map( selected_ray_maps, selected_shapes_ray ) assert len(ray_out) == len(full_out), f"{len(ray_out)}, {len(full_out)}" for i in range(len(ray_out)): full_out[i][ray_masks_flat] += ray_out[i] full_out[i][~ray_masks_flat] += self.masked_ray_map_token full_pos[ray_masks_flat] += ( ray_pos * (~img_masks_flat[ray_masks_flat][:, None, None]).long() ) else: raymaps = torch.zeros( 1, 6, imgs[0].shape[-2], imgs[0].shape[-1], device=img_out[0].device ) ray_mask_flat = torch.zeros_like(img_masks_flat) ray_mask_flat[:1] = True ray_out, ray_pos, _ = self._encode_ray_map(raymaps, shapes[ray_mask_flat]) for i in range(len(ray_out)): full_out[i][ray_mask_flat] += ray_out[i] * 0.0 full_out[i][~ray_mask_flat] += self.masked_ray_map_token * 0.0 return ( shapes.chunk(len(views), dim=0), [out.chunk(len(views), dim=0) for out in full_out], full_pos.chunk(len(views), dim=0), ) def _decoder(self, f_state, pos_state, f_img, pos_img, f_pose, pos_pose): final_output = [(f_state, f_img)] # before projection assert f_state.shape[-1] == self.dec_embed_dim f_img = self.decoder_embed(f_img) if self.pose_head_flag: assert f_pose is not None and pos_pose is not None f_img = torch.cat([f_pose, f_img], dim=1) pos_img = torch.cat([pos_pose, pos_img], dim=1) final_output.append((f_state, f_img)) for blk_state, blk_img in zip(self.dec_blocks_state, self.dec_blocks): if ( self.gradient_checkpointing and self.training and torch.is_grad_enabled() ): f_state, _ = checkpoint( blk_state, *final_output[-1][::+1], pos_state, pos_img, use_reentrant=not self.fixed_input_length, ) f_img, _ = checkpoint( blk_img, *final_output[-1][::-1], pos_img, pos_state, use_reentrant=not self.fixed_input_length, ) else: f_state, _ = blk_state(*final_output[-1][::+1], pos_state, pos_img) f_img, _ = blk_img(*final_output[-1][::-1], pos_img, pos_state) final_output.append((f_state, f_img)) del final_output[1] # duplicate with final_output[0] final_output[-1] = ( self.dec_norm_state(final_output[-1][0]), self.dec_norm(final_output[-1][1]), ) return zip(*final_output) def _downstream_head(self, decout, img_shape, **kwargs): B, S, D = decout[-1].shape head = getattr(self, f"head") return head(decout, img_shape, **kwargs) def _init_state(self, image_tokens, image_pos): """ Current Version: input the first frame img feature and pose to initialize the state feature and pose """ state_feat, state_pos, _ = self._encode_state(image_tokens, image_pos) state_feat = self.decoder_embed_state(state_feat) return state_feat, state_pos def _recurrent_rollout( self, state_feat, state_pos, current_feat, current_pos, pose_feat, pose_pos, init_state_feat, img_mask=None, reset_mask=None, update=None, ): new_state_feat, dec = self._decoder( state_feat, state_pos, current_feat, current_pos, pose_feat, pose_pos ) new_state_feat = new_state_feat[-1] return new_state_feat, dec def _get_img_level_feat(self, feat): return torch.mean(feat, dim=1, keepdim=True) def _forward_encoder(self, views): shape, feat_ls, pos = self._encode_views(views) feat = feat_ls[-1] state_feat, state_pos = self._init_state(feat[0], pos[0]) mem = self.pose_retriever.mem.expand(feat[0].shape[0], -1, -1) init_state_feat = state_feat.clone() init_mem = mem.clone() return (feat, pos, shape), ( init_state_feat, init_mem, state_feat, state_pos, mem, ) def _forward_decoder_step( self, views, i, feat_i, pos_i, shape_i, init_state_feat, init_mem, state_feat, state_pos, mem, ): if self.pose_head_flag: global_img_feat_i = self._get_img_level_feat(feat_i) if i == 0: pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1) else: pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem) pose_pos_i = -torch.ones( feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype ) else: pose_feat_i = None pose_pos_i = None new_state_feat, dec = self._recurrent_rollout( state_feat, state_pos, feat_i, pos_i, pose_feat_i, pose_pos_i, init_state_feat, img_mask=views[i]["img_mask"], reset_mask=views[i]["reset"], update=views[i].get("update", None), ) out_pose_feat_i = dec[-1][:, 0:1] new_mem = self.pose_retriever.update_mem( mem, global_img_feat_i, out_pose_feat_i ) head_input = [ dec[0].float(), dec[self.dec_depth * 2 // 4][:, 1:].float(), dec[self.dec_depth * 3 // 4][:, 1:].float(), dec[self.dec_depth].float(), ] res = self._downstream_head(head_input, shape_i, pos=pos_i) img_mask = views[i]["img_mask"] update = views[i].get("update", None) if update is not None: update_mask = img_mask & update # if don't update, then whatever img_mask else: update_mask = img_mask update_mask = update_mask[:, None, None].float() state_feat = new_state_feat * update_mask + state_feat * ( 1 - update_mask ) # update global state mem = new_mem * update_mask + mem * (1 - update_mask) # then update local state reset_mask = views[i]["reset"] if reset_mask is not None: reset_mask = reset_mask[:, None, None].float() state_feat = init_state_feat * reset_mask + state_feat * (1 - reset_mask) mem = init_mem * reset_mask + mem * (1 - reset_mask) return res, (state_feat, mem) def _forward_impl(self, views, ret_state=False): shape, feat_ls, pos = self._encode_views(views) feat = feat_ls[-1] state_feat, state_pos = self._init_state(feat[0], pos[0]) mem = self.pose_retriever.mem.expand(feat[0].shape[0], -1, -1) init_state_feat = state_feat.clone() init_mem = mem.clone() all_state_args = [(state_feat, state_pos, init_state_feat, mem, init_mem)] ress = [] for i in range(len(views)): feat_i = feat[i] pos_i = pos[i] if self.pose_head_flag: global_img_feat_i = self._get_img_level_feat(feat_i) if i == 0: pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1) else: pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem) pose_pos_i = -torch.ones( feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype ) else: pose_feat_i = None pose_pos_i = None new_state_feat, dec = self._recurrent_rollout( state_feat, state_pos, feat_i, pos_i, pose_feat_i, pose_pos_i, init_state_feat, img_mask=views[i]["img_mask"], reset_mask=views[i]["reset"], update=views[i].get("update", None), ) out_pose_feat_i = dec[-1][:, 0:1] new_mem = self.pose_retriever.update_mem( mem, global_img_feat_i, out_pose_feat_i ) assert len(dec) == self.dec_depth + 1 head_input = [ dec[0].float(), dec[self.dec_depth * 2 // 4][:, 1:].float(), dec[self.dec_depth * 3 // 4][:, 1:].float(), dec[self.dec_depth].float(), ] res = self._downstream_head(head_input, shape[i], pos=pos_i) ress.append(res) img_mask = views[i]["img_mask"] update = views[i].get("update", None) if update is not None: update_mask = ( img_mask & update ) # if don't update, then whatever img_mask else: update_mask = img_mask update_mask = update_mask[:, None, None].float() state_feat = new_state_feat * update_mask + state_feat * ( 1 - update_mask ) # update global state mem = new_mem * update_mask + mem * ( 1 - update_mask ) # then update local state reset_mask = views[i]["reset"] if reset_mask is not None: reset_mask = reset_mask[:, None, None].float() state_feat = init_state_feat * reset_mask + state_feat * ( 1 - reset_mask ) mem = init_mem * reset_mask + mem * (1 - reset_mask) all_state_args.append( (state_feat, state_pos, init_state_feat, mem, init_mem) ) if ret_state: return ress, views, all_state_args return ress, views def forward(self, views, ret_state=False): if ret_state: ress, views, state_args = self._forward_impl(views, ret_state=ret_state) return ARCroco3DStereoOutput(ress=ress, views=views), state_args else: ress, views = self._forward_impl(views, ret_state=ret_state) return ARCroco3DStereoOutput(ress=ress, views=views) def inference_step( self, view, state_feat, state_pos, init_state_feat, mem, init_mem ): batch_size = view["img"].shape[0] raymaps = [] shapes = [] for j in range(batch_size): assert view["ray_mask"][j] raymap = view["ray_map"][[j]].permute(0, 3, 1, 2) raymaps.append(raymap) shapes.append( view.get( "true_shape", torch.tensor(view["ray_map"].shape[-2:])[None].repeat( view["ray_map"].shape[0], 1 ), )[[j]] ) raymaps = torch.cat(raymaps, dim=0) shape = torch.cat(shapes, dim=0).to(raymaps.device) feat_ls, pos, _ = self._encode_ray_map(raymaps, shapes) feat_i = feat_ls[-1] pos_i = pos if self.pose_head_flag: global_img_feat_i = self._get_img_level_feat(feat_i) pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem) pose_pos_i = -torch.ones( feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype ) else: pose_feat_i = None pose_pos_i = None new_state_feat, dec = self._recurrent_rollout( state_feat, state_pos, feat_i, pos_i, pose_feat_i, pose_pos_i, init_state_feat, img_mask=view["img_mask"], reset_mask=view["reset"], update=view.get("update", None), ) out_pose_feat_i = dec[-1][:, 0:1] new_mem = self.pose_retriever.update_mem( mem, global_img_feat_i, out_pose_feat_i ) assert len(dec) == self.dec_depth + 1 head_input = [ dec[0].float(), dec[self.dec_depth * 2 // 4][:, 1:].float(), dec[self.dec_depth * 3 // 4][:, 1:].float(), dec[self.dec_depth].float(), ] res = self._downstream_head(head_input, shape, pos=pos_i) return res, view def forward_recurrent(self, views, device, ret_state=False): ress = [] all_state_args = [] for i, view in enumerate(views): device = view["img"].device batch_size = view["img"].shape[0] img_mask = view["img_mask"].reshape( -1, batch_size ) # Shape: (1, batch_size) ray_mask = view["ray_mask"].reshape( -1, batch_size ) # Shape: (1, batch_size) imgs = view["img"].unsqueeze(0) # Shape: (1, batch_size, C, H, W) ray_maps = view["ray_map"].unsqueeze( 0 ) # Shape: (num_views, batch_size, H, W, C) shapes = ( view["true_shape"].unsqueeze(0) if "true_shape" in view else torch.tensor(view["img"].shape[-2:], device=device) .unsqueeze(0) .repeat(batch_size, 1) .unsqueeze(0) ) # Shape: (num_views, batch_size, 2) imgs = imgs.view( -1, *imgs.shape[2:] ) # Shape: (num_views * batch_size, C, H, W) ray_maps = ray_maps.view( -1, *ray_maps.shape[2:] ) # Shape: (num_views * batch_size, H, W, C) shapes = shapes.view(-1, 2).to( imgs.device ) # Shape: (num_views * batch_size, 2) img_masks_flat = img_mask.view(-1) # Shape: (num_views * batch_size) ray_masks_flat = ray_mask.view(-1) selected_imgs = imgs[img_masks_flat] selected_shapes = shapes[img_masks_flat] if selected_imgs.size(0) > 0: img_out, img_pos, _ = self._encode_image(selected_imgs, selected_shapes) else: img_out, img_pos = None, None ray_maps = ray_maps.permute(0, 3, 1, 2) # Change shape to (N, C, H, W) selected_ray_maps = ray_maps[ray_masks_flat] selected_shapes_ray = shapes[ray_masks_flat] if selected_ray_maps.size(0) > 0: ray_out, ray_pos, _ = self._encode_ray_map( selected_ray_maps, selected_shapes_ray ) else: ray_out, ray_pos = None, None shape = shapes if img_out is not None and ray_out is None: feat_i = img_out[-1] pos_i = img_pos elif img_out is None and ray_out is not None: feat_i = ray_out[-1] pos_i = ray_pos elif img_out is not None and ray_out is not None: feat_i = img_out[-1] + ray_out[-1] pos_i = img_pos else: raise NotImplementedError if i == 0: state_feat, state_pos = self._init_state(feat_i, pos_i) mem = self.pose_retriever.mem.expand(feat_i.shape[0], -1, -1) init_state_feat = state_feat.clone() init_mem = mem.clone() all_state_args.append( (state_feat, state_pos, init_state_feat, mem, init_mem) ) if self.pose_head_flag: global_img_feat_i = self._get_img_level_feat(feat_i) if i == 0: pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1) else: pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem) pose_pos_i = -torch.ones( feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype ) else: pose_feat_i = None pose_pos_i = None new_state_feat, dec = self._recurrent_rollout( state_feat, state_pos, feat_i, pos_i, pose_feat_i, pose_pos_i, init_state_feat, img_mask=view["img_mask"], reset_mask=view["reset"], update=view.get("update", None), ) out_pose_feat_i = dec[-1][:, 0:1] new_mem = self.pose_retriever.update_mem( mem, global_img_feat_i, out_pose_feat_i ) assert len(dec) == self.dec_depth + 1 head_input = [ dec[0].float(), dec[self.dec_depth * 2 // 4][:, 1:].float(), dec[self.dec_depth * 3 // 4][:, 1:].float(), dec[self.dec_depth].float(), ] res = self._downstream_head(head_input, shape, pos=pos_i) ress.append(res) img_mask = view["img_mask"] update = view.get("update", None) if update is not None: update_mask = ( img_mask & update ) # if don't update, then whatever img_mask else: update_mask = img_mask update_mask = update_mask[:, None, None].float() state_feat = new_state_feat * update_mask + state_feat * ( 1 - update_mask ) # update global state mem = new_mem * update_mask + mem * ( 1 - update_mask ) # then update local state reset_mask = view["reset"] if reset_mask is not None: reset_mask = reset_mask[:, None, None].float() state_feat = init_state_feat * reset_mask + state_feat * ( 1 - reset_mask ) mem = init_mem * reset_mask + mem * (1 - reset_mask) all_state_args.append( (state_feat, state_pos, init_state_feat, mem, init_mem) ) if ret_state: return ress, views, all_state_args return ress, views if __name__ == "__main__": print(ARCroco3DStereo.mro()) cfg = ARCroco3DStereoConfig( state_size=256, pos_embed="RoPE100", rgb_head=True, pose_head=True, img_size=(224, 224), head_type="linear", output_mode="pts3d+pose", depth_mode=("exp", -inf, inf), conf_mode=("exp", 1, inf), pose_mode=("exp", -inf, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, ) ARCroco3DStereo(cfg)