|
|
|
|
|
import numpy as np |
|
import torch |
|
|
|
|
|
def round_width(width, multiplier, min_width=1, divisor=1, verbose=False): |
|
if not multiplier: |
|
return width |
|
width *= multiplier |
|
min_width = min_width or divisor |
|
if verbose: |
|
print(f"min width {min_width}") |
|
print(f"width {width} divisor {divisor}") |
|
print(f"other {int(width + divisor / 2) // divisor * divisor}") |
|
|
|
width_out = max(min_width, int(width + divisor / 2) // divisor * divisor) |
|
if width_out < 0.9 * width: |
|
width_out += divisor |
|
return int(width_out) |
|
|
|
|
|
def validate_checkpoint_wrapper_import(checkpoint_wrapper): |
|
""" |
|
Check if checkpoint_wrapper is imported. |
|
""" |
|
if checkpoint_wrapper is None: |
|
raise ImportError("Please install fairscale.") |
|
|
|
|
|
def get_gkern(kernlen, std): |
|
"""Returns a 2D Gaussian kernel array.""" |
|
|
|
def _gaussian_fn(kernlen, std): |
|
n = torch.arange(0, kernlen).float() |
|
n -= n.mean() |
|
n /= std |
|
w = torch.exp(-0.5 * n**2) |
|
return w |
|
|
|
gkern1d = _gaussian_fn(kernlen, std) |
|
gkern2d = torch.outer(gkern1d, gkern1d) |
|
return gkern2d / gkern2d.sum() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size, cls_token=False): |
|
""" |
|
grid_size: int of the grid height and width |
|
t_size: int of the temporal size |
|
return: |
|
pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
|
""" |
|
assert embed_dim % 4 == 0 |
|
embed_dim_spatial = embed_dim // 4 * 3 |
|
embed_dim_temporal = embed_dim // 4 |
|
|
|
|
|
grid_h = np.arange(grid_size, dtype=np.float32) |
|
grid_w = np.arange(grid_size, dtype=np.float32) |
|
grid = np.meshgrid(grid_w, grid_h) |
|
grid = np.stack(grid, axis=0) |
|
|
|
grid = grid.reshape([2, 1, grid_size, grid_size]) |
|
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid( |
|
embed_dim_spatial, grid |
|
) |
|
|
|
|
|
grid_t = np.arange(t_size, dtype=np.float32) |
|
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid( |
|
embed_dim_temporal, grid_t |
|
) |
|
|
|
|
|
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] |
|
pos_embed_temporal = np.repeat( |
|
pos_embed_temporal, grid_size**2, axis=1 |
|
) |
|
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] |
|
pos_embed_spatial = np.repeat( |
|
pos_embed_spatial, t_size, axis=0 |
|
) |
|
|
|
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) |
|
pos_embed = pos_embed.reshape([-1, embed_dim]) |
|
|
|
if cls_token: |
|
pos_embed = np.concatenate( |
|
[np.zeros([1, embed_dim]), pos_embed], axis=0 |
|
) |
|
return pos_embed |
|
|
|
|
|
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): |
|
""" |
|
grid_size: int of the grid height and width |
|
return: |
|
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
|
""" |
|
grid_h = np.arange(grid_size, dtype=np.float32) |
|
grid_w = np.arange(grid_size, dtype=np.float32) |
|
grid = np.meshgrid(grid_w, grid_h) |
|
grid = np.stack(grid, axis=0) |
|
|
|
grid = grid.reshape([2, 1, grid_size, grid_size]) |
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
|
if cls_token: |
|
pos_embed = np.concatenate( |
|
[np.zeros([1, embed_dim]), pos_embed], axis=0 |
|
) |
|
return pos_embed |
|
|
|
|
|
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
|
assert embed_dim % 2 == 0 |
|
|
|
|
|
emb_h = get_1d_sincos_pos_embed_from_grid( |
|
embed_dim // 2, grid[0] |
|
) |
|
emb_w = get_1d_sincos_pos_embed_from_grid( |
|
embed_dim // 2, grid[1] |
|
) |
|
|
|
emb = np.concatenate([emb_h, emb_w], axis=1) |
|
return emb |
|
|
|
|
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
|
""" |
|
embed_dim: output dimension for each position |
|
pos: a list of positions to be encoded: size (M,) |
|
out: (M, D) |
|
""" |
|
assert embed_dim % 2 == 0 |
|
omega = np.arange(embed_dim // 2, dtype=np.float) |
|
omega /= embed_dim / 2.0 |
|
omega = 1.0 / 10000**omega |
|
|
|
pos = pos.reshape(-1) |
|
out = np.einsum("m,d->md", pos, omega) |
|
|
|
emb_sin = np.sin(out) |
|
emb_cos = np.cos(out) |
|
|
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) |
|
return emb |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def interpolate_pos_embed(model, checkpoint_model): |
|
if "pos_embed" in checkpoint_model: |
|
pos_embed_checkpoint = checkpoint_model["pos_embed"] |
|
embedding_size = pos_embed_checkpoint.shape[-1] |
|
num_patches = model.patch_embed.num_patches |
|
num_extra_tokens = model.pos_embed.shape[-2] - num_patches |
|
|
|
orig_size = int( |
|
(pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5 |
|
) |
|
|
|
new_size = int(num_patches**0.5) |
|
|
|
if orig_size != new_size: |
|
print( |
|
"Position interpolate from %dx%d to %dx%d" |
|
% (orig_size, orig_size, new_size, new_size) |
|
) |
|
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] |
|
|
|
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] |
|
pos_tokens = pos_tokens.reshape( |
|
-1, orig_size, orig_size, embedding_size |
|
).permute(0, 3, 1, 2) |
|
pos_tokens = torch.nn.functional.interpolate( |
|
pos_tokens, |
|
size=(new_size, new_size), |
|
mode="bicubic", |
|
align_corners=False, |
|
) |
|
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) |
|
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) |
|
checkpoint_model["pos_embed"] = new_pos_embed |
|
|
|
|
|
def calc_mvit_feature_geometry(cfg): |
|
feat_size = [ |
|
[ |
|
cfg.DATA.NUM_FRAMES // cfg.MVIT.PATCH_STRIDE[0] |
|
if len(cfg.MVIT.PATCH_STRIDE) > 2 |
|
else 1, |
|
cfg.DATA.TRAIN_CROP_SIZE // cfg.MVIT.PATCH_STRIDE[-2], |
|
cfg.DATA.TRAIN_CROP_SIZE // cfg.MVIT.PATCH_STRIDE[-1], |
|
] |
|
for i in range(cfg.MVIT.DEPTH) |
|
] |
|
feat_stride = [ |
|
[ |
|
cfg.MVIT.PATCH_STRIDE[0] if len(cfg.MVIT.PATCH_STRIDE) > 2 else 1, |
|
cfg.MVIT.PATCH_STRIDE[-2], |
|
cfg.MVIT.PATCH_STRIDE[-1], |
|
] |
|
for i in range(cfg.MVIT.DEPTH) |
|
] |
|
for _, x in enumerate(cfg.MVIT.POOL_Q_STRIDE): |
|
for i in range(cfg.MVIT.DEPTH): |
|
if i >= x[0]: |
|
for j in range(len(feat_size[i])): |
|
feat_size[i][j] = feat_size[i][j] // x[j + 1] |
|
feat_stride[i][j] = feat_stride[i][j] * x[j + 1] |
|
return feat_size, feat_stride |