Spaces:
Configuration error
Configuration error
| import trimesh | |
| import torch | |
| import numpy as np | |
| import os | |
| import math | |
| import torchvision | |
| from tqdm import tqdm | |
| import cv2 # Assuming OpenCV is used for image saving | |
| from PIL import Image | |
| import pytorch3d | |
| import random | |
| from PIL import ImageGrab | |
| torchvision | |
| from torchvision.utils import save_image | |
| from pytorch3d.renderer import ( | |
| PointsRasterizationSettings, | |
| PointsRenderer, | |
| PointsRasterizer, | |
| AlphaCompositor, | |
| PerspectiveCameras, | |
| ) | |
| import imageio | |
| import torch.nn.functional as F | |
| from torchvision.transforms import ToPILImage | |
| import copy | |
| from scipy.interpolate import interp1d | |
| from scipy.interpolate import UnivariateSpline | |
| from scipy.spatial.transform import Rotation as R | |
| from scipy.spatial.transform import Slerp | |
| import sys | |
| sys.path.append('./extern/dust3r') | |
| from dust3r.utils.device import to_numpy | |
| import matplotlib.pyplot as plt | |
| import matplotlib.colors as mcolors | |
| def save_video(data,images_path,folder=None): | |
| if isinstance(data, np.ndarray): | |
| tensor_data = (torch.from_numpy(data) * 255).to(torch.uint8) | |
| elif isinstance(data, torch.Tensor): | |
| tensor_data = (data.detach().cpu() * 255).to(torch.uint8) | |
| elif isinstance(data, list): | |
| folder = [folder]*len(data) | |
| images = [np.array(Image.open(os.path.join(folder_name,path))) for folder_name,path in zip(folder,data)] | |
| stacked_images = np.stack(images, axis=0) | |
| tensor_data = torch.from_numpy(stacked_images).to(torch.uint8) | |
| torchvision.io.write_video(images_path, tensor_data, fps=8, video_codec='h264', options={'crf': '10'}) | |
| def get_input_dict(img_tensor,idx,dtype = torch.float32): | |
| return {'img':F.interpolate(img_tensor.to(dtype), size=(288, 512), mode='bilinear', align_corners=False), 'true_shape': np.array([[288, 512]], dtype=np.int32), 'idx': idx, 'instance': str(idx), 'img_ori':img_tensor.to(dtype)} | |
| # return {'img':F.interpolate(img_tensor.to(dtype), size=(288, 512), mode='bilinear', align_corners=False), 'true_shape': np.array([[288, 512]], dtype=np.int32), 'idx': idx, 'instance': str(idx), 'img_ori':ToPILImage()((img_tensor.squeeze(0)+ 1) / 2)} | |
| def rotate_theta(c2ws_input, theta, phi, r, device): | |
| # theta: 图像的倾角,新的y’轴(位于yoz平面)与y轴的夹角 | |
| #让相机在以[0,0,depth_avg]为球心的球面上运动,可以先让其在[0,0,0]为球心的球面运动,方便计算旋转矩阵,之后在平移 | |
| c2ws = copy.deepcopy(c2ws_input) | |
| c2ws[:,2, 3] = c2ws[:,2, 3] + r #将相机坐标系沿着世界坐标系-z方向平移r | |
| # 计算旋转向量 | |
| theta = torch.deg2rad(torch.tensor(theta)).to(device) | |
| phi = torch.deg2rad(torch.tensor(phi)).to(device) | |
| v = torch.tensor([0, torch.cos(theta), torch.sin(theta)]) | |
| # 计算反对称矩阵 | |
| v_x = torch.zeros(3, 3).to(device) | |
| v_x[0, 1] = -v[2] | |
| v_x[0, 2] = v[1] | |
| v_x[1, 0] = v[2] | |
| v_x[1, 2] = -v[0] | |
| v_x[2, 0] = -v[1] | |
| v_x[2, 1] = v[0] | |
| # 计算反对称矩阵的平方 | |
| v_x_square = torch.matmul(v_x, v_x) | |
| # 计算旋转矩阵 | |
| R = torch.eye(3).to(device) + torch.sin(phi) * v_x + (1 - torch.cos(phi)) * v_x_square | |
| # 转换为齐次表示 | |
| R_h = torch.eye(4) | |
| R_h[:3, :3] = R | |
| Rot_mat = R_h.to(device) | |
| c2ws = torch.matmul(Rot_mat, c2ws) | |
| c2ws[:,2, 3]= c2ws[:,2, 3] - r #最后减去r,相当于绕着z=|r|为中心旋转 | |
| return c2ws | |
| def sphere2pose(c2ws_input, theta, phi, r, device): | |
| c2ws = copy.deepcopy(c2ws_input) | |
| #先沿着世界坐标系z轴方向平移再旋转 | |
| c2ws[:,2,3] += r | |
| theta = torch.deg2rad(torch.tensor(theta)).to(device) | |
| sin_value_x = torch.sin(theta) | |
| cos_value_x = torch.cos(theta) | |
| rot_mat_x = torch.tensor([[1, 0, 0, 0], | |
| [0, cos_value_x, -sin_value_x, 0], | |
| [0, sin_value_x, cos_value_x, 0], | |
| [0, 0, 0, 1]]).unsqueeze(0).repeat(c2ws.shape[0],1,1).to(device) | |
| phi = torch.deg2rad(torch.tensor(phi)).to(device) | |
| sin_value_y = torch.sin(phi) | |
| cos_value_y = torch.cos(phi) | |
| rot_mat_y = torch.tensor([[cos_value_y, 0, sin_value_y, 0], | |
| [0, 1, 0, 0], | |
| [-sin_value_y, 0, cos_value_y, 0], | |
| [0, 0, 0, 1]]).unsqueeze(0).repeat(c2ws.shape[0],1,1).to(device) | |
| c2ws = torch.matmul(rot_mat_x,c2ws) | |
| c2ws = torch.matmul(rot_mat_y,c2ws) | |
| return c2ws | |
| def generate_candidate_poses(c2ws_anchor,H,W,fs,c,theta, phi,num_candidates,device): | |
| # Initialize a camera. | |
| """ | |
| The camera coordinate sysmte in COLMAP is right-down-forward | |
| Pytorch3D is left-up-forward | |
| """ | |
| if num_candidates == 2: | |
| thetas = np.array([0,-theta]) | |
| phis = np.array([phi,phi]) | |
| elif num_candidates == 3: | |
| thetas = np.array([0,-theta,theta/2.]) #avoid too many downward | |
| phis = np.array([phi,phi,phi]) | |
| else: | |
| raise ValueError("NBV mode only supports 2 or 3 candidates per iteration.") | |
| c2ws_list = [] | |
| for th, ph in zip(thetas,phis): | |
| c2w_new = sphere2pose(c2ws_anchor, np.float32(th), np.float32(ph), r=None, device= device) | |
| c2ws_list.append(c2w_new) | |
| c2ws = torch.cat(c2ws_list,dim=0) | |
| num_views = c2ws.shape[0] | |
| R, T = c2ws[:,:3, :3], c2ws[:,:3, 3:] | |
| ## 将dust3r坐标系转成pytorch3d坐标系 | |
| R = torch.stack([-R[:,:, 0], -R[:,:, 1], R[:,:, 2]], 2) # from RDF to LUF for Rotation | |
| new_c2w = torch.cat([R, T], 2) | |
| w2c = torch.linalg.inv(torch.cat((new_c2w, torch.Tensor([[[0,0,0,1]]]).to(device).repeat(new_c2w.shape[0],1,1)),1)) | |
| R_new, T_new = w2c[:,:3, :3].permute(0,2,1), w2c[:,:3, 3] # convert R to row-major matrix | |
| image_size = ((H, W),) # (h, w) | |
| cameras = PerspectiveCameras(focal_length=fs, principal_point=c, in_ndc=False, image_size=image_size, R=R_new, T=T_new, device=device) | |
| return cameras,thetas,phis | |
| def generate_traj_specified(c2ws_anchor,H,W,fs,c,theta, phi,d_r,frame,device): | |
| # Initialize a camera. | |
| """ | |
| The camera coordinate sysmte in COLMAP is right-down-forward | |
| Pytorch3D is left-up-forward | |
| """ | |
| thetas = np.linspace(0,theta,frame) | |
| phis = np.linspace(0,phi,frame) | |
| rs = np.linspace(0,d_r*c2ws_anchor[0,2,3].cpu(),frame) | |
| c2ws_list = [] | |
| for th, ph, r in zip(thetas,phis,rs): | |
| c2w_new = sphere2pose(c2ws_anchor, np.float32(th), np.float32(ph), np.float32(r), device) | |
| c2ws_list.append(c2w_new) | |
| c2ws = torch.cat(c2ws_list,dim=0) | |
| num_views = c2ws.shape[0] | |
| R, T = c2ws[:,:3, :3], c2ws[:,:3, 3:] | |
| ## 将dust3r坐标系转成pytorch3d坐标系 | |
| R = torch.stack([-R[:,:, 0], -R[:,:, 1], R[:,:, 2]], 2) # from RDF to LUF for Rotation | |
| new_c2w = torch.cat([R, T], 2) | |
| w2c = torch.linalg.inv(torch.cat((new_c2w, torch.Tensor([[[0,0,0,1]]]).to(device).repeat(new_c2w.shape[0],1,1)),1)) | |
| R_new, T_new = w2c[:,:3, :3].permute(0,2,1), w2c[:,:3, 3] # convert R to row-major matrix | |
| image_size = ((H, W),) # (h, w) | |
| cameras = PerspectiveCameras(focal_length=fs, principal_point=c, in_ndc=False, image_size=image_size, R=R_new, T=T_new, device=device) | |
| return cameras,num_views | |
| def generate_traj_txt(c2ws_anchor,H,W,fs,c,phi, theta, r,frame,device,viz_traj=False, save_dir = None): | |
| # Initialize a camera. | |
| """ | |
| The camera coordinate sysmte in COLMAP is right-down-forward | |
| Pytorch3D is left-up-forward | |
| """ | |
| c2ws_anchor = c2ws_anchor.to(device) | |
| if len(phi)>3: | |
| phis = txt_interpolation(phi,frame,mode='smooth') | |
| phis[0] = phi[0] | |
| phis[-1] = phi[-1] | |
| else: | |
| phis = txt_interpolation(phi,frame,mode='linear') | |
| if len(theta)>3: | |
| thetas = txt_interpolation(theta,frame,mode='smooth') | |
| thetas[0] = theta[0] | |
| thetas[-1] = theta[-1] | |
| else: | |
| thetas = txt_interpolation(theta,frame,mode='linear') | |
| if len(r) >3: | |
| rs = txt_interpolation(r,frame,mode='smooth') | |
| rs[0] = r[0] | |
| rs[-1] = r[-1] | |
| else: | |
| rs = txt_interpolation(r,frame,mode='linear') | |
| rs = rs*c2ws_anchor[0,2,3].cpu().numpy() | |
| c2ws_list = [] | |
| for th, ph, r in zip(thetas,phis,rs): | |
| c2w_new = sphere2pose(c2ws_anchor, np.float32(th), np.float32(ph), np.float32(r), device) | |
| c2ws_list.append(c2w_new) | |
| c2ws = torch.cat(c2ws_list,dim=0) | |
| if viz_traj: | |
| poses = c2ws.cpu().numpy() | |
| # visualizer(poses, os.path.join(save_dir,'viz_traj.png')) | |
| frames = [visualizer_frame(poses, i) for i in range(len(poses))] | |
| save_video(np.array(frames)/255.,os.path.join(save_dir,'viz_traj.mp4')) | |
| num_views = c2ws.shape[0] | |
| R, T = c2ws[:,:3, :3], c2ws[:,:3, 3:] | |
| ## 将dust3r坐标系转成pytorch3d坐标系 | |
| R = torch.stack([-R[:,:, 0], -R[:,:, 1], R[:,:, 2]], 2) # from RDF to LUF for Rotation | |
| new_c2w = torch.cat([R, T], 2) | |
| w2c = torch.linalg.inv(torch.cat((new_c2w, torch.Tensor([[[0,0,0,1]]]).to(device).repeat(new_c2w.shape[0],1,1)),1)) | |
| R_new, T_new = w2c[:,:3, :3].permute(0,2,1), w2c[:,:3, 3] # convert R to row-major matrix | |
| image_size = ((H, W),) # (h, w) | |
| cameras = PerspectiveCameras(focal_length=fs, principal_point=c, in_ndc=False, image_size=image_size, R=R_new, T=T_new, device=device) | |
| return cameras,num_views | |
| def setup_renderer(cameras, image_size): | |
| # Define the settings for rasterization and shading. | |
| raster_settings = PointsRasterizationSettings( | |
| image_size=image_size, | |
| radius = 0.01, | |
| points_per_pixel = 10, | |
| bin_size = 0 | |
| ) | |
| renderer = PointsRenderer( | |
| rasterizer=PointsRasterizer(cameras=cameras, raster_settings=raster_settings), | |
| compositor=AlphaCompositor() | |
| ) | |
| render_setup = {'cameras': cameras, 'raster_settings': raster_settings, 'renderer': renderer} | |
| return render_setup | |
| def interpolate_sequence(sequence, k,device): | |
| N, M = sequence.size() | |
| weights = torch.linspace(0, 1, k+1).view(1, -1, 1).to(device) | |
| left_values = sequence[:-1].unsqueeze(1).repeat(1, k+1, 1) | |
| right_values = sequence[1:].unsqueeze(1).repeat(1, k+1, 1) | |
| new_sequence = torch.einsum("ijk,ijl->ijl", (1 - weights), left_values) + torch.einsum("ijk,ijl->ijl", weights, right_values) | |
| new_sequence = new_sequence.reshape(-1, M) | |
| new_sequence = torch.cat([new_sequence, sequence[-1].view(1, -1)], dim=0) | |
| return new_sequence | |
| def focus_point_fn(c2ws: torch.Tensor) -> torch.Tensor: | |
| """Calculate nearest point to all focal axes in camera-to-world matrices.""" | |
| # Extract camera directions and origins from c2ws | |
| directions, origins = c2ws[:, :3, 2:3], c2ws[:, :3, 3:4] | |
| m = torch.eye(3).to(c2ws.device) - directions * torch.transpose(directions, 1, 2) | |
| mt_m = torch.transpose(m, 1, 2) @ m | |
| focus_pt = torch.inverse(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0] | |
| return focus_pt | |
| def generate_camera_path(c2ws: torch.Tensor, n_inserts: int = 15, device='cuda') -> torch.Tensor: | |
| n_poses = c2ws.shape[0] | |
| interpolated_poses = [] | |
| for i in range(n_poses-1): | |
| start_pose = c2ws[i] | |
| end_pose = c2ws[(i + 1) % n_poses] | |
| focus_point = focus_point_fn(torch.stack([start_pose,end_pose])) | |
| interpolated_path = interpolate_poses(start_pose, end_pose, focus_point, n_inserts, device) | |
| # Exclude the last pose (end_pose) for all pairs | |
| interpolated_path = interpolated_path[:-1] | |
| interpolated_poses.append(interpolated_path) | |
| # Concatenate all the interpolated paths | |
| interpolated_poses.append(c2ws[-1:]) | |
| full_path = torch.cat(interpolated_poses, dim=0) | |
| return full_path | |
| def interpolate_poses(start_pose: torch.Tensor, end_pose: torch.Tensor, focus_point: torch.Tensor, n_inserts: int = 15, device='cuda') -> torch.Tensor: | |
| dtype = start_pose.dtype | |
| start_distance = torch.sqrt((start_pose[0, 3] - focus_point[0])**2 + (start_pose[1, 3] - focus_point[1])**2 + (start_pose[2, 3] - focus_point[2])**2) | |
| end_distance = torch.sqrt((end_pose[0, 3] - focus_point[0])**2 + (end_pose[1, 3] - focus_point[1])**2 + (end_pose[2, 3] - focus_point[2])**2) | |
| start_rot = R.from_matrix(start_pose[:3, :3].cpu().numpy()) | |
| end_rot = R.from_matrix(end_pose[:3, :3].cpu().numpy()) | |
| slerp_obj = Slerp([0, 1], R.from_quat([start_rot.as_quat(), end_rot.as_quat()])) | |
| inserted_c2ws = [] | |
| for t in torch.linspace(0., 1., n_inserts + 2, dtype=dtype): # Exclude the first and last point | |
| interpolated_rot = slerp_obj(t).as_matrix() | |
| interpolated_translation = (1 - t) * start_pose[:3, 3] + t * end_pose[:3, 3] | |
| interpolated_distance = (1 - t) * start_distance + t * end_distance | |
| direction = (interpolated_translation - focus_point) / torch.norm(interpolated_translation - focus_point) | |
| interpolated_translation = focus_point + direction * interpolated_distance | |
| inserted_pose = torch.eye(4, dtype=dtype).to(device) | |
| inserted_pose[:3, :3] = torch.from_numpy(interpolated_rot).to(device) | |
| inserted_pose[:3, 3] = interpolated_translation | |
| inserted_c2ws.append(inserted_pose) | |
| path = torch.stack(inserted_c2ws) | |
| return path | |
| def inv(mat): | |
| """ Invert a torch or numpy matrix | |
| """ | |
| if isinstance(mat, torch.Tensor): | |
| return torch.linalg.inv(mat) | |
| if isinstance(mat, np.ndarray): | |
| return np.linalg.inv(mat) | |
| raise ValueError(f'bad matrix type = {type(mat)}') | |
| def save_pointcloud_with_normals(imgs, pts3d, msk, save_path, mask_pc, reduce_pc): | |
| pc = get_pc(imgs, pts3d, msk,mask_pc,reduce_pc) # Assuming get_pc is defined elsewhere and returns a trimesh point cloud | |
| # Define a default normal, e.g., [0, 1, 0] | |
| default_normal = [0, 1, 0] | |
| # Prepare vertices, colors, and normals for saving | |
| vertices = pc.vertices | |
| colors = pc.colors | |
| normals = np.tile(default_normal, (vertices.shape[0], 1)) | |
| # Construct the header of the PLY file | |
| header = """ply | |
| format ascii 1.0 | |
| element vertex {} | |
| property float x | |
| property float y | |
| property float z | |
| property uchar red | |
| property uchar green | |
| property uchar blue | |
| property float nx | |
| property float ny | |
| property float nz | |
| end_header | |
| """.format(len(vertices)) | |
| # Write the PLY file | |
| with open(save_path, 'w') as ply_file: | |
| ply_file.write(header) | |
| for vertex, color, normal in zip(vertices, colors, normals): | |
| ply_file.write('{} {} {} {} {} {} {} {} {}\n'.format( | |
| vertex[0], vertex[1], vertex[2], | |
| int(color[0]), int(color[1]), int(color[2]), | |
| normal[0], normal[1], normal[2] | |
| )) | |
| def get_pc(imgs, pts3d, mask, mask_pc=False, reduce_pc=False): | |
| imgs = to_numpy(imgs) | |
| pts3d = to_numpy(pts3d) | |
| mask = to_numpy(mask) | |
| if mask_pc: | |
| pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)]) | |
| col = np.concatenate([p[m] for p, m in zip(imgs, mask)]) | |
| else: | |
| pts = np.concatenate([p for p in pts3d]) | |
| col = np.concatenate([p for p in imgs]) | |
| if reduce_pc: | |
| pts = pts.reshape(-1, 3)[::3] | |
| col = col.reshape(-1, 3)[::3] | |
| else: | |
| pts = pts.reshape(-1, 3) | |
| col = col.reshape(-1, 3) | |
| #mock normals: | |
| normals = np.tile([0, 1, 0], (pts.shape[0], 1)) | |
| pct = trimesh.PointCloud(pts, colors=col) | |
| # debug | |
| # pct.export('output.ply') | |
| # print('exporting output.ply') | |
| pct.vertices_normal = normals # Manually add normals to the point cloud | |
| return pct#, pts | |
| def world_to_kth(poses, k): | |
| # 将世界坐标系转到和第k个pose的相机坐标系一致 | |
| kth_pose = poses[k] | |
| inv_kth_pose = torch.inverse(kth_pose) | |
| new_poses = torch.bmm(inv_kth_pose.unsqueeze(0).expand_as(poses), poses) | |
| return new_poses | |
| def world_point_to_kth(poses, points, k, device): | |
| # 将世界坐标系转到和第k个pose的相机坐标系一致,同时处理点云 | |
| kth_pose = poses[k] | |
| inv_kth_pose = torch.inverse(kth_pose) | |
| # 给所有pose左成kth_w2c,将其都变到kth_pose的camera coordinate下 | |
| new_poses = torch.bmm(inv_kth_pose.unsqueeze(0).expand_as(poses), poses) | |
| N, W, H, _ = points.shape | |
| points = points.view(N, W * H, 3) | |
| homogeneous_points = torch.cat([points, torch.ones(N, W*H, 1).to(device)], dim=-1) | |
| new_points = inv_kth_pose.unsqueeze(0).expand(N, -1, -1).unsqueeze(1)@ homogeneous_points.unsqueeze(-1) | |
| new_points = new_points.squeeze(-1)[...,:3].view(N, W, H, _) | |
| return new_poses, new_points | |
| def world_point_to_obj(poses, points, k, r, elevation, device): | |
| ## 作用:将世界坐标系转到object的中心 | |
| ## 先将世界坐标系转到指定相机 | |
| poses, points = world_point_to_kth(poses, points, k, device) | |
| ## 定义目标坐标系位姿, 原点位于object中心(远世界坐标系[0,0,r]),Y轴向上, Z轴垂直屏幕向外, X轴向右 | |
| elevation_rad = torch.deg2rad(torch.tensor(180-elevation)).to(device) | |
| sin_value_x = torch.sin(elevation_rad) | |
| cos_value_x = torch.cos(elevation_rad) | |
| R = torch.tensor([[1, 0, 0,], | |
| [0, cos_value_x, sin_value_x], | |
| [0, -sin_value_x, cos_value_x]]).to(device) | |
| t = torch.tensor([0, 0, r]).to(device) | |
| pose_obj = torch.eye(4).to(device) | |
| pose_obj[:3, :3] = R | |
| pose_obj[:3, 3] = t | |
| ## 给所有点和pose乘以目标坐标系的逆(w2c),将它们变换到目标坐标系下 | |
| inv_obj_pose = torch.inverse(pose_obj) | |
| new_poses = torch.bmm(inv_obj_pose.unsqueeze(0).expand_as(poses), poses) | |
| N, W, H, _ = points.shape | |
| points = points.view(N, W * H, 3) | |
| homogeneous_points = torch.cat([points, torch.ones(N, W*H, 1).to(device)], dim=-1) | |
| new_points = inv_obj_pose.unsqueeze(0).expand(N, -1, -1).unsqueeze(1)@ homogeneous_points.unsqueeze(-1) | |
| new_points = new_points.squeeze(-1)[...,:3].view(N, W, H, _) | |
| return new_poses, new_points | |
| def txt_interpolation(input_list,n,mode = 'smooth'): | |
| x = np.linspace(0, 1, len(input_list)) | |
| if mode == 'smooth': | |
| f = UnivariateSpline(x, input_list, k=3) | |
| elif mode == 'linear': | |
| f = interp1d(x, input_list) | |
| else: | |
| raise KeyError(f"Invalid txt interpolation mode: {mode}") | |
| xnew = np.linspace(0, 1, n) | |
| ynew = f(xnew) | |
| return ynew | |
| def visualizer(camera_poses, save_path="out.png"): | |
| fig = plt.figure() | |
| ax = fig.add_subplot(111, projection="3d") | |
| colors = ["blue" for _ in camera_poses] | |
| for pose, color in zip(camera_poses, colors): | |
| camera_positions = pose[:3, 3] | |
| ax.scatter( | |
| camera_positions[0], | |
| camera_positions[1], | |
| camera_positions[2], | |
| c=color, | |
| marker="o", | |
| ) | |
| ax.set_xlabel("X") | |
| ax.set_ylabel("Y") | |
| ax.set_zlabel("Z") | |
| ax.set_title("Camera trajectory") | |
| # ax.view_init(90+30, -90) | |
| plt.savefig(save_path) | |
| plt.close() | |
| def visualizer_frame(camera_poses, highlight_index): | |
| fig = plt.figure() | |
| ax = fig.add_subplot(111, projection="3d") | |
| # 获取camera_positions[2]的最大值和最小值 | |
| z_values = [pose[:3, 3][2] for pose in camera_poses] | |
| z_min, z_max = min(z_values), max(z_values) | |
| # 创建一个颜色映射对象 | |
| cmap = mcolors.LinearSegmentedColormap.from_list("mycmap", ["#00008B", "#ADD8E6"]) | |
| # cmap = plt.get_cmap("coolwarm") | |
| norm = mcolors.Normalize(vmin=z_min, vmax=z_max) | |
| sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) | |
| for i, pose in enumerate(camera_poses): | |
| camera_positions = pose[:3, 3] | |
| color = "blue" if i == highlight_index else "blue" | |
| size = 100 if i == highlight_index else 25 | |
| color = sm.to_rgba(camera_positions[2]) # 根据camera_positions[2]的值映射颜色 | |
| ax.scatter( | |
| camera_positions[0], | |
| camera_positions[1], | |
| camera_positions[2], | |
| c=color, | |
| marker="o", | |
| s=size, | |
| ) | |
| ax.set_xlabel("X") | |
| ax.set_ylabel("Y") | |
| ax.set_zlabel("Z") | |
| # ax.set_title("Camera trajectory") | |
| ax.view_init(90+30, -90) | |
| plt.ylim(-0.1,0.2) | |
| fig.canvas.draw() | |
| width, height = fig.canvas.get_width_height() | |
| img = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8').reshape(height, width, 3) | |
| # new_width = int(width * 0.6) | |
| # start_x = (width - new_width) // 2 + new_width // 5 | |
| # end_x = start_x + new_width | |
| # img = img[:, start_x:end_x, :] | |
| plt.close() | |
| return img | |