Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import utils.basic | |
import torch.nn.functional as F | |
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): | |
r"""Sample a tensor using bilinear interpolation | |
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at | |
coordinates :attr:`coords` using bilinear interpolation. It is the same | |
as `torch.nn.functional.grid_sample()` but with a different coordinate | |
convention. | |
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where | |
:math:`B` is the batch size, :math:`C` is the number of channels, | |
:math:`H` is the height of the image, and :math:`W` is the width of the | |
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is | |
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. | |
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, | |
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note | |
that in this case the order of the components is slightly different | |
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. | |
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be | |
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the | |
left-most image pixel :math:`W-1` to the center of the right-most | |
pixel. | |
If `align_corners` is `False`, the coordinate :math:`x` is assumed to | |
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of | |
the left-most pixel :math:`W` to the right edge of the right-most | |
pixel. | |
Similar conventions apply to the :math:`y` for the range | |
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range | |
:math:`[0,T-1]` and :math:`[0,T]`. | |
Args: | |
input (Tensor): batch of input images. | |
coords (Tensor): batch of coordinates. | |
align_corners (bool, optional): Coordinate convention. Defaults to `True`. | |
padding_mode (str, optional): Padding mode. Defaults to `"border"`. | |
Returns: | |
Tensor: sampled points. | |
""" | |
sizes = input.shape[2:] | |
assert len(sizes) in [2, 3] | |
if len(sizes) == 3: | |
# t x y -> x y t to match dimensions T H W in grid_sample | |
coords = coords[..., [1, 2, 0]] | |
if align_corners: | |
coords = coords * torch.tensor( | |
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device | |
) | |
else: | |
coords = coords * torch.tensor( | |
[2 / size for size in reversed(sizes)], device=coords.device | |
) | |
coords -= 1 | |
return F.grid_sample( | |
input, coords, align_corners=align_corners, padding_mode=padding_mode | |
) | |
def sample_features4d(input, coords): | |
r"""Sample spatial features | |
`sample_features4d(input, coords)` samples the spatial features | |
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. | |
The field is sampled at coordinates :attr:`coords` using bilinear | |
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, | |
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the | |
same convention as :func:`bilinear_sampler` with `align_corners=True`. | |
The output tensor has one feature per point, and has shape :math:`(B, | |
R, C)`. | |
Args: | |
input (Tensor): spatial features. | |
coords (Tensor): points. | |
Returns: | |
Tensor: sampled features. | |
""" | |
B, _, _, _ = input.shape | |
# B R 2 -> B R 1 2 | |
coords = coords.unsqueeze(2) | |
# B C R 1 | |
feats = bilinear_sampler(input, coords) | |
return feats.permute(0, 2, 1, 3).view( | |
B, -1, feats.shape[1] * feats.shape[3] | |
) # B C R 1 -> B R C | |
def sample_features5d(input, coords): | |
r"""Sample spatio-temporal features | |
`sample_features5d(input, coords)` works in the same way as | |
:func:`sample_features4d` but for spatio-temporal features and points: | |
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is | |
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i, | |
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`. | |
Args: | |
input (Tensor): spatio-temporal features. | |
coords (Tensor): spatio-temporal points. | |
Returns: | |
Tensor: sampled features. | |
""" | |
B, T, _, _, _ = input.shape | |
# B T C H W -> B C T H W | |
input = input.permute(0, 2, 1, 3, 4) | |
# B R1 R2 3 -> B R1 R2 1 3 | |
coords = coords.unsqueeze(3) | |
# B C R1 R2 1 | |
feats = bilinear_sampler(input, coords) | |
return feats.permute(0, 2, 3, 1, 4).view( | |
B, feats.shape[2], feats.shape[3], feats.shape[1] | |
) # B C R1 R2 1 -> B R1 R2 C | |
def bilinear_sample2d(im, x, y, return_inbounds=False): | |
# x and y are each B, N | |
# output is B, C, N | |
B, C, H, W = list(im.shape) | |
N = list(x.shape)[1] | |
x = x.float() | |
y = y.float() | |
H_f = torch.tensor(H, dtype=torch.float32) | |
W_f = torch.tensor(W, dtype=torch.float32) | |
# inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float() | |
max_y = (H_f - 1).int() | |
max_x = (W_f - 1).int() | |
x0 = torch.floor(x).int() | |
x1 = x0 + 1 | |
y0 = torch.floor(y).int() | |
y1 = y0 + 1 | |
x0_clip = torch.clamp(x0, 0, max_x) | |
x1_clip = torch.clamp(x1, 0, max_x) | |
y0_clip = torch.clamp(y0, 0, max_y) | |
y1_clip = torch.clamp(y1, 0, max_y) | |
dim2 = W | |
dim1 = W * H | |
base = torch.arange(0, B, dtype=torch.int64, device=x.device)*dim1 | |
base = torch.reshape(base, [B, 1]).repeat([1, N]) | |
base_y0 = base + y0_clip * dim2 | |
base_y1 = base + y1_clip * dim2 | |
idx_y0_x0 = base_y0 + x0_clip | |
idx_y0_x1 = base_y0 + x1_clip | |
idx_y1_x0 = base_y1 + x0_clip | |
idx_y1_x1 = base_y1 + x1_clip | |
# use the indices to lookup pixels in the flat image | |
# im is B x C x H x W | |
# move C out to last dim | |
im_flat = (im.permute(0, 2, 3, 1)).reshape(B*H*W, C) | |
i_y0_x0 = im_flat[idx_y0_x0.long()] | |
i_y0_x1 = im_flat[idx_y0_x1.long()] | |
i_y1_x0 = im_flat[idx_y1_x0.long()] | |
i_y1_x1 = im_flat[idx_y1_x1.long()] | |
# Finally calculate interpolated values. | |
x0_f = x0.float() | |
x1_f = x1.float() | |
y0_f = y0.float() | |
y1_f = y1.float() | |
w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2) | |
w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2) | |
w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2) | |
w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2) | |
output = w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + \ | |
w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1 | |
# output is B*N x C | |
output = output.view(B, -1, C) | |
output = output.permute(0, 2, 1) | |
# output is B x C x N | |
if return_inbounds: | |
x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte() | |
y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte() | |
inbounds = (x_valid & y_valid).float() | |
inbounds = inbounds.reshape(B, N) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1) | |
return output, inbounds | |
return output # B, C, N | |