Spaces:
Running
on
Zero
Running
on
Zero
| import sys | |
| from pyparsing import col | |
| sys.path.insert(0,".") | |
| import argparse | |
| from packaging import version | |
| import glob | |
| import os | |
| from LightGlue.lightglue import LightGlue, SuperPoint, DISK, SIFT, ALIKED, DoGHardNet | |
| from LightGlue.lightglue.utils import load_image, rbd | |
| from cotracker.predictor import CoTrackerPredictor, sample_trajectories, generate_gassian_heatmap, sample_trajectories_with_ref | |
| import torch | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from models_diffusers.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel | |
| from pipelines.AniDoc import AniDocPipeline | |
| from models_diffusers.controlnet_svd import ControlNetSVDModel | |
| from diffusers.utils import load_image, export_to_video, export_to_gif | |
| import time | |
| from lineart_extractor.annotator.lineart import LineartDetector | |
| import numpy as np | |
| from PIL import Image | |
| from utils import load_images_from_folder,export_gif_with_ref,export_gif_side_by_side,extract_frames_from_video,safe_round,select_multiple_points,generate_point_map,generate_point_map_frames,export_gif_side_by_side_complete,export_gif_side_by_side_complete_ablation | |
| import random | |
| import torchvision.transforms as T | |
| from LightGlue.lightglue import viz2d | |
| import matplotlib.pyplot as plt | |
| from cotracker.utils.visualizer import Visualizer, read_video_from_path | |
| from torchvision.transforms import PILToTensor | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--pretrained_model_name_or_path", type=str, default="pretrained_weights/stable-video-diffusion-img2vid-xt", help="Path to the input image.") | |
| parser.add_argument( | |
| "--pretrained_unet", type=str, help="Path to the input image.", | |
| default="pretrained_weights/anidoc" | |
| ) | |
| parser.add_argument( | |
| "--controlnet_model_name_or_path", type=str, help="Path to the input image.", | |
| default="pretrained_weights/anidoc/controlnet" | |
| ) | |
| parser.add_argument("--output_dir", type=str, default=None, help="Path to the output video.") | |
| parser.add_argument("--seed", type=int, default=42, help="random seed.") | |
| parser.add_argument("--noise_aug", type=float, default=0.02) | |
| parser.add_argument("--num_frames", type=int, default=14) | |
| parser.add_argument("--width", type=int, default=512) | |
| parser.add_argument("--height", type=int, default=320) | |
| parser.add_argument("--all_sketch",action="store_true",help="all_sketch") | |
| parser.add_argument("--not_quant_sketch",action="store_true",help="not_quant_sketch") | |
| parser.add_argument("--repeat_sketch",action="store_true",help="not_quant_sketch") | |
| parser.add_argument("--matching",action="store_true",help="add keypoint matching") | |
| parser.add_argument("--tracking",action="store_true",help="tracking keypoint") | |
| parser.add_argument("--repeat_matching",action="store_true",help="not tracking, but just simply repeat") | |
| parser.add_argument("--tracker_point_init", type=str, default='gaussion', choices=['dift', 'gaussion', 'both'], help="Regular grid size") | |
| parser.add_argument( | |
| "--tracker_shift_grid", | |
| type=int, default=0, choices=[0, 1], | |
| help="shift the grid for the tracker") | |
| parser.add_argument("--tracker_grid_size", type=int, default=8, help="Regular grid size") | |
| parser.add_argument( | |
| "--tracker_grid_query_frame", | |
| type=int, | |
| default=0, | |
| help="Compute dense and grid tracks starting from this frame", | |
| ) | |
| parser.add_argument( | |
| "--tracker_backward_tracking", | |
| action="store_true", | |
| help="Compute tracks in both directions, not only forward", | |
| ) | |
| parser.add_argument("--control_image", type=str, default=None, help="Path to the output video.") | |
| parser.add_argument("--ref_image", type=str, default=None, help="Path to the output video.") | |
| parser.add_argument("--max_points", type=int, default=10) | |
| args = parser.parse_args() | |
| return args | |
| if __name__ == "__main__": | |
| args = get_args() | |
| dtype = torch.float16 | |
| unet = UNetSpatioTemporalConditionModel.from_pretrained( | |
| args.pretrained_unet, | |
| subfolder="unet", | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True, | |
| custom_resume=True, | |
| ) | |
| unet.to("cuda",dtype) | |
| if args.controlnet_model_name_or_path: | |
| controlnet = ControlNetSVDModel.from_pretrained( | |
| args.controlnet_model_name_or_path, | |
| ) | |
| else: | |
| controlnet = ControlNetSVDModel.from_unet( | |
| unet, | |
| conditioning_channels=8 | |
| ) | |
| controlnet.to("cuda",dtype) | |
| if is_xformers_available(): | |
| import xformers | |
| xformers_version = version.parse(xformers.__version__) | |
| unet.enable_xformers_memory_efficient_attention() | |
| else: | |
| raise ValueError( | |
| "xformers is not available. Make sure it is installed correctly") | |
| pipe = AniDocPipeline.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| unet=unet, | |
| controlnet=controlnet, | |
| low_cpu_mem_usage=False, | |
| torch_dtype=torch.float16, variant="fp16" | |
| ) | |
| pipe.to("cuda") | |
| device = "cuda" | |
| detector = LineartDetector(device) | |
| extractor = SuperPoint(max_num_keypoints=2000).eval().to(device) # load the extractor | |
| matcher = LightGlue(features='superpoint').eval().to(device) # load the matcher | |
| tracker = CoTrackerPredictor( | |
| checkpoint="pretrained_weights/cotracker2.pth", | |
| shift_grid=args.tracker_shift_grid, | |
| ) | |
| tracker.requires_grad_(False) | |
| tracker.to(device, dtype=torch.float32) | |
| width, height = args.width, args.height | |
| # image = load_image('dalle3_cat.jpg') | |
| if args.output_dir is None: | |
| args.output_dir = "results" | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| image_folder_list=[ | |
| 'data_test/sample1.mp4', | |
| ] | |
| ref_image_list=[ | |
| "data_test/sample1.png", | |
| ] | |
| if args.ref_image is not None and args.control_image is not None: | |
| ref_image_list=[args.ref_image] | |
| image_folder_list=[args.control_image] | |
| for val_id ,each_sample in enumerate(image_folder_list): | |
| if os.path.isdir(each_sample): | |
| control_images=load_images_from_folder(each_sample) | |
| elif each_sample.endswith(".mp4"): | |
| control_images = extract_frames_from_video(each_sample) | |
| ref_image=load_image(ref_image_list[val_id]).resize((width, height)) | |
| #resize: | |
| for j, each in enumerate(control_images): | |
| control_images[j]=control_images[j].resize((width, height)) | |
| # load image from folder | |
| if args.all_sketch: | |
| controlnet_image=[] | |
| for k in range(len(control_images)): | |
| sketch=control_images[k] | |
| sketch = np.array(sketch) | |
| sketch=detector(sketch,coarse=False) | |
| sketch=np.repeat(sketch[:, :, np.newaxis], 3, axis=2) | |
| if args.not_quant_sketch: | |
| pass | |
| else: | |
| sketch= (sketch > 200).astype(np.uint8)*255 | |
| sketch = Image.fromarray(sketch).resize((width, height)) | |
| controlnet_image.append(sketch) | |
| controlnet_sketch_condition = [T.ToTensor()(img).unsqueeze(0) for img in controlnet_image] | |
| controlnet_sketch_condition = torch.cat(controlnet_sketch_condition, dim=0).unsqueeze(0).to(device, dtype=torch.float16) | |
| controlnet_sketch_condition = (controlnet_sketch_condition - 0.5) / 0.5 #(1,14,3,h,w) | |
| # matching condition | |
| with torch.no_grad(): | |
| ref_img_value = T.ToTensor()(ref_image).to(device, dtype=torch.float16) #(0,1) | |
| ref_img_value = ref_img_value.to(torch.float32) | |
| current_img= T.ToTensor()(controlnet_image[0]).to(device, dtype=torch.float16) #(0,1) | |
| current_img = current_img.to(torch.float32) | |
| feats0 = extractor.extract(ref_img_value) | |
| feats1 = extractor.extract(current_img) | |
| matches01 = matcher({'image0': feats0, 'image1': feats1}) | |
| feats0, feats1, matches01 = [rbd(x) for x in [feats0, feats1, matches01]] | |
| matches = matches01['matches'] | |
| points0 = feats0['keypoints'][matches[..., 0]] | |
| points1 = feats1['keypoints'][matches[..., 1]] | |
| points0 = points0.cpu().numpy() | |
| # points0_org=points0.copy() | |
| points1 = points1.cpu().numpy() | |
| points0 = safe_round(points0, current_img.shape) | |
| points1 = safe_round(points1, current_img.shape) | |
| num_points = min(50, points0.shape[0]) | |
| points0,points1 = select_multiple_points(points0, points1, num_points) | |
| mask1, mask2 = generate_point_map(size=current_img.shape, coords0=points0, coords1=points1) | |
| # import ipdb;ipdb.set_trace() | |
| point_map1=torch.from_numpy(mask1) | |
| point_map2=torch.from_numpy(mask2) | |
| point_map1 = point_map1.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(device, dtype=torch.float16) | |
| point_map2 = point_map2.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(device, dtype=torch.float16) | |
| point_map=torch.cat([point_map1,point_map2],dim=2) | |
| conditional_pixel_values=ref_img_value.unsqueeze(0).unsqueeze(0) | |
| conditional_pixel_values = (conditional_pixel_values - 0.5) / 0.5 | |
| point_map_with_ref= torch.cat([point_map,conditional_pixel_values],dim=2) | |
| original_shape = list(point_map_with_ref.shape) | |
| new_shape = original_shape.copy() | |
| new_shape[1] = args.num_frames-1 | |
| if args.repeat_matching: | |
| matching_controlnet_image=point_map_with_ref.repeat(1,args.num_frames,1,1,1) | |
| controlnet_condition=torch.cat([controlnet_sketch_condition, matching_controlnet_image], dim=2) | |
| elif args.tracking: | |
| with torch.no_grad(): | |
| video_for_tracker = (controlnet_sketch_condition * 0.5 + 0.5) * 255. | |
| queries = np.insert(points1,0,0,axis=1) | |
| queries =torch.from_numpy(queries).to(device,torch.float).unsqueeze(0) | |
| if queries.shape[1]==0: | |
| pred_tracks_sampled=None | |
| points0_sampled = None | |
| else: | |
| pred_tracks, pred_visibility = tracker( | |
| video_for_tracker.to(dtype=torch.float32), | |
| queries=queries, | |
| grid_size=args.tracker_grid_size, # 8 | |
| grid_query_frame=args.tracker_grid_query_frame, # 0 | |
| backward_tracking=args.tracker_backward_tracking, # False | |
| # segm_mask=segm_mask, | |
| ) | |
| pred_tracks_sampled, pred_visibility_sampled,points0_sampled = sample_trajectories_with_ref( | |
| pred_tracks.cpu(), pred_visibility.cpu(), torch.from_numpy(points0).unsqueeze(0).cpu(), | |
| max_points=args.max_points, | |
| motion_threshold=1, | |
| vis_threshold=3, | |
| ) | |
| if pred_tracks_sampled is None: | |
| mask1 = np.zeros((args.height, args.width), dtype=np.uint8) | |
| mask2 = np.zeros((args.num_frames,args.height, args.width), dtype=np.uint8) | |
| else: | |
| pred_tracks_sampled = pred_tracks_sampled.squeeze(0).cpu().numpy() | |
| pred_visibility_sampled =pred_visibility_sampled.squeeze(0).cpu().numpy() | |
| points0_sampled =points0_sampled.squeeze(0).cpu().numpy() | |
| for frame_id in range(args.num_frames): | |
| pred_tracks_sampled[frame_id] = safe_round(pred_tracks_sampled[frame_id],current_img.shape) | |
| points0_sampled = safe_round(points0_sampled,current_img.shape) | |
| mask1, mask2 = generate_point_map_frames(size=current_img.shape, coords0=points0_sampled,coords1=pred_tracks_sampled,visibility=pred_visibility_sampled) | |
| point_map1=torch.from_numpy(mask1) | |
| point_map2=torch.from_numpy(mask2) | |
| point_map1 = point_map1.unsqueeze(0).unsqueeze(0).repeat(1,args.num_frames,1,1,1).to(device, dtype=torch.float16) | |
| point_map2 = point_map2.unsqueeze(0).unsqueeze(2).to(device, dtype=torch.float16) | |
| point_map=torch.cat([point_map1,point_map2],dim=2) | |
| conditional_pixel_values_repeat=conditional_pixel_values.repeat(1,14,1,1,1) | |
| point_map_with_ref= torch.cat([point_map,conditional_pixel_values_repeat],dim=2) | |
| controlnet_condition= torch.cat([controlnet_sketch_condition, point_map_with_ref], dim=2) | |
| else: | |
| zero_tensor = torch.zeros(new_shape).to(device, dtype=torch.float16) | |
| matching_controlnet_image=torch.cat((point_map_with_ref,zero_tensor),dim=1) | |
| controlnet_condition = torch.cat([controlnet_sketch_condition, matching_controlnet_image], dim=2) | |
| ref_base_name=os.path.splitext(os.path.basename(ref_image_list[val_id]))[0] | |
| sketch_base_name=os.path.splitext(os.path.basename(each_sample))[0] | |
| supp_dir=os.path.join(args.output_dir,ref_base_name+"_"+sketch_base_name) | |
| os.makedirs(supp_dir, exist_ok=True) | |
| elif args.repeat_sketch: | |
| controlnet_image=[] | |
| for i_2 in range(int(len(control_images)/2)): | |
| sketch=control_images[0] | |
| sketch = np.array(sketch) | |
| sketch=detector(sketch,coarse=False) | |
| sketch=np.repeat(sketch[:, :, np.newaxis], 3, axis=2) | |
| if args.not_quant_sketch: | |
| pass | |
| else: | |
| sketch= (sketch > 200).astype(np.uint8)*255 | |
| sketch = Image.fromarray(sketch) | |
| controlnet_image.append(sketch) | |
| for i_3 in range(int(len(control_images)/2)): | |
| sketch=control_images[-1] | |
| sketch = np.array(sketch) | |
| sketch=detector(sketch,coarse=False) | |
| sketch=np.repeat(sketch[:, :, np.newaxis], 3, axis=2) | |
| if args.not_quant_sketch: | |
| pass | |
| else: | |
| sketch= (sketch > 200).astype(np.uint8)*255 | |
| sketch = Image.fromarray(sketch) | |
| controlnet_image.append(sketch) | |
| generator = torch.manual_seed(args.seed) | |
| with torch.inference_mode(): | |
| video_frames = pipe( | |
| ref_image, | |
| controlnet_condition, | |
| height=args.height, | |
| width=args.width, | |
| num_frames=14, | |
| decode_chunk_size=8, | |
| motion_bucket_id=127, | |
| fps=7, | |
| noise_aug_strength=0.02, | |
| generator=generator, | |
| ).frames[0] | |
| out_file = supp_dir+'.mp4' | |
| if args.all_sketch: | |
| export_gif_side_by_side_complete_ablation(ref_image,controlnet_image,video_frames,out_file.replace('.mp4','.gif'),supp_dir,6) | |
| elif args.repeat_sketch: | |
| export_gif_with_ref(control_images[0],video_frames,controlnet_image[-1],controlnet_image[0],out_file.replace('.mp4','.gif'),6) | |