File size: 1,562 Bytes
59d751c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

from .utils import instantiate_from_config, get_camera_flow_generator_input, warp_image

import pdb

class CameraFlowGenerator(nn.Module):
    def __init__(
        self,
        depth_estimator_kwargs,
        use_observed_mask=False,
        cycle_th=3.,
    ):
        super().__init__()
        
        self.depth_warping_module = instantiate_from_config(depth_estimator_kwargs)
        self.use_observed_mask = use_observed_mask
        self.cycle_th = cycle_th
    
    def forward(self, condition_image, camera_flow_generator_input):
        # NOTE. camera_flow_generator_input is a dict of network inputs!
        # camera_flow_generator_input: Dict
        # - image
        # - intrinsics
        # - extrinsics
        with torch.no_grad():
            flow_f, flow_b, depth_warped_frames, depth_ctxt, depth_trgt = self.depth_warping_module(camera_flow_generator_input)
        image_ctxt = repeat(condition_image, "b c h w -> (b v) c h w", v=(depth_warped_frames.shape[0]//condition_image.shape[0]))
        
        log_dict = {
            'depth_warped_frames': depth_warped_frames,
            'depth_ctxt': depth_ctxt,
            'depth_trgt': depth_trgt,
        }
        
        # if self.use_observed_mask:
        #     observed_mask = run_filtering(flow_f, flow_b, cycle_th=self.cycle_th)
        #     log_dict[
        #         'observed_mask': observed_mask
        #     ]
        
        return flow_f, log_dict