|
import logging |
|
from typing import List, Dict |
|
|
|
import math |
|
import torch |
|
from torch import nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
_logger = logging.getLogger(__name__) |
|
|
|
|
|
def resample_patch_embed( |
|
patch_embed, |
|
new_size: List[int], |
|
interpolation: str = 'bicubic', |
|
antialias: bool = True, |
|
verbose: bool = False, |
|
): |
|
"""Resample the weights of the patch embedding kernel to target resolution. |
|
We resample the patch embedding kernel by approximately inverting the effect |
|
of patch resizing. |
|
|
|
Code based on: |
|
https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py |
|
|
|
With this resizing, we can for example load a B/8 filter into a B/16 model |
|
and, on 2x larger input image, the result will match. |
|
|
|
Args: |
|
patch_embed: original parameter to be resized. |
|
new_size (tuple(int, int): target shape (height, width)-only. |
|
interpolation (str): interpolation for resize |
|
antialias (bool): use anti-aliasing filter in resize |
|
verbose (bool): log operation |
|
Returns: |
|
Resized patch embedding kernel. |
|
""" |
|
import numpy as np |
|
try: |
|
import functorch |
|
vmap = functorch.vmap |
|
except ImportError: |
|
if hasattr(torch, 'vmap'): |
|
vmap = torch.vmap |
|
else: |
|
assert False, "functorch or a version of torch with vmap is required for FlexiViT resizing." |
|
|
|
assert len(patch_embed.shape) == 4, "Four dimensions expected" |
|
assert len(new_size) == 2, "New shape should only be hw" |
|
old_size = patch_embed.shape[-2:] |
|
if tuple(old_size) == tuple(new_size): |
|
return patch_embed |
|
|
|
if verbose: |
|
_logger.info(f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation.") |
|
|
|
def resize(x_np, _new_size): |
|
x_tf = torch.Tensor(x_np)[None, None, ...] |
|
x_upsampled = F.interpolate( |
|
x_tf, size=_new_size, mode=interpolation, antialias=antialias)[0, 0, ...].numpy() |
|
return x_upsampled |
|
|
|
def get_resize_mat(_old_size, _new_size): |
|
mat = [] |
|
for i in range(np.prod(_old_size)): |
|
basis_vec = np.zeros(_old_size) |
|
basis_vec[np.unravel_index(i, _old_size)] = 1. |
|
mat.append(resize(basis_vec, _new_size).reshape(-1)) |
|
return np.stack(mat).T |
|
|
|
resize_mat = get_resize_mat(old_size, new_size) |
|
resize_mat_pinv = torch.tensor(np.linalg.pinv(resize_mat.T), device=patch_embed.device) |
|
|
|
def resample_kernel(kernel): |
|
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1) |
|
return resampled_kernel.reshape(new_size) |
|
|
|
v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1) |
|
orig_dtype = patch_embed.dtype |
|
patch_embed = patch_embed.float() |
|
patch_embed = v_resample_kernel(patch_embed) |
|
patch_embed = patch_embed.to(orig_dtype) |
|
return patch_embed |
|
|
|
|
|
def adapt_input_conv(in_chans, conv_weight): |
|
conv_type = conv_weight.dtype |
|
conv_weight = conv_weight.float() |
|
O, I, J, K = conv_weight.shape |
|
if in_chans == 1: |
|
if I > 3: |
|
assert conv_weight.shape[1] % 3 == 0 |
|
|
|
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) |
|
conv_weight = conv_weight.sum(dim=2, keepdim=False) |
|
else: |
|
conv_weight = conv_weight.sum(dim=1, keepdim=True) |
|
elif in_chans != 3: |
|
if I != 3: |
|
raise NotImplementedError('Weight format not supported by conversion.') |
|
else: |
|
|
|
|
|
repeat = int(math.ceil(in_chans / 3)) |
|
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] |
|
conv_weight *= (3 / float(in_chans)) |
|
|
|
|
|
|
|
|
|
|
|
conv_weight = conv_weight.to(conv_type) |
|
return conv_weight |
|
|
|
|
|
def adapt_head_conv(conv_weight): |
|
conv_type = conv_weight.dtype |
|
conv_weight = conv_weight.float() |
|
O, I, J, K = conv_weight.shape |
|
|
|
conv_weight_new = torch.chunk(conv_weight, 6, dim=1) |
|
conv_weight_new = [conv_weight_new.mean(dim=1, keepdim=True) for conv_weight_new in conv_weight_new] |
|
conv_weight_new = torch.cat(conv_weight_new, dim=1) * 0.5 |
|
conv_weight = torch.cat([conv_weight, conv_weight_new], dim=1) |
|
conv_weight = conv_weight.to(conv_type) |
|
return conv_weight |
|
|
|
|
|
def adapt_linear(conv_weight): |
|
conv_type = conv_weight.dtype |
|
conv_weight = conv_weight.float() |
|
O, I = conv_weight.shape |
|
|
|
conv_weight_new = torch.tensor_split(conv_weight, 81, dim=1) |
|
conv_weight_new = [conv_weight_new.mean(dim=1, keepdim=True) for conv_weight_new in conv_weight_new] |
|
conv_weight_new = torch.cat(conv_weight_new, dim=1) |
|
|
|
conv_weight = torch.cat([conv_weight * 0.5, conv_weight_new * 0.5], dim=1) |
|
conv_weight = conv_weight.to(conv_type) |
|
return conv_weight |
|
|
|
|
|
def checkpoint_filter_fn( |
|
state_dict: Dict[str, torch.Tensor], |
|
model: nn.Module, |
|
interpolation: str = 'bicubic', |
|
antialias: bool = True, |
|
) -> Dict[str, torch.Tensor]: |
|
""" convert patch embedding weight from manual patchify + linear proj to conv""" |
|
out_dict = {} |
|
|
|
|
|
prefix = '' |
|
|
|
if prefix: |
|
|
|
state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} |
|
|
|
for k, v in state_dict.items(): |
|
if 'patch_embed.proj.weight' in k: |
|
O, I, H, W = model.backbone.patch_embed.proj.weight.shape |
|
if len(v.shape) < 4: |
|
|
|
O, I, H, W = model.backbone.patch_embed.proj.weight.shape |
|
v = v.reshape(O, -1, H, W) |
|
if v.shape[-1] != W or v.shape[-2] != H: |
|
v = resample_patch_embed( |
|
v, |
|
(H, W), |
|
interpolation=interpolation, |
|
antialias=antialias, |
|
verbose=True, |
|
) |
|
if v.shape[1] != I: |
|
v = adapt_input_conv(I, v) |
|
|
|
|
|
|
|
elif 'decoder_embed.weight' in k: |
|
O, I = model.backbone.decoder_embed.weight.shape |
|
if v.shape[1] != I: |
|
v = adapt_linear(v) |
|
|
|
out_dict[k] = v |
|
|
|
|
|
prefix = 'backbone.' |
|
out_dict = {prefix + k if 'downstream_head' not in k else k: v for k, v in out_dict.items()} |
|
|
|
|
|
out_dict['downstream_head1.dpt.head.4.weight'] = out_dict['downstream_head1.dpt.head.4.weight'][0:3] |
|
out_dict['downstream_head1.dpt.head.4.bias'] = out_dict['downstream_head1.dpt.head.4.bias'][0:3] |
|
out_dict['downstream_head2.dpt.head.4.weight'] = out_dict['downstream_head2.dpt.head.4.weight'][0:3] |
|
out_dict['downstream_head2.dpt.head.4.bias'] = out_dict['downstream_head2.dpt.head.4.bias'][0:3] |
|
|
|
return out_dict |
|
|