|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
from .croco import CroCoNet |
|
|
|
|
|
def croco_args_from_ckpt(ckpt): |
|
if "croco_kwargs" in ckpt: |
|
return ckpt["croco_kwargs"] |
|
elif "args" in ckpt and hasattr( |
|
ckpt["args"], "model" |
|
): |
|
s = ckpt[ |
|
"args" |
|
].model |
|
assert s.startswith("CroCoNet(") |
|
return eval( |
|
"dict" + s[len("CroCoNet") :] |
|
) |
|
else: |
|
return dict() |
|
|
|
|
|
class CroCoDownstreamMonocularEncoder(CroCoNet): |
|
def __init__(self, head, **kwargs): |
|
"""Build network for monocular downstream task, only using the encoder. |
|
It takes an extra argument head, that is called with the features |
|
and a dictionary img_info containing 'width' and 'height' keys |
|
The head is setup with the croconet arguments in this init function |
|
NOTE: It works by *calling super().__init__() but with redefined setters |
|
|
|
""" |
|
super(CroCoDownstreamMonocularEncoder, self).__init__(**kwargs) |
|
head.setup(self) |
|
self.head = head |
|
|
|
def _set_mask_generator(self, *args, **kwargs): |
|
"""No mask generator""" |
|
return |
|
|
|
def _set_mask_token(self, *args, **kwargs): |
|
"""No mask token""" |
|
self.mask_token = None |
|
return |
|
|
|
def _set_decoder(self, *args, **kwargs): |
|
"""No decoder""" |
|
return |
|
|
|
def _set_prediction_head(self, *args, **kwargs): |
|
"""No 'prediction head' for downstream tasks.""" |
|
return |
|
|
|
def forward(self, img): |
|
""" |
|
img if of size batch_size x 3 x h x w |
|
""" |
|
B, C, H, W = img.size() |
|
img_info = {"height": H, "width": W} |
|
need_all_layers = ( |
|
hasattr(self.head, "return_all_blocks") and self.head.return_all_blocks |
|
) |
|
out, _, _ = self._encode_image( |
|
img, do_mask=False, return_all_blocks=need_all_layers |
|
) |
|
return self.head(out, img_info) |
|
|
|
|
|
class CroCoDownstreamBinocular(CroCoNet): |
|
def __init__(self, head, **kwargs): |
|
"""Build network for binocular downstream task |
|
It takes an extra argument head, that is called with the features |
|
and a dictionary img_info containing 'width' and 'height' keys |
|
The head is setup with the croconet arguments in this init function |
|
""" |
|
super(CroCoDownstreamBinocular, self).__init__(**kwargs) |
|
head.setup(self) |
|
self.head = head |
|
|
|
def _set_mask_generator(self, *args, **kwargs): |
|
"""No mask generator""" |
|
return |
|
|
|
def _set_mask_token(self, *args, **kwargs): |
|
"""No mask token""" |
|
self.mask_token = None |
|
return |
|
|
|
def _set_prediction_head(self, *args, **kwargs): |
|
"""No prediction head for downstream tasks, define your own head""" |
|
return |
|
|
|
def encode_image_pairs(self, img1, img2, return_all_blocks=False): |
|
"""run encoder for a pair of images |
|
it is actually ~5% faster to concatenate the images along the batch dimension |
|
than to encode them separately |
|
""" |
|
|
|
|
|
|
|
|
|
out, pos, _ = self._encode_image( |
|
torch.cat((img1, img2), dim=0), |
|
do_mask=False, |
|
return_all_blocks=return_all_blocks, |
|
) |
|
if return_all_blocks: |
|
out, out2 = list(map(list, zip(*[o.chunk(2, dim=0) for o in out]))) |
|
out2 = out2[-1] |
|
else: |
|
out, out2 = out.chunk(2, dim=0) |
|
pos, pos2 = pos.chunk(2, dim=0) |
|
return out, out2, pos, pos2 |
|
|
|
def forward(self, img1, img2): |
|
B, C, H, W = img1.size() |
|
img_info = {"height": H, "width": W} |
|
return_all_blocks = ( |
|
hasattr(self.head, "return_all_blocks") and self.head.return_all_blocks |
|
) |
|
out, out2, pos, pos2 = self.encode_image_pairs( |
|
img1, img2, return_all_blocks=return_all_blocks |
|
) |
|
if return_all_blocks: |
|
decout = self._decoder( |
|
out[-1], pos, None, out2, pos2, return_all_blocks=return_all_blocks |
|
) |
|
decout = out + decout |
|
else: |
|
decout = self._decoder( |
|
out, pos, None, out2, pos2, return_all_blocks=return_all_blocks |
|
) |
|
return self.head(decout, img_info) |
|
|