Spaces:
Configuration error
Configuration error
| import sys | |
| sys.path.append('./extern/dust3r') | |
| from dust3r.inference import inference, load_model | |
| from dust3r.utils.image import load_images | |
| from dust3r.image_pairs import make_pairs | |
| from dust3r.cloud_opt import global_aligner, GlobalAlignerMode | |
| from dust3r.utils.device import to_numpy | |
| import trimesh | |
| import torch | |
| import numpy as np | |
| import torchvision | |
| import os | |
| import copy | |
| import cv2 | |
| from PIL import Image | |
| import pytorch3d | |
| print(pytorch3d.__version__) | |
| from pytorch3d.structures import Pointclouds | |
| from torchvision.utils import save_image | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| from utils.pvd_utils import * | |
| from omegaconf import OmegaConf | |
| from pytorch_lightning import seed_everything | |
| from utils.diffusion_utils import instantiate_from_config,load_model_checkpoint,image_guided_synthesis | |
| from pathlib import Path | |
| from torchvision.utils import save_image | |
| class ViewCrafter: | |
| def __init__(self, opts, gradio = False): | |
| self.opts = opts | |
| self.device = opts.device | |
| self.setup_dust3r() | |
| # self.setup_diffusion() | |
| # initialize ref images, pcd | |
| if not gradio: | |
| self.images, self.img_ori = self.load_initial_images(image_dir=self.opts.image_dir) | |
| self.run_dust3r(input_images=self.images) | |
| def run_dust3r(self, input_images,clean_pc = False): | |
| pairs = make_pairs(input_images, scene_graph='complete', prefilter=None, symmetrize=True) | |
| output = inference(pairs, self.dust3r, self.device, batch_size=self.opts.batch_size) | |
| mode = GlobalAlignerMode.PointCloudOptimizer #if len(self.images) > 2 else GlobalAlignerMode.PairViewer | |
| scene = global_aligner(output, device=self.device, mode=mode) | |
| if mode == GlobalAlignerMode.PointCloudOptimizer: | |
| loss = scene.compute_global_alignment(init='mst', niter=self.opts.niter, schedule=self.opts.schedule, lr=self.opts.lr) | |
| if clean_pc: | |
| self.scene = scene.clean_pointcloud() | |
| else: | |
| self.scene = scene | |
| def render_pcd(self,pts3d,imgs,masks,views,renderer,device): | |
| imgs = to_numpy(imgs) | |
| pts3d = to_numpy(pts3d) | |
| if masks == None: | |
| pts = torch.from_numpy(np.concatenate([p for p in pts3d])).view(-1, 3).to(device) | |
| col = torch.from_numpy(np.concatenate([p for p in imgs])).view(-1, 3).to(device) | |
| else: | |
| # masks = to_numpy(masks) | |
| pts = torch.from_numpy(np.concatenate([p[m] for p, m in zip(pts3d, masks)])).to(device) | |
| col = torch.from_numpy(np.concatenate([p[m] for p, m in zip(imgs, masks)])).to(device) | |
| color_mask = torch.ones(col.shape).to(device) | |
| # point_cloud_mask = Pointclouds(points=[pts],features=[color_mask]).extend(views) | |
| point_cloud = Pointclouds(points=[pts], features=[col]).extend(views) | |
| images = renderer(point_cloud) | |
| # view_masks = renderer(point_cloud_mask) | |
| return images, None | |
| def run_render(self, pcd, imgs,masks, H, W, camera_traj,num_views,use_cpu=False): | |
| if use_cpu: | |
| device = torch.device("cpu") | |
| else: | |
| device = self.device | |
| render_setup = setup_renderer(camera_traj, image_size=(H,W)) | |
| renderer = render_setup['renderer'] | |
| render_results, viewmask = self.render_pcd(pcd, imgs, masks, num_views,renderer,device) | |
| return render_results, viewmask | |
| def run_diffusion(self, renderings): | |
| prompts = [self.opts.prompt] | |
| videos = (renderings * 2. - 1.).permute(3,0,1,2).unsqueeze(0).to(self.device) | |
| condition_index = [0] | |
| with torch.no_grad(), torch.cuda.amp.autocast(): | |
| # [1,1,c,t,h,w] | |
| batch_samples = image_guided_synthesis(self.diffusion, prompts, videos, self.noise_shape, self.opts.n_samples, self.opts.ddim_steps, self.opts.ddim_eta, \ | |
| self.opts.unconditional_guidance_scale, self.opts.cfg_img, self.opts.frame_stride, self.opts.text_input, self.opts.multiple_cond_cfg, self.opts.timestep_spacing, self.opts.guidance_rescale, condition_index) | |
| # save_results_seperate(batch_samples[0], self.opts.save_dir, fps=8) | |
| # torch.Size([1, 3, 25, 576, 1024]) [-1,1] | |
| return torch.clamp(batch_samples[0][0].permute(1,2,3,0), -1., 1.) | |
| def nvs_single_view(self, gradio=False): | |
| # 最后一个view为 0 pose | |
| c2ws = self.scene.get_im_poses().detach()[1:] | |
| principal_points = self.scene.get_principal_points().detach()[1:] #cx cy | |
| focals = self.scene.get_focals().detach()[1:] | |
| shape = self.images[0]['true_shape'] | |
| H, W = int(shape[0][0]), int(shape[0][1]) | |
| pcd = [i.detach() for i in self.scene.get_pts3d(clip_thred=self.opts.dpt_trd)] # a list of points of size whc | |
| depth = [i.detach() for i in self.scene.get_depthmaps()] | |
| depth_avg = depth[-1][H//2,W//2] #以图像中心处的depth(z)为球心旋转 | |
| radius = depth_avg*self.opts.center_scale #缩放调整 | |
| ## change coordinate | |
| c2ws,pcd = world_point_to_obj(poses=c2ws, points=torch.stack(pcd), k=-1, r=radius, elevation=self.opts.elevation, device=self.device) | |
| imgs = np.array(self.scene.imgs) | |
| masks = None | |
| if self.opts.mode == 'single_view_nbv': | |
| ## 输入candidate->渲染mask->最大mask对应的pose作为nbv | |
| ## nbv模式下self.opts.d_theta[0], self.opts.d_phi[0]代表search space中的网格theta, phi之间的间距; self.opts.d_phi[0]的符号代表方向,分为左右两个方向 | |
| ## FIXME hard coded candidate view数量, 以left为例,第一次迭代从[左,左上]中选取, 从第二次开始可以从[左,左上,左下]中选取 | |
| num_candidates = 2 | |
| candidate_poses,thetas,phis = generate_candidate_poses(c2ws, H, W, focals, principal_points, self.opts.d_theta[0], self.opts.d_phi[0],num_candidates, self.device) | |
| _, viewmask = self.run_render([pcd[-1]], [imgs[-1]],masks, H, W, candidate_poses,num_candidates,use_cpu=False) | |
| nbv_id = torch.argmin(viewmask.sum(dim=[1,2,3])).item() | |
| save_image( viewmask.permute(0,3,1,2), os.path.join(self.opts.save_dir,f"candidate_mask0_nbv{nbv_id}.png"), normalize=True, value_range=(0, 1)) | |
| theta_nbv = thetas[nbv_id] | |
| phi_nbv = phis[nbv_id] | |
| # generate camera trajectory from T_curr to T_nbv | |
| camera_traj,num_views = generate_traj_specified(c2ws, H, W, focals, principal_points, theta_nbv, phi_nbv, self.opts.d_r[0],self.opts.video_length, self.device) | |
| # 重置elevation | |
| self.opts.elevation -= theta_nbv | |
| elif self.opts.mode == 'single_view_target': | |
| camera_traj,num_views = generate_traj_specified(c2ws, H, W, focals, principal_points, self.opts.d_theta[0], self.opts.d_phi[0], self.opts.d_r[0],self.opts.video_length, self.device) | |
| elif self.opts.mode == 'single_view_txt': | |
| if not gradio: | |
| with open(self.opts.traj_txt, 'r') as file: | |
| lines = file.readlines() | |
| phi = [float(i) for i in lines[0].split()] | |
| theta = [float(i) for i in lines[1].split()] | |
| r = [float(i) for i in lines[2].split()] | |
| else: | |
| phi, theta, r = self.gradio_traj | |
| # device = torch.device("cpu") | |
| device = self.device | |
| camera_traj,num_views = generate_traj_txt(c2ws, H, W, focals, principal_points, phi, theta, r,self.opts.video_length, device,viz_traj=True, save_dir = self.opts.save_dir) | |
| # camera_traj,num_views = generate_traj_txt(c2ws, H, W, focals, principal_points, phi, theta, r,self.opts.video_length, self.device,viz_traj=True, save_dir = self.opts.save_dir) | |
| else: | |
| raise KeyError(f"Invalid Mode: {self.opts.mode}") | |
| render_results, viewmask = self.run_render([pcd[-1]], [imgs[-1]],masks, H, W, camera_traj,num_views,use_cpu=False) | |
| render_results = render_results.to(self.device) | |
| render_results = F.interpolate(render_results.permute(0,3,1,2), size=(576, 1024), mode='bilinear', align_corners=False).permute(0,2,3,1) | |
| render_results[0] = self.img_ori | |
| if self.opts.mode == 'single_view_txt': | |
| if phi[-1]==0. and theta[-1]==0. and r[-1]==0.: | |
| render_results[-1] = self.img_ori | |
| save_video(render_results, os.path.join(self.opts.save_dir, 'render0.mp4')) | |
| save_pointcloud_with_normals([imgs[-1]], [pcd[-1]], msk=None, save_path=os.path.join(self.opts.save_dir,'pcd0.ply') , mask_pc=False, reduce_pc=False) | |
| diffusion_results = self.run_diffusion(render_results) | |
| save_video((diffusion_results + 1.0) / 2.0, os.path.join(self.opts.save_dir, 'diffusion0.mp4')) | |
| return diffusion_results | |
| def nvs_sparse_view(self,iter): | |
| c2ws = self.scene.get_im_poses().detach() | |
| principal_points = self.scene.get_principal_points().detach() | |
| focals = self.scene.get_focals().detach() | |
| shape = self.images[0]['true_shape'] | |
| H, W = int(shape[0][0]), int(shape[0][1]) | |
| pcd = [i.detach() for i in self.scene.get_pts3d(clip_thred=self.opts.dpt_trd)] # a list of points of size whc | |
| depth = [i.detach() for i in self.scene.get_depthmaps()] | |
| depth_avg = depth[0][H//2,W//2] #以ref图像中心处的depth(z)为球心旋转 | |
| radius = depth_avg*self.opts.center_scale #缩放调整 | |
| ## masks for cleaner point cloud | |
| self.scene.min_conf_thr = float(self.scene.conf_trf(torch.tensor(self.opts.min_conf_thr))) | |
| masks = self.scene.get_masks() | |
| depth = self.scene.get_depthmaps() | |
| bgs_mask = [dpt > self.opts.bg_trd*(torch.max(dpt[40:-40,:])+torch.min(dpt[40:-40,:])) for dpt in depth] | |
| masks_new = [m+mb for m, mb in zip(masks,bgs_mask)] | |
| masks = to_numpy(masks_new) | |
| ## render, 从c2ws[0]即ref image对应的相机开始 | |
| imgs = np.array(self.scene.imgs) | |
| if self.opts.mode == 'single_view_ref_iterative': | |
| c2ws,pcd = world_point_to_obj(poses=c2ws, points=torch.stack(pcd), k=0, r=radius, elevation=self.opts.elevation, device=self.device) | |
| camera_traj,num_views = generate_traj_specified(c2ws[0:1], H, W, focals[0:1], principal_points[0:1], self.opts.d_theta[iter], self.opts.d_phi[iter], self.opts.d_r[iter],self.opts.video_length, self.device) | |
| render_results, viewmask = self.run_render(pcd, imgs,masks, H, W, camera_traj,num_views) | |
| render_results = F.interpolate(render_results.permute(0,3,1,2), size=(576, 1024), mode='bilinear', align_corners=False).permute(0,2,3,1) | |
| render_results[0] = self.img_ori | |
| elif self.opts.mode == 'single_view_1drc_iterative': | |
| self.opts.elevation -= self.opts.d_theta[iter-1] | |
| c2ws,pcd = world_point_to_obj(poses=c2ws, points=torch.stack(pcd), k=-1, r=radius, elevation=self.opts.elevation, device=self.device) | |
| camera_traj,num_views = generate_traj_specified(c2ws[-1:], H, W, focals[-1:], principal_points[-1:], self.opts.d_theta[iter], self.opts.d_phi[iter], self.opts.d_r[iter],self.opts.video_length, self.device) | |
| render_results, viewmask = self.run_render(pcd, imgs,masks, H, W, camera_traj,num_views) | |
| render_results = F.interpolate(render_results.permute(0,3,1,2), size=(576, 1024), mode='bilinear', align_corners=False).permute(0,2,3,1) | |
| render_results[0] = (self.images[-1]['img_ori'].squeeze(0).permute(1,2,0)+1.)/2. | |
| elif self.opts.mode == 'single_view_nbv': | |
| c2ws,pcd = world_point_to_obj(poses=c2ws, points=torch.stack(pcd), k=-1, r=radius, elevation=self.opts.elevation, device=self.device) | |
| ## 输入candidate->渲染mask->最大mask对应的pose作为nbv | |
| ## nbv模式下self.opts.d_theta[0], self.opts.d_phi[0]代表search space中的网格theta, phi之间的间距; self.opts.d_phi[0]的符号代表方向,分为左右两个方向 | |
| ## FIXME hard coded candidate view数量, 以left为例,第一次迭代从[左,左上]中选取, 从第二次开始可以从[左,左上,左下]中选取 | |
| num_candidates = 3 | |
| candidate_poses,thetas,phis = generate_candidate_poses(c2ws[-1:], H, W, focals[-1:], principal_points[-1:], self.opts.d_theta[0], self.opts.d_phi[0], num_candidates, self.device) | |
| _, viewmask = self.run_render(pcd, imgs,masks, H, W, candidate_poses,num_candidates) | |
| nbv_id = torch.argmin(viewmask.sum(dim=[1,2,3])).item() | |
| save_image(viewmask.permute(0,3,1,2), os.path.join(self.opts.save_dir,f"candidate_mask{iter}_nbv{nbv_id}.png"), normalize=True, value_range=(0, 1)) | |
| theta_nbv = thetas[nbv_id] | |
| phi_nbv = phis[nbv_id] | |
| # generate camera trajectory from T_curr to T_nbv | |
| camera_traj,num_views = generate_traj_specified(c2ws[-1:], H, W, focals[-1:], principal_points[-1:], theta_nbv, phi_nbv, self.opts.d_r[0],self.opts.video_length, self.device) | |
| # 重置elevation | |
| self.opts.elevation -= theta_nbv | |
| render_results, viewmask = self.run_render(pcd, imgs,masks, H, W, camera_traj,num_views) | |
| render_results = F.interpolate(render_results.permute(0,3,1,2), size=(576, 1024), mode='bilinear', align_corners=False).permute(0,2,3,1) | |
| render_results[0] = (self.images[-1]['img_ori'].squeeze(0).permute(1,2,0)+1.)/2. | |
| else: | |
| raise KeyError(f"Invalid Mode: {self.opts.mode}") | |
| save_video(render_results, os.path.join(self.opts.save_dir, f'render{iter}.mp4')) | |
| save_pointcloud_with_normals(imgs, pcd, msk=masks, save_path=os.path.join(self.opts.save_dir, f'pcd{iter}.ply') , mask_pc=True, reduce_pc=False) | |
| diffusion_results = self.run_diffusion(render_results) | |
| save_video((diffusion_results + 1.0) / 2.0, os.path.join(self.opts.save_dir, f'diffusion{iter}.mp4')) | |
| # torch.Size([25, 576, 1024, 3]) | |
| return diffusion_results | |
| def nvs_single_view_ref_iterative(self): | |
| all_results = [] | |
| sample_rate = 6 | |
| idx = 1 #初始包含1张ref image | |
| for itr in range(0, len(self.opts.d_phi)): | |
| if itr == 0: | |
| self.images = [self.images[0]] #去掉后一份copy | |
| diffusion_results_itr = self.nvs_single_view() | |
| # diffusion_results_itr = torch.randn([25, 576, 1024, 3]).to(self.device) | |
| diffusion_results_itr = diffusion_results_itr.permute(0,3,1,2) | |
| all_results.append(diffusion_results_itr) | |
| else: | |
| for i in range(0+sample_rate, diffusion_results_itr.shape[0], sample_rate): | |
| self.images.append(get_input_dict(diffusion_results_itr[i:i+1,...],idx,dtype = torch.float32)) | |
| idx += 1 | |
| self.run_dust3r(input_images=self.images, clean_pc=True) | |
| diffusion_results_itr = self.nvs_sparse_view(itr) | |
| # diffusion_results_itr = torch.randn([25, 576, 1024, 3]).to(self.device) | |
| diffusion_results_itr = diffusion_results_itr.permute(0,3,1,2) | |
| all_results.append(diffusion_results_itr) | |
| return all_results | |
| def nvs_single_view_1drc_iterative(self): | |
| all_results = [] | |
| sample_rate = 6 | |
| idx = 1 #初始包含1张ref image | |
| for itr in range(0, len(self.opts.d_phi)): | |
| if itr == 0: | |
| self.images = [self.images[0]] #去掉后一份copy | |
| diffusion_results_itr = self.nvs_single_view() | |
| # diffusion_results_itr = torch.randn([25, 576, 1024, 3]).to(self.device) | |
| diffusion_results_itr = diffusion_results_itr.permute(0,3,1,2) | |
| all_results.append(diffusion_results_itr) | |
| else: | |
| for i in range(0+sample_rate, diffusion_results_itr.shape[0], sample_rate): | |
| self.images.append(get_input_dict(diffusion_results_itr[i:i+1,...],idx,dtype = torch.float32)) | |
| idx += 1 | |
| self.run_dust3r(input_images=self.images, clean_pc=True) | |
| diffusion_results_itr = self.nvs_sparse_view(itr) | |
| # diffusion_results_itr = torch.randn([25, 576, 1024, 3]).to(self.device) | |
| diffusion_results_itr = diffusion_results_itr.permute(0,3,1,2) | |
| all_results.append(diffusion_results_itr) | |
| return all_results | |
| def nvs_single_view_nbv(self): | |
| # lef and right | |
| # d_theta and a_phi 是搜索空间的顶点间隔 | |
| all_results = [] | |
| ## FIXME: hard coded | |
| sample_rate = 6 | |
| max_itr = 3 | |
| idx = 1 #初始包含1张ref image | |
| for itr in range(0, max_itr): | |
| if itr == 0: | |
| self.images = [self.images[0]] #去掉后一份copy | |
| diffusion_results_itr = self.nvs_single_view() | |
| # diffusion_results_itr = torch.randn([25, 576, 1024, 3]).to(self.device) | |
| diffusion_results_itr = diffusion_results_itr.permute(0,3,1,2) | |
| all_results.append(diffusion_results_itr) | |
| else: | |
| for i in range(0+sample_rate, diffusion_results_itr.shape[0], sample_rate): | |
| self.images.append(get_input_dict(diffusion_results_itr[i:i+1,...],idx,dtype = torch.float32)) | |
| idx += 1 | |
| self.run_dust3r(input_images=self.images, clean_pc=True) | |
| diffusion_results_itr = self.nvs_sparse_view(itr) | |
| # diffusion_results_itr = torch.randn([25, 576, 1024, 3]).to(self.device) | |
| diffusion_results_itr = diffusion_results_itr.permute(0,3,1,2) | |
| all_results.append(diffusion_results_itr) | |
| return all_results | |
| def setup_diffusion(self): | |
| seed_everything(self.opts.seed) | |
| config = OmegaConf.load(self.opts.config) | |
| model_config = config.pop("model", OmegaConf.create()) | |
| ## set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set" | |
| model_config['params']['unet_config']['params']['use_checkpoint'] = False | |
| model = instantiate_from_config(model_config) | |
| model = model.to(self.device) | |
| model.cond_stage_model.device = self.device | |
| model.perframe_ae = self.opts.perframe_ae | |
| assert os.path.exists(self.opts.ckpt_path), "Error: checkpoint Not Found!" | |
| model = load_model_checkpoint(model, self.opts.ckpt_path) | |
| model.eval() | |
| self.diffusion = model | |
| h, w = self.opts.height // 8, self.opts.width // 8 | |
| channels = model.model.diffusion_model.out_channels | |
| n_frames = self.opts.video_length | |
| self.noise_shape = [self.opts.bs, channels, n_frames, h, w] | |
| def setup_dust3r(self): | |
| self.dust3r = load_model(self.opts.model_path, self.device) | |
| def load_initial_images(self, image_dir): | |
| ## load images | |
| ## dict_keys(['img', 'true_shape', 'idx', 'instance', 'img_ori']),张量形式 | |
| images = load_images([image_dir], size=512,force_1024 = True) | |
| img_ori = (images[0]['img_ori'].squeeze(0).permute(1,2,0)+1.)/2. # [576,1024,3] [0,1] | |
| # img_ori = Image.open(image_dir).convert('RGB') | |
| # transform = transforms.Compose([ | |
| # transforms.Resize((576, 1024)), | |
| # transforms.ToTensor(), | |
| # transforms.Normalize((0., 0., 0.), (1., 1., 1.)) # 归一化到[-1,1],如果要归一化到[0,1],请使用transforms.Normalize((0., 0., 0.), (1., 1., 1.)) | |
| # ]) | |
| # img_ori = transform(img_ori).permute(1,2,0).to(self.device) | |
| if len(images) == 1: | |
| images = [images[0], copy.deepcopy(images[0])] | |
| images[1]['idx'] = 1 | |
| return images, img_ori | |
| def run_gradio(self,i2v_input_image, i2v_elevation, i2v_center_scale, i2v_d_phi, i2v_d_theta, i2v_d_r, i2v_steps, i2v_seed): | |
| self.opts.elevation = float(i2v_elevation) | |
| self.opts.center_scale = float(i2v_center_scale) | |
| self.opts.ddim_steps = i2v_steps | |
| self.gradio_traj = [float(i) for i in i2v_d_phi.split()],[float(i) for i in i2v_d_theta.split()],[float(i) for i in i2v_d_r.split()] | |
| seed_everything(i2v_seed) | |
| transform = transforms.Compose([ | |
| transforms.Resize(576), | |
| transforms.CenterCrop((576,1024)), | |
| ]) | |
| torch.cuda.empty_cache() | |
| img_tensor = torch.from_numpy(i2v_input_image).permute(2, 0, 1).unsqueeze(0).float().to(self.device) | |
| img_tensor = (img_tensor / 255. - 0.5) * 2 | |
| image_tensor_resized = transform(img_tensor) #1,3,h,w | |
| images = get_input_dict(image_tensor_resized,idx = 0,dtype = torch.float32) | |
| images = [images, copy.deepcopy(images)] | |
| images[1]['idx'] = 1 | |
| self.images = images | |
| self.img_ori = (image_tensor_resized.squeeze(0).permute(1,2,0) + 1.)/2. | |
| # self.images: torch.Size([1, 3, 288, 512]), [-1,1] | |
| # self.img_ori: torch.Size([576, 1024, 3]), [0,1] | |
| # self.images, self.img_ori = self.load_initial_images(image_dir=i2v_input_image) | |
| self.run_dust3r(input_images=self.images) | |
| self.nvs_single_view(gradio=True) | |
| traj_dir = os.path.join(self.opts.save_dir, "viz_traj.mp4") | |
| gen_dir = os.path.join(self.opts.save_dir, "diffusion0.mp4") | |
| return traj_dir, gen_dir |