Spaces:
Running
Running
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
from scipy import interpolate | |
def load_ckpt(model, path): | |
"""Load checkpoint""" | |
state_dict = torch.load(path, map_location=torch.device("cpu"), weights_only=True)[ | |
"state_dict" | |
] | |
model.load_state_dict(state_dict, strict=True) | |
def load_ckpt_submission(model, path): | |
"""Load checkpoint""" | |
state_dict = torch.load(path, map_location=torch.device("cpu"), weights_only=True)[ | |
"state_dict" | |
] | |
state_dict = {k[6:]: v for k, v in state_dict.items()} | |
model.load_state_dict(state_dict, strict=True) | |
def resize_data(img1, img2, flow, factor=1.0): | |
_, _, h, w = img1.shape | |
h = int(h * factor) | |
w = int(w * factor) | |
img1 = F.interpolate(img1, (h, w), mode="area") | |
img2 = F.interpolate(img2, (h, w), mode="area") | |
flow = F.interpolate(flow, (h, w), mode="area") * factor | |
return img1, img2, flow | |
class InputPadder: | |
"""Pads images such that dimensions are divisible by 8""" | |
def __init__(self, dims, mode="sintel"): | |
self.ht, self.wd = dims[-2:] | |
pad_ht = (((self.ht // 16) + 1) * 16 - self.ht) % 16 | |
pad_wd = (((self.wd // 16) + 1) * 16 - self.wd) % 16 | |
self.mode = mode | |
if mode == "sintel": | |
self._pad = [ | |
pad_wd // 2, | |
pad_wd - pad_wd // 2, | |
pad_ht // 2, | |
pad_ht - pad_ht // 2, | |
0, | |
0, | |
] | |
elif mode == "downzero": | |
self._pad = [0, pad_wd, 0, pad_ht, 0, 0] | |
else: | |
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht, 0, 0] | |
def pad(self, input): | |
if self.mode == "downzero": | |
return F.pad(input, self._pad) | |
else: | |
return F.pad(input, self._pad, mode="replicate") | |
def unpad(self, x): | |
ht, wd = x.shape[-2:] | |
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] | |
return x[..., c[0] : c[1], c[2] : c[3]] | |
# Intended to be used before converting to torch.Tensor | |
def merge_flows(flow1, valid1, flow2, valid2, method="nearest"): | |
flow1 = np.transpose(flow1, axes=[2, 0, 1]) | |
_, ht, wd = flow1.shape | |
x1, y1 = np.meshgrid(np.arange(wd), np.arange(ht)) | |
x1_f = x1 + flow1[0] | |
y1_f = y1 + flow1[1] | |
x1 = x1.reshape(-1) | |
y1 = y1.reshape(-1) | |
x1_f = x1_f.reshape(-1) | |
y1_f = y1_f.reshape(-1) | |
valid1 = valid1.reshape(-1) | |
mask1 = ( | |
(valid1 > 0.5) & (x1_f >= 0) & (x1_f <= wd - 1) & (y1_f >= 0) & (y1_f <= ht - 1) | |
) | |
x1 = x1[mask1] | |
y1 = y1[mask1] | |
x1_f = x1_f[mask1] | |
y1_f = y1_f[mask1] | |
valid1 = valid1[mask1] | |
# STEP 1: interpolate valid values | |
new_valid1 = interpolate.interpn( | |
(np.arange(ht), np.arange(wd)), | |
valid2, | |
(y1_f, x1_f), | |
method=method, | |
bounds_error=False, | |
fill_value=0, | |
) | |
valid1 = new_valid1.round() | |
mask1 = valid1 > 0.5 | |
x1 = x1[mask1] | |
y1 = y1[mask1] | |
x1_f = x1_f[mask1] | |
y1_f = y1_f[mask1] | |
valid1 = valid1[mask1] | |
flow2_filled = fill_invalid(flow2, valid2) | |
# STEP 2: interpolate flow values | |
flow_x = interpolate.interpn( | |
(np.arange(ht), np.arange(wd)), | |
flow2_filled[:, :, 0], | |
(y1_f, x1_f), | |
method=method, | |
bounds_error=False, | |
fill_value=0, | |
) | |
flow_y = interpolate.interpn( | |
(np.arange(ht), np.arange(wd)), | |
flow2_filled[:, :, 1], | |
(y1_f, x1_f), | |
method=method, | |
bounds_error=False, | |
fill_value=0, | |
) | |
new_flow_x = np.zeros_like(flow1[0]) | |
new_flow_y = np.zeros_like(flow1[1]) | |
new_flow_x[(y1, x1)] = flow_x + x1_f - x1 | |
new_flow_y[(y1, x1)] = flow_y + y1_f - y1 | |
new_flow = np.stack([new_flow_x, new_flow_y], axis=0) | |
new_valid = np.zeros_like(flow1[0]) | |
new_valid[(y1, x1)] = valid1 | |
new_flow = np.transpose(new_flow, axes=[1, 2, 0]) | |
return new_flow, new_valid | |
def fill_invalid(flow, valid): | |
return fill_invalid_slow(flow, valid) | |
# Intended to be used before converting to torch.Tensor, slightly modification of forward_interpolate | |
def fill_invalid_slow(flow, valid): | |
flow = np.transpose(flow, axes=[2, 0, 1]) | |
dx, dy = flow[0], flow[1] | |
ht, wd = dx.shape | |
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) | |
x1 = x0.copy() | |
y1 = y0.copy() | |
x1 = x1.reshape(-1) | |
y1 = y1.reshape(-1) | |
dx = dx.reshape(-1) | |
dy = dy.reshape(-1) | |
valid_flat = valid.reshape(-1) | |
mask = valid_flat > 0.5 | |
x1 = x1[mask] | |
y1 = y1[mask] | |
dx = dx[mask] | |
dy = dy[mask] | |
flow_x = interpolate.griddata( | |
(x1, y1), dx, (x0, y0), method="nearest", fill_value=0 | |
) | |
flow_y = interpolate.griddata( | |
(x1, y1), dy, (x0, y0), method="nearest", fill_value=0 | |
) | |
flow = np.stack([flow_x, flow_y], axis=0) | |
flow = np.transpose(flow, axes=[1, 2, 0]) | |
return flow | |
def forward_interpolate(flow): | |
flow = flow.detach().cpu().numpy() | |
dx, dy = flow[0], flow[1] | |
ht, wd = dx.shape | |
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) | |
x1 = x0 + dx | |
y1 = y0 + dy | |
x1 = x1.reshape(-1) | |
y1 = y1.reshape(-1) | |
dx = dx.reshape(-1) | |
dy = dy.reshape(-1) | |
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) | |
x1 = x1[valid] | |
y1 = y1[valid] | |
dx = dx[valid] | |
dy = dy[valid] | |
flow_x = interpolate.griddata( | |
(x1, y1), dx, (x0, y0), method="nearest", fill_value=0 | |
) | |
flow_y = interpolate.griddata( | |
(x1, y1), dy, (x0, y0), method="nearest", fill_value=0 | |
) | |
flow = np.stack([flow_x, flow_y], axis=0) | |
return torch.from_numpy(flow).float() | |
def bilinear_sampler(img, coords, mode="bilinear", mask=False): | |
"""Wrapper for grid_sample, uses pixel coordinates""" | |
H, W = img.shape[-2:] | |
xgrid, ygrid = coords.split([1, 1], dim=-1) | |
xgrid = 2 * xgrid / (W - 1) - 1 | |
ygrid = 2 * ygrid / (H - 1) - 1 | |
grid = torch.cat([xgrid, ygrid], dim=-1) | |
img = F.grid_sample(img, grid, align_corners=True) | |
if mask: | |
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) | |
return img, mask.float() | |
return img | |
def coords_grid(batch, ht, wd, device): | |
coords = torch.meshgrid( | |
torch.arange(ht, device=device), torch.arange(wd, device=device), indexing="ij" | |
) | |
coords = torch.stack(coords[::-1], dim=0).float() | |
return coords[None].repeat(batch, 1, 1, 1) | |
def upflow8(flow, mode="bilinear"): | |
new_size = (8 * flow.shape[2], 8 * flow.shape[3]) | |
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) | |
def transform(T, p): | |
assert T.shape == (4, 4) | |
return np.einsum("H W j, i j -> H W i", p, T[:3, :3]) + T[:3, 3] | |
def from_homog(x): | |
return x[..., :-1] / x[..., [-1]] | |
def reproject(depth1, pose1, pose2, K1, K2): | |
H, W = depth1.shape | |
x, y = np.meshgrid(np.arange(W), np.arange(H), indexing="xy") | |
img_1_coords = np.stack((x, y, np.ones_like(x)), axis=-1).astype(np.float64) | |
cam1_coords = np.einsum( | |
"H W, H W j, i j -> H W i", depth1, img_1_coords, np.linalg.inv(K1) | |
) | |
rel_pose = np.linalg.inv(pose2) @ pose1 | |
cam2_coords = transform(rel_pose, cam1_coords) | |
return from_homog(np.einsum("H W j, i j -> H W i", cam2_coords, K2)) | |
def induced_flow(depth0, depth1, data): | |
H, W = depth0.shape | |
coords1 = reproject(depth0, data["T0"], data["T1"], data["K0"], data["K1"]) | |
x, y = np.meshgrid(np.arange(W), np.arange(H), indexing="xy") | |
coords0 = np.stack([x, y], axis=-1) | |
flow_01 = coords1 - coords0 | |
H, W = depth1.shape | |
coords1 = reproject(depth1, data["T1"], data["T0"], data["K1"], data["K0"]) | |
x, y = np.meshgrid(np.arange(W), np.arange(H), indexing="xy") | |
coords0 = np.stack([x, y], axis=-1) | |
flow_10 = coords1 - coords0 | |
return flow_01, flow_10 | |
def check_cycle_consistency(flow_01, flow_10): | |
flow_01 = torch.from_numpy(flow_01).permute(2, 0, 1)[None] | |
flow_10 = torch.from_numpy(flow_10).permute(2, 0, 1)[None] | |
H, W = flow_01.shape[-2:] | |
coords = coords_grid(1, H, W, flow_01.device) | |
coords1 = coords + flow_01 | |
flow_reprojected = bilinear_sampler(flow_10, coords1.permute(0, 2, 3, 1)) | |
cycle = flow_reprojected + flow_01 | |
cycle = torch.norm(cycle, dim=1) | |
mask = (cycle < 0.1 * min(H, W)).float() | |
return mask[0].numpy() | |