import argparse import cv2 import numpy as np import os import torch import torch.nn as nn import torch.nn.functional as F from torchvision.transforms import Compose from models.monoD.depth_anything.dpt import DPT_DINOv2 from models.monoD.depth_anything.util.transform import ( Resize, NormalizeImage, PrepareForNet ) def build(config): """ Build the model from the config NOTE: the config should contain the following - encoder: the encoder type of the model - load_from: the path to the pretrained model """ args = config assert args.encoder in ['vits', 'vitb', 'vitl'] if args.encoder == 'vits': depth_anything = DPT_DINOv2(encoder='vits', features=64, out_channels=[48, 96, 192, 384], localhub=args.localhub).cuda() elif args.encoder == 'vitb': depth_anything = DPT_DINOv2(encoder='vitb', features=128, out_channels=[96, 192, 384, 768], localhub=args.localhub).cuda() else: depth_anything = DPT_DINOv2(encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], localhub=args.localhub).cuda() depth_anything.load_state_dict(torch.load(args.load_from, map_location='cpu'), strict=True) total_params = sum(param.numel() for param in depth_anything.parameters()) print('Total parameters: {:.2f}M'.format(total_params / 1e6)) depth_anything.eval() return depth_anything class DepthAnything(nn.Module): def __init__(self, args): super(DepthAnything, self).__init__() # build the chosen model self.dpAny = build(args) def infer(self, rgbs): """ Infer the depth map from the input RGB image Args: rgbs: the input RGB image B x 3 x H x W (Cuda Tensor) Asserts: the input should be a cuda tensor """ assert (rgbs.is_cuda)&(len(rgbs.shape) == 4) T, C, H, W = rgbs.shape # prepare the input Resizer = Resize( width=518, height=518, resize_target=False, keep_aspect_ratio=True, ensure_multiple_of=14, resize_method='lower_bound', image_interpolation_method=cv2.INTER_CUBIC, ) #NOTE: step 1 Resize width, height = Resizer.get_size( rgbs.shape[2], rgbs.shape[3] ) rgbs = F.interpolate( rgbs, (int(height), int(width)), mode='bicubic', align_corners=False ) #NOTE: step 2 NormalizeImage mean_ = torch.tensor([0.485, 0.456, 0.406], device=rgbs.device).view(1, 3, 1, 1) std_ = torch.tensor([0.229, 0.224, 0.225], device=rgbs.device).view(1, 3, 1, 1) rgbs = (rgbs - mean_)/std_ #NOTE: step 3 PrepareForNet # get the depth map disp = self.dpAny(rgbs) disp = F.interpolate( disp[:,None], (H, W), mode='bilinear', align_corners=False ) # clamping the farthest depth to 100x of the nearest depth_map = disp return depth_map