Spaces:
Runtime error
Runtime error
| import os | |
| from scipy.interpolate import griddata as interp_grid | |
| from tqdm import tqdm | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import torch | |
| from packaging import version as pver | |
| import torch.nn.functional as F | |
| def trajectory_to_camera_poses_v1(traj, intrinsics, sample_n_frames, zc = 1.0): | |
| if not isinstance(zc, list): | |
| assert isinstance(zc, float) or isinstance(zc, int), 'zc should be a float or int or a list of float or int' | |
| zc = [zc] * traj.shape[0] | |
| zc = np.array(zc, dtype=traj.dtype) | |
| xc = (traj[:, 0] - intrinsics[0, 2]) * zc / intrinsics[0, 0] | |
| yc = (traj[:, 1] - intrinsics[0, 3]) * zc / intrinsics[0, 1] | |
| first_frame_w2c = np.array([ | |
| [1, 0, 0, 0], | |
| [0, 1, 0, 0], | |
| [0, 0, 1, 0], | |
| [0, 0, 0, 1] | |
| ], dtype=np.float32) | |
| xw = xc[0] | |
| yw = yc[0] | |
| zw = zc[0] | |
| # zw = 0 | |
| # print(f'zw: {zw}') | |
| Tx = xc - xw | |
| Ty = yc - yw | |
| Tz = zc - zw | |
| traj_w2c = [first_frame_w2c] | |
| for i in range(1, sample_n_frames): | |
| w2c_mat = np.array([ | |
| [1, 0, 0, Tx[i]], | |
| [0, 1, 0, Ty[i]], | |
| [0, 0, 1, Tz[i]], | |
| [0, 0, 0, 1] | |
| ], dtype=first_frame_w2c.dtype) | |
| traj_w2c.append(w2c_mat) | |
| traj_w2c = np.stack(traj_w2c, axis=0) | |
| return traj_w2c # [n_frame, 4, 4] | |
| def Unprojected(image_curr: np.array, | |
| depth_curr: np.array, | |
| RTs: np.array, | |
| H: int = 320, W: int = 576, | |
| K: np.array = None, | |
| dtype: np.dtype = np.float32): | |
| ''' | |
| image_curr: [H, W, c], float, 0-1 | |
| depth_curr: [H, W], float32, in meters | |
| RTs: [num_frames, 3, 4], the camera poses, w2c | |
| ''' | |
| x, y = np.meshgrid(np.arange(W, dtype=dtype), np.arange(H, dtype=dtype), indexing='xy') # pixels | |
| # ceter_depth = np.mean(depth_curr[cam.H//2-10:cam.H//2+10, cam.W//2-10:cam.W//2+10]) | |
| RTs = RTs.astype(dtype) | |
| depth_curr = depth_curr.astype(dtype) | |
| image_curr = image_curr.reshape(H*W, -1).astype(dtype) # [0, 1] | |
| R0, T0 = RTs[0, :, :3], RTs[0, :, 3:4] | |
| # new_pts_coord_world2 = image_curr | |
| pts_coord_cam = np.matmul(np.linalg.inv(K), np.stack((x*depth_curr, y*depth_curr, 1*depth_curr), axis=0).reshape(3,-1)) | |
| new_pts_coord_world2 = (np.linalg.inv(R0).dot(pts_coord_cam) - np.linalg.inv(R0).dot(T0)) ## new_pts_coord_world2 | |
| new_pts_colors2 = image_curr ## new_pts_colors2 | |
| pts_coord_world, pts_colors = new_pts_coord_world2.copy(), new_pts_colors2.copy() | |
| images = [] | |
| for i in tqdm(range(1, RTs.shape[0])): | |
| R, T = RTs[i, :, :3], RTs[i, :, 3:4] | |
| ### Transform world to pixel | |
| pts_coord_cam2 = R.dot(pts_coord_world) + T ### Same with c2w*world_coord (in homogeneous space) | |
| pixel_coord_cam2 = np.matmul(K, pts_coord_cam2) #.reshape(3,H,W).transpose(1,2,0).astype(np.float32) | |
| valid_idx = np.where(np.logical_and.reduce((pixel_coord_cam2[2]>0, | |
| pixel_coord_cam2[0]/pixel_coord_cam2[2]>=0, | |
| pixel_coord_cam2[0]/pixel_coord_cam2[2]<=W-1, | |
| pixel_coord_cam2[1]/pixel_coord_cam2[2]>=0, | |
| pixel_coord_cam2[1]/pixel_coord_cam2[2]<=H-1)))[0] | |
| pixel_coord_cam2 = pixel_coord_cam2[:2, valid_idx]/pixel_coord_cam2[-1:, valid_idx] | |
| # round_coord_cam2 = np.round(pixel_coord_cam2).astype(np.int32) | |
| x, y = np.meshgrid(np.arange(W, dtype=dtype), np.arange(H, dtype=dtype), indexing='xy') | |
| grid = np.stack((x,y), axis=-1).reshape(-1,2) | |
| image2 = interp_grid(pixel_coord_cam2.transpose(1,0), pts_colors[valid_idx], grid, method='linear', fill_value=0).reshape(H,W,-1) | |
| images.append(image2) | |
| print(f'Total {len(images)} images, each image shape: {images[0].shape}') | |
| return images | |
| class Camera(object): | |
| def __init__(self, entry): | |
| fx, fy, cx, cy = entry[1:5] | |
| self.fx = fx | |
| self.fy = fy | |
| self.cx = cx | |
| self.cy = cy | |
| w2c_mat = np.array(entry[7:]).reshape(3, 4) | |
| w2c_mat_4x4 = np.eye(4) | |
| w2c_mat_4x4[:3, :] = w2c_mat | |
| self.w2c_mat = w2c_mat_4x4 | |
| self.c2w_mat = np.linalg.inv(w2c_mat_4x4) | |
| def get_relative_pose(cam_params, zero_t_first_frame): | |
| abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] | |
| abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] | |
| source_cam_c2w = abs_c2ws[0] | |
| if zero_t_first_frame: | |
| cam_to_origin = 0 | |
| else: | |
| cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3]) | |
| target_cam_c2w = np.array([ | |
| [1, 0, 0, 0], | |
| [0, 1, 0, -cam_to_origin], | |
| [0, 0, 1, 0], | |
| [0, 0, 0, 1] | |
| ]) | |
| abs2rel = target_cam_c2w @ abs_w2cs[0] | |
| ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] | |
| ret_poses = np.array(ret_poses, dtype=np.float32) | |
| return ret_poses | |
| def custom_meshgrid(*args): | |
| # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid | |
| if pver.parse(torch.__version__) < pver.parse('1.10'): | |
| return torch.meshgrid(*args) | |
| else: | |
| return torch.meshgrid(*args, indexing='ij') | |
| def ray_condition(K, c2w, H, W, device, flip_flag=None): | |
| # c2w: B, V, 4, 4 | |
| # K: B, V, 4 | |
| B, V = K.shape[:2] | |
| j, i = custom_meshgrid( | |
| torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), | |
| torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), | |
| ) | |
| i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW] | |
| j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW] | |
| n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0 | |
| if n_flip > 0: | |
| j_flip, i_flip = custom_meshgrid( | |
| torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), | |
| torch.linspace(W - 1, 0, W, device=device, dtype=c2w.dtype) | |
| ) | |
| i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5 | |
| j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5 | |
| i[:, flip_flag, ...] = i_flip | |
| j[:, flip_flag, ...] = j_flip | |
| fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 | |
| zs = torch.ones_like(i) # [B, V, HxW] | |
| xs = (i - cx) / fx * zs | |
| ys = (j - cy) / fy * zs | |
| zs = zs.expand_as(ys) | |
| directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 | |
| directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 | |
| rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, HW, 3 | |
| rays_o = c2w[..., :3, 3] # B, V, 3 | |
| rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, HW, 3 | |
| # c2w @ dirctions | |
| rays_dxo = torch.cross(rays_o, rays_d) # B, V, HW, 3 | |
| plucker = torch.cat([rays_dxo, rays_d], dim=-1) | |
| plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 | |
| # plucker = plucker.permute(0, 1, 4, 2, 3) | |
| return plucker, rays_o, rays_d | |
| def RT2Plucker(RT, num_frames, sample_size, fx, fy, cx, cy): | |
| ''' | |
| RT: [num_frames, 3, 4] | |
| ''' | |
| cam_params = [] | |
| for i in range(num_frames): | |
| cam_params.append(Camera([0, fx, fy, cx, cy, 0, 0, RT[i].reshape(-1)])) | |
| print(cam_params[0].c2w_mat.shape) | |
| intrinsics = np.asarray([[cam_param.fx * sample_size[1], | |
| cam_param.fy * sample_size[0], | |
| cam_param.cx * sample_size[1], | |
| cam_param.cy * sample_size[0]] | |
| for cam_param in cam_params], dtype=np.float32) | |
| intrinsics = torch.as_tensor(intrinsics)[None] | |
| print(intrinsics.shape) | |
| relative_pose = True | |
| zero_t_first_frame = True | |
| use_flip = False | |
| if relative_pose: | |
| c2w_poses = get_relative_pose(cam_params, zero_t_first_frame) | |
| else: | |
| c2w_poses = np.array([cam_param.c2w_mat for cam_param in cam_params], dtype=np.float32) | |
| c2w = torch.as_tensor(c2w_poses)[None] # [1, n_frame, 4, 4] | |
| flip_flag = torch.zeros(num_frames, dtype=torch.bool, device=c2w.device) | |
| plucker_embedding, rays_o, rays_d = ray_condition(intrinsics, c2w, sample_size[0], sample_size[1], device='cpu', | |
| flip_flag=flip_flag) | |
| plucker_embedding = plucker_embedding[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W | |
| plucker_embedding = plucker_embedding.permute(1, 0, 2, 3).contiguous() # 6, V, H, W | |
| return plucker_embedding, rays_o, rays_d | |
| def visualized_trajectories(images, trajectories, save_path, save_each_frame=False): | |
| ''' | |
| images: [n_frame, H, W, 3], numpy, 0-255 | |
| trajectories: [[n_frame, 2]], list[numpy], x, y | |
| save_path: str, end with .gif | |
| ''' | |
| pil_image = [] | |
| H, W = images.shape[1], images.shape[2] | |
| n_frame = images.shape[0] | |
| for i in range(n_frame): | |
| image = images[i].astype(np.uint8) | |
| image = cv2.UMat(image) | |
| # print(f'image: {image.shape} {image.dtype} {type(image)}') | |
| # | |
| for traj in trajectories: | |
| line_data = traj[:i+1] | |
| if len(line_data) == 1: | |
| y = int(round(line_data[0][1])) | |
| x = int(round(line_data[0][0])) | |
| if y >= H: | |
| y = H - 1 | |
| if line_data[0][0] >= W: | |
| x = W - 1 | |
| # image[y, x] = [255, 0, 0] | |
| cv2.circle(image, (x, y), 1, (0, 255, 0), 8) | |
| else: | |
| for j in range(1, len(line_data)): | |
| x0, y0 = int(round(line_data[j-1][0])), int(round(line_data[j-1][1])) | |
| x1, y1 = int(round(line_data[j][0])), int(round(line_data[j][1])) | |
| if y0 >= H: | |
| y0 = H - 1 | |
| if y1 >= H: | |
| y1 = H - 1 | |
| if x0 >= W: | |
| x0 = W - 1 | |
| if x1 >= W: | |
| x1 = W - 1 | |
| if x0 != x1 or y0 != y1: | |
| if j == len(line_data) - 1: | |
| line_length = np.sqrt((x1 - x0) ** 2 + (y1 - y0) ** 2) | |
| arrow_head_length = 10 | |
| tip_length = arrow_head_length / line_length | |
| cv2.arrowedLine(image, (x0, y0), (x1, y1), (255, 0, 0), 6, tipLength=tip_length) | |
| else: | |
| cv2.line(image, (x0, y0), (x1, y1), (255, 0, 0), 6) | |
| cv2.circle(image, (x, y), 1, (0, 255, 0), 8) | |
| # cv2.circle(image, (x1, y1), 1, (0, 0, 255), 5) | |
| image = cv2.UMat.get(image) | |
| pil_image.append(Image.fromarray(image)) | |
| pil_image[0].save(save_path, save_all=True, append_images=pil_image[1:], loop=0, duration=100) | |
| # save each frame | |
| if save_each_frame: | |
| img_save_root = save_path.replace('.gif', '') | |
| os.makedirs(img_save_root, exist_ok=True) | |
| for i, img in enumerate(pil_image): | |
| img.save(os.path.join(img_save_root, f'{i:05d}.png')) | |
| def roll_with_ignore_multidim(arr, shifts): | |
| result = np.copy(arr) | |
| for axis, shift in enumerate(shifts): | |
| result = roll_with_ignore(result, shift, axis) | |
| return result | |
| def roll_with_ignore(arr, shift, axis): | |
| result = np.zeros_like(arr) | |
| if shift > 0: | |
| result[tuple(slice(shift, None) if i == axis else slice(None) for i in range(arr.ndim))] = \ | |
| arr[tuple(slice(None, -shift) if i == axis else slice(None) for i in range(arr.ndim))] | |
| elif shift < 0: | |
| result[tuple(slice(None, shift) if i == axis else slice(None) for i in range(arr.ndim))] = \ | |
| arr[tuple(slice(-shift, None) if i == axis else slice(None) for i in range(arr.ndim))] | |
| else: | |
| result = arr | |
| return result | |
| def dilate_mask_pytorch(mask, kernel_size=2): | |
| ''' | |
| mask: torch.Tensor, shape [b, c, h, w] | |
| kernel_size: int | |
| ''' | |
| # Define a smaller kernel for the dilation | |
| kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=mask.dtype, device=mask.device) | |
| # Perform the dilation operation | |
| dilated_mask = F.conv2d(mask, kernel, padding=kernel_size//2) | |
| # Ensure the output is still a binary mask (0 and 1) | |
| dilated_mask = (dilated_mask > 0).to(mask.dtype).to(mask.device) | |
| return dilated_mask | |
| def smooth_mask(mask, kernel_size=3): | |
| ''' | |
| mask: torch.Tensor, shape [b, c, h, w] | |
| kernel_size: int | |
| ''' | |
| # Define a Gaussian kernel for smoothing | |
| kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=mask.dtype, device=mask.device) / (kernel_size * kernel_size) | |
| # Perform the smoothing operation | |
| smoothed_mask = F.conv2d(mask, kernel, padding=kernel_size//2) | |
| # Ensure the output is still a binary mask (0 and 1) | |
| smoothed_mask = (smoothed_mask > 0.5).to(mask.dtype).to(mask.device) | |
| return smoothed_mask | |