roll-ai's picture
Upload 177 files
59d751c verified
from transformers import pipeline
from PIL import Image
import requests
import torchvision
import os
from .camera.WarperPytorch import Warper
import numpy as np
from einops import rearrange, repeat
import torch
import torch.nn as nn
import torch.nn.functional as F
from .depth_anything_v2.dpt import DepthAnythingV2
import pdb
def to_pil_image(x):
# x: c h w, [-1, 1]
x_np = ((x+1)/2*255).permute(1,2,0).detach().cpu().numpy().astype(np.uint8)
x_pil = Image.fromarray(x_np)
return x_pil
def to_npy(x):
return ((x+1)/2*255).permute(1,2,0).detach().cpu().numpy()
def unnormalize_intrinsic(x, size):
h, w = size
x_ = x.detach().clone()
x_[:,0:1] = x[:,0:1].detach().clone() * w
x_[:,1:2] = x[:,1:2].detach().clone() * h
return x_
class DepthWarping_wrapper(nn.Module):
def __init__(self,
model_config,
ckpt_path,):
super().__init__()
# self.depth_model = pipeline(task="depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
self.depth_model = DepthAnythingV2(**model_config)
self.depth_model.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
self.depth_model = self.depth_model.eval()
self.warper = Warper()
def get_input(self, batch):
# pdb.set_trace()
b, v = batch["target"]["intrinsics"].shape[:2]
h, w = batch["context"]["image"].shape[-2:]
image = (batch["context"]["image"]) * 2 - 1
image_ctxt = repeat(image, "b c h w -> (b v) c h w", v=v)
c2w_ctxt = repeat(batch["context"]["extrinsics"], "b t h w -> (b v t) h w", v=v) # No need to apply inverse as it is an eye matrix.
# c2w_trgt = rearrange(torch.inverse(batch["target"]["extrinsics"]), "b t h w -> (b t) h w")
c2w_trgt = rearrange(batch["target"]["extrinsics"], "b t h w -> (b t) h w")
intrinsics_ctxt = unnormalize_intrinsic(repeat(batch["context"]["intrinsics"], "b t h w -> (b v t) h w", v=v), size=(h,w))
intrinsics_trgt = unnormalize_intrinsic(rearrange(batch["target"]["intrinsics"], "b t h w -> (b t) h w"), size=(h,w))
# image = image.squeeze(1)
# depth_ctxt = torch.stack([torch.tensor(self.depth_model.infer_image(to_npy(x))) for x in image], dim=0).to(image.device).unsqueeze(1) # B 1 H W
depth_ctxt = torch.stack([self.depth_model.infer_image(to_npy(x)) for x in image], dim=0).to(image.device).unsqueeze(1) # B 1 H W
# depth_ctxt = torch.nn.functional.interpolate(
# depth_ctxt,
# size=(h,w),
# mode="bicubic",
# align_corners=False,
# )
return image_ctxt, c2w_ctxt, c2w_trgt, intrinsics_ctxt, intrinsics_trgt, depth_ctxt, batch['variable_intrinsic']
def forward(self, batch):
image_ctxt, c2w_ctxt, c2w_trgt, intrinsics_ctxt, intrinsics_trgt, depth_ctxt, variable_intrinsic = self.get_input(batch)
with torch.cuda.amp.autocast(enabled=False):
b, v = batch["target"]["intrinsics"].shape[:2]
# h, w = image_ctxt.shape[-2:]
warped_trgt, mask_trgt, warped_depth_trgt, flow_f = self.warper.forward_warp(
frame1=image_ctxt,
mask1=None,
depth1=repeat(depth_ctxt, "b c h w -> (b t) c h w", t=v),
transformation1=c2w_ctxt,
transformation2=c2w_trgt,
intrinsic1=intrinsics_ctxt,
intrinsic2=intrinsics_trgt if variable_intrinsic else None)
warped_src, mask_src, warped_depth_src, flow_b = self.warper.forward_warp(
frame1=warped_trgt,
mask1=None,
depth1=warped_depth_trgt,
transformation1=c2w_trgt,
transformation2=c2w_ctxt,
intrinsic1=intrinsics_trgt,
intrinsic2=None)
# if use_backward_flow:
# mask = mask_trgt
# else:
# mask = mask_src
return flow_f, flow_b, warped_trgt, depth_ctxt, warped_depth_trgt