SceneDINO / scenedino /common /array_operations.py
jev-aleks's picture
scenedino init
9e15541
from functools import partial
import torch
import torch.nn.functional as F
unsqueezer = partial(torch.unsqueeze, dim=0)
def map_fn(batch, fn):
if isinstance(batch, dict):
for k in batch.keys():
batch[k] = map_fn(batch[k], fn)
return batch
elif isinstance(batch, list):
return [map_fn(e, fn) for e in batch]
else:
return fn(batch)
def to(data, device, non_blocking=True):
if isinstance(data, dict):
return {k: to(data[k], device, non_blocking=non_blocking) for k in data.keys()}
elif isinstance(data, list):
return [to(v, device, non_blocking=non_blocking) for v in data]
else:
return data.to(device, non_blocking=non_blocking)
def set_requires_grad(nets, requires_grad=False):
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
def mask_mean(t: torch.Tensor, m: torch.Tensor, dim=None, keepdim=False):
t = t.clone()
t[m] = 0
els = 1
if dim is None or len(dim)==0:
dim = list(range(len(t.shape)))
for d in dim:
els *= t.shape[d]
return torch.sum(t, dim=dim, keepdim=keepdim) / (els - torch.sum(m.to(torch.float), dim=dim, keepdim=keepdim))
def apply_crop(array, crop):
return array[crop[0]:crop[0] + crop[2], crop[1]:crop[1] + crop[3]]
def shrink_mask(mask, shrink=3):
mask = F.avg_pool2d(mask.to(torch.float32), kernel_size=shrink, padding=shrink // 2, stride=1)
return (mask == 1.).to(torch.float32)
def get_mask(size, border=5, device=None):
mask = torch.ones(size, dtype=torch.float32)
mask = shrink_mask(mask, border)
if device is not None:
mask = mask.to(device)
return mask
def get_grid(H, W, normalize=True):
if normalize:
h_range = torch.linspace(-1,1,H)
w_range = torch.linspace(-1,1,W)
else:
h_range = torch.arange(0,H)
w_range = torch.arange(0,W)
grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).flip(2).float() # flip h,w to x,y
return grid
def detach(t):
if isinstance(t, tuple):
return tuple(t_.detach() for t_ in t)
else: return t.detach()