egorchistov's picture
Initial release
ac59957
import torch
import torch.nn.functional as F
import numpy as np
from scipy import interpolate
def load_ckpt(model, path):
"""Load checkpoint"""
state_dict = torch.load(path, map_location=torch.device("cpu"), weights_only=True)[
"state_dict"
]
model.load_state_dict(state_dict, strict=True)
def load_ckpt_submission(model, path):
"""Load checkpoint"""
state_dict = torch.load(path, map_location=torch.device("cpu"), weights_only=True)[
"state_dict"
]
state_dict = {k[6:]: v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=True)
def resize_data(img1, img2, flow, factor=1.0):
_, _, h, w = img1.shape
h = int(h * factor)
w = int(w * factor)
img1 = F.interpolate(img1, (h, w), mode="area")
img2 = F.interpolate(img2, (h, w), mode="area")
flow = F.interpolate(flow, (h, w), mode="area") * factor
return img1, img2, flow
class InputPadder:
"""Pads images such that dimensions are divisible by 8"""
def __init__(self, dims, mode="sintel"):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 16) + 1) * 16 - self.ht) % 16
pad_wd = (((self.wd // 16) + 1) * 16 - self.wd) % 16
self.mode = mode
if mode == "sintel":
self._pad = [
pad_wd // 2,
pad_wd - pad_wd // 2,
pad_ht // 2,
pad_ht - pad_ht // 2,
0,
0,
]
elif mode == "downzero":
self._pad = [0, pad_wd, 0, pad_ht, 0, 0]
else:
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht, 0, 0]
def pad(self, input):
if self.mode == "downzero":
return F.pad(input, self._pad)
else:
return F.pad(input, self._pad, mode="replicate")
def unpad(self, x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
return x[..., c[0] : c[1], c[2] : c[3]]
# Intended to be used before converting to torch.Tensor
def merge_flows(flow1, valid1, flow2, valid2, method="nearest"):
flow1 = np.transpose(flow1, axes=[2, 0, 1])
_, ht, wd = flow1.shape
x1, y1 = np.meshgrid(np.arange(wd), np.arange(ht))
x1_f = x1 + flow1[0]
y1_f = y1 + flow1[1]
x1 = x1.reshape(-1)
y1 = y1.reshape(-1)
x1_f = x1_f.reshape(-1)
y1_f = y1_f.reshape(-1)
valid1 = valid1.reshape(-1)
mask1 = (
(valid1 > 0.5) & (x1_f >= 0) & (x1_f <= wd - 1) & (y1_f >= 0) & (y1_f <= ht - 1)
)
x1 = x1[mask1]
y1 = y1[mask1]
x1_f = x1_f[mask1]
y1_f = y1_f[mask1]
valid1 = valid1[mask1]
# STEP 1: interpolate valid values
new_valid1 = interpolate.interpn(
(np.arange(ht), np.arange(wd)),
valid2,
(y1_f, x1_f),
method=method,
bounds_error=False,
fill_value=0,
)
valid1 = new_valid1.round()
mask1 = valid1 > 0.5
x1 = x1[mask1]
y1 = y1[mask1]
x1_f = x1_f[mask1]
y1_f = y1_f[mask1]
valid1 = valid1[mask1]
flow2_filled = fill_invalid(flow2, valid2)
# STEP 2: interpolate flow values
flow_x = interpolate.interpn(
(np.arange(ht), np.arange(wd)),
flow2_filled[:, :, 0],
(y1_f, x1_f),
method=method,
bounds_error=False,
fill_value=0,
)
flow_y = interpolate.interpn(
(np.arange(ht), np.arange(wd)),
flow2_filled[:, :, 1],
(y1_f, x1_f),
method=method,
bounds_error=False,
fill_value=0,
)
new_flow_x = np.zeros_like(flow1[0])
new_flow_y = np.zeros_like(flow1[1])
new_flow_x[(y1, x1)] = flow_x + x1_f - x1
new_flow_y[(y1, x1)] = flow_y + y1_f - y1
new_flow = np.stack([new_flow_x, new_flow_y], axis=0)
new_valid = np.zeros_like(flow1[0])
new_valid[(y1, x1)] = valid1
new_flow = np.transpose(new_flow, axes=[1, 2, 0])
return new_flow, new_valid
def fill_invalid(flow, valid):
return fill_invalid_slow(flow, valid)
# Intended to be used before converting to torch.Tensor, slightly modification of forward_interpolate
def fill_invalid_slow(flow, valid):
flow = np.transpose(flow, axes=[2, 0, 1])
dx, dy = flow[0], flow[1]
ht, wd = dx.shape
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
x1 = x0.copy()
y1 = y0.copy()
x1 = x1.reshape(-1)
y1 = y1.reshape(-1)
dx = dx.reshape(-1)
dy = dy.reshape(-1)
valid_flat = valid.reshape(-1)
mask = valid_flat > 0.5
x1 = x1[mask]
y1 = y1[mask]
dx = dx[mask]
dy = dy[mask]
flow_x = interpolate.griddata(
(x1, y1), dx, (x0, y0), method="nearest", fill_value=0
)
flow_y = interpolate.griddata(
(x1, y1), dy, (x0, y0), method="nearest", fill_value=0
)
flow = np.stack([flow_x, flow_y], axis=0)
flow = np.transpose(flow, axes=[1, 2, 0])
return flow
def forward_interpolate(flow):
flow = flow.detach().cpu().numpy()
dx, dy = flow[0], flow[1]
ht, wd = dx.shape
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
x1 = x0 + dx
y1 = y0 + dy
x1 = x1.reshape(-1)
y1 = y1.reshape(-1)
dx = dx.reshape(-1)
dy = dy.reshape(-1)
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
x1 = x1[valid]
y1 = y1[valid]
dx = dx[valid]
dy = dy[valid]
flow_x = interpolate.griddata(
(x1, y1), dx, (x0, y0), method="nearest", fill_value=0
)
flow_y = interpolate.griddata(
(x1, y1), dy, (x0, y0), method="nearest", fill_value=0
)
flow = np.stack([flow_x, flow_y], axis=0)
return torch.from_numpy(flow).float()
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
"""Wrapper for grid_sample, uses pixel coordinates"""
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1, 1], dim=-1)
xgrid = 2 * xgrid / (W - 1) - 1
ygrid = 2 * ygrid / (H - 1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def coords_grid(batch, ht, wd, device):
coords = torch.meshgrid(
torch.arange(ht, device=device), torch.arange(wd, device=device), indexing="ij"
)
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
def upflow8(flow, mode="bilinear"):
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
def transform(T, p):
assert T.shape == (4, 4)
return np.einsum("H W j, i j -> H W i", p, T[:3, :3]) + T[:3, 3]
def from_homog(x):
return x[..., :-1] / x[..., [-1]]
def reproject(depth1, pose1, pose2, K1, K2):
H, W = depth1.shape
x, y = np.meshgrid(np.arange(W), np.arange(H), indexing="xy")
img_1_coords = np.stack((x, y, np.ones_like(x)), axis=-1).astype(np.float64)
cam1_coords = np.einsum(
"H W, H W j, i j -> H W i", depth1, img_1_coords, np.linalg.inv(K1)
)
rel_pose = np.linalg.inv(pose2) @ pose1
cam2_coords = transform(rel_pose, cam1_coords)
return from_homog(np.einsum("H W j, i j -> H W i", cam2_coords, K2))
def induced_flow(depth0, depth1, data):
H, W = depth0.shape
coords1 = reproject(depth0, data["T0"], data["T1"], data["K0"], data["K1"])
x, y = np.meshgrid(np.arange(W), np.arange(H), indexing="xy")
coords0 = np.stack([x, y], axis=-1)
flow_01 = coords1 - coords0
H, W = depth1.shape
coords1 = reproject(depth1, data["T1"], data["T0"], data["K1"], data["K0"])
x, y = np.meshgrid(np.arange(W), np.arange(H), indexing="xy")
coords0 = np.stack([x, y], axis=-1)
flow_10 = coords1 - coords0
return flow_01, flow_10
def check_cycle_consistency(flow_01, flow_10):
flow_01 = torch.from_numpy(flow_01).permute(2, 0, 1)[None]
flow_10 = torch.from_numpy(flow_10).permute(2, 0, 1)[None]
H, W = flow_01.shape[-2:]
coords = coords_grid(1, H, W, flow_01.device)
coords1 = coords + flow_01
flow_reprojected = bilinear_sampler(flow_10, coords1.permute(0, 2, 3, 1))
cycle = flow_reprojected + flow_01
cycle = torch.norm(cycle, dim=1)
mask = (cycle < 0.1 * min(H, W)).float()
return mask[0].numpy()