Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import cv2 | |
| import os | |
| import imageio | |
| from diffusers import StableDiffusionControlNetPipeline, ControlNetModel | |
| from controlnet_aux import LineartDetector | |
| from functools import partial | |
| from PIL import Image | |
| from torch.utils.data import DataLoader, Dataset | |
| from torchvision.transforms import Compose, ToTensor, Normalize, Resize | |
| from NaRCan_model import Homography, Siren | |
| from util import get_mgrid, apply_homography, jacobian, VideoFitting, TestVideoFitting | |
| def get_example(): | |
| case = [ | |
| [ | |
| 'examples/bear.mp4', | |
| ], | |
| [ | |
| 'examples/boat.mp4', | |
| ], | |
| [ | |
| 'examples/woman-drink.mp4', | |
| ], | |
| [ | |
| 'examples/corgi.mp4', | |
| ], | |
| [ | |
| 'examples/yacht.mp4', | |
| ], | |
| [ | |
| 'examples/koolshooters.mp4', | |
| ], | |
| [ | |
| 'examples/overlook-the-ocean.mp4', | |
| ], | |
| [ | |
| 'examples/rotate.mp4', | |
| ], | |
| [ | |
| 'examples/shark-ocean.mp4', | |
| ], | |
| [ | |
| 'examples/surf.mp4', | |
| ], | |
| [ | |
| 'examples/cactus.mp4', | |
| ], | |
| [ | |
| 'examples/gold-fish.mp4', | |
| ] | |
| ] | |
| return case | |
| def set_default_prompt(video_name): | |
| video_to_prompt = { | |
| 'bear.mp4': 'bear, Van Gogh Style', | |
| 'boat.mp4': 'a burning boat sails on lava', | |
| 'cactus.mp4': 'cactus, made of paper', | |
| 'corgi.mp4': 'a hellhound', | |
| 'gold-fish.mp4': 'Goldfish in the Milky Way', | |
| 'koolshooters.mp4': 'Avatar', | |
| 'overlook-the-ocean.mp4': 'ocean, pixel style', | |
| 'rotate.mp4': 'turbine engine', | |
| 'shark-ocean.mp4': 'A sleek shark, cartoon style', | |
| 'surf.mp4': 'Sailing, The background is a large white cloud, sketch style', | |
| 'woman-drink.mp4': 'a drinking zombie', | |
| 'yacht.mp4': 'yacht, cyberpunk style', | |
| } | |
| return video_to_prompt.get(video_name, '') | |
| def update_prompt(input_video): | |
| video_name = input_video.split('/')[-1] | |
| return set_default_prompt(video_name) | |
| # Map videos to corresponding images | |
| video_to_image = { | |
| 'bear.mp4': ['canonical/bear.png', 'pth_file/bear', 'examples_frames/bear'], | |
| 'boat.mp4': ['canonical/boat.png', 'pth_file/boat', 'examples_frames/boat'], | |
| 'cactus.mp4': ['canonical/cactus.png', 'pth_file/cactus', 'examples_frames/cactus'], | |
| 'corgi.mp4': ['canonical/corgi.png', 'pth_file/corgi', 'examples_frames/corgi'], | |
| 'gold-fish.mp4': ['canonical/gold-fish.png', 'pth_file/gold-fish', 'examples_frames/gold-fish'], | |
| 'koolshooters.mp4': ['canonical/koolshooters.png', 'pth_file/koolshooters', 'examples_frames/koolshooters'], | |
| 'overlook-the-ocean.mp4': ['canonical/overlook-the-ocean.png', 'pth_file/overlook-the-ocean', 'examples_frames/overlook-the-ocean'], | |
| 'rotate.mp4': ['canonical/rotate.png', 'pth_file/rotate', 'examples_frames/rotate'], | |
| 'shark-ocean.mp4': ['canonical/shark-ocean.png', 'pth_file/shark-ocean', 'examples_frames/shark-ocean'], | |
| 'surf.mp4': ['canonical/surf.png', 'pth_file/surf', 'examples_frames/surf'], | |
| 'woman-drink.mp4': ['canonical/woman-drink.png', 'pth_file/woman-drink', 'examples_frames/woman-drink'], | |
| 'yacht.mp4': ['canonical/yacht.png', 'pth_file/yacht', 'examples_frames/yacht'], | |
| } | |
| def images_to_video(image_list, output_path, fps=10): | |
| # Convert PIL Images to numpy arrays | |
| frames = [np.array(img).astype(np.uint8) for img in image_list] | |
| frames = frames[:20] | |
| # Create video writer | |
| writer = imageio.get_writer(output_path, fps=fps, codec='libx264') | |
| for frame in frames: | |
| writer.append_data(frame) | |
| writer.close() | |
| def NaRCan_make_video(edit_canonical, pth_path, frames_path): | |
| # load NaRCan model | |
| checkpoint_g_old = torch.load(os.path.join(pth_path, "homography_g.pth")) | |
| checkpoint_g = torch.load(os.path.join(pth_path, "mlp_g.pth")) | |
| g_old = Homography(hidden_features=256, hidden_layers=2).cuda() | |
| g = Siren(in_features=3, out_features=2, hidden_features=256, | |
| hidden_layers=5, outermost_linear=True).cuda() | |
| g_old.load_state_dict(checkpoint_g_old) | |
| g.load_state_dict(checkpoint_g) | |
| g_old.eval() | |
| g.eval() | |
| transform = Compose([ | |
| Resize(512), | |
| ToTensor(), | |
| Normalize(torch.Tensor([0.5, 0.5, 0.5]), torch.Tensor([0.5, 0.5, 0.5])) | |
| ]) | |
| v = TestVideoFitting(frames_path, transform) | |
| videoloader = DataLoader(v, batch_size=1, pin_memory=True, num_workers=0) | |
| model_input, ground_truth = next(iter(videoloader)) | |
| model_input, ground_truth = model_input[0].cuda(), ground_truth[0].cuda() | |
| myoutput = None | |
| data_len = len(os.listdir(frames_path)) | |
| with torch.no_grad(): | |
| batch_size = (v.H * v.W) | |
| for step in range(data_len): | |
| start = (step * batch_size) % len(model_input) | |
| end = min(start + batch_size, len(model_input)) | |
| # get the deformation | |
| xy, t = model_input[start:end, :-1], model_input[start:end, [-1]] | |
| xyt = model_input[start:end] | |
| h_old = apply_homography(xy, g_old(t)) | |
| h = g(xyt) | |
| xy_ = h_old + h | |
| # use canonical to reconstruct | |
| w, h = v.W, v.H | |
| canonical_img = np.array(edit_canonical.convert('RGB')) | |
| canonical_img = torch.from_numpy(canonical_img).float().cuda() | |
| h_c, w_c = canonical_img.shape[:2] | |
| grid_new = xy_.clone() | |
| grid_new[..., 1] = xy_[..., 0] / 1.5 | |
| grid_new[..., 0] = xy_[..., 1] / 2.0 | |
| if len(canonical_img.shape) == 3: | |
| canonical_img = canonical_img.unsqueeze(0) | |
| results = torch.nn.functional.grid_sample( | |
| canonical_img.permute(0, 3, 1, 2), | |
| grid_new.unsqueeze(1).unsqueeze(0), | |
| mode='bilinear', | |
| padding_mode='border') | |
| o = results.squeeze().permute(1,0) | |
| if step == 0: | |
| myoutput = o | |
| else: | |
| myoutput = torch.cat([myoutput, o]) | |
| myoutput = myoutput.reshape(512, 512, data_len, 3).permute(2, 0, 1, 3).clone().detach().cpu().numpy().astype(np.float32) | |
| # myoutput = np.clip(myoutput, -1, 1) * 0.5 + 0.5 | |
| for i in range(len(myoutput)): | |
| myoutput[i] = Image.fromarray(np.uint8(myoutput[i])).resize((512, 512)) #854, 480 | |
| edit_video_path = f'NaRCan_fps_10.mp4' | |
| images_to_video(myoutput, edit_video_path) | |
| return edit_video_path | |
| def edit_with_pnp(input_video, prompt, num_steps, guidance_scale, seed, n_prompt, control_type="Lineart"): | |
| video_name = input_video.split('/')[-1] | |
| if video_name in video_to_image: | |
| image_path = video_to_image[video_name][0] | |
| pth_path = video_to_image[video_name][1] | |
| frames_path = video_to_image[video_name][2] | |
| else: | |
| return None | |
| if control_type == "Lineart": | |
| # Load the control net model for lineart | |
| controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_lineart", torch_dtype=torch.float16) | |
| pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 | |
| ) | |
| pipe.to("cuda") | |
| # lineart | |
| processor = LineartDetector.from_pretrained("lllyasviel/Annotators") | |
| processor_partial = partial(processor, coarse=False) | |
| size_ = 768 | |
| canonical_image = Image.open(image_path) | |
| ori_size = canonical_image.size | |
| image = processor_partial(canonical_image.resize((size_, size_)), detect_resolution=size_, image_resolution=size_) | |
| image = image.resize(ori_size, resample=Image.BILINEAR) | |
| generator = torch.manual_seed(seed) if seed != -1 else None | |
| output_images = pipe( | |
| prompt=prompt, | |
| image=image, | |
| num_inference_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| negative_prompt=n_prompt, | |
| generator=generator | |
| ).images | |
| # output_images[0] = output_images[0].resize(ori_size, resample=Image.BILINEAR) | |
| else: | |
| # Load the control net model for canny | |
| controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16) | |
| pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 | |
| ) | |
| pipe.to("cuda") | |
| # canny | |
| canonical_image = cv2.imread(image_path) | |
| canonical_image = cv2.cvtColor(canonical_image, cv2.COLOR_BGR2RGB) | |
| image = cv2.cvtColor(canonical_image, cv2.COLOR_RGB2GRAY) | |
| image = cv2.Canny(image, 100, 200) | |
| image = image[:, :, None] | |
| image = np.concatenate([image, image, image], axis=2) | |
| image = Image.fromarray(image) | |
| generator = torch.manual_seed(seed) if seed != -1 else None | |
| output_images = pipe( | |
| prompt=prompt, | |
| image=image, | |
| num_inference_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| negative_prompt=n_prompt, | |
| generator=generator | |
| ).images | |
| edit_video_path = NaRCan_make_video(output_images[0], pth_path, frames_path) | |
| # Here we return the first output image as the result | |
| return edit_video_path | |
| ######## | |
| # demo # | |
| ######## | |
| intro = """ | |
| <div style="text-align:center"> | |
| <h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;"> | |
| NaRCan - <small>Natural Refined Canonical Image</small> | |
| </h1> | |
| <span>[<a target="_blank" href="https://koi953215.github.io/NaRCan_page/">Project page</a>], [<a target="_blank" href="https://huggingface.co/papers/2406.06523">Paper</a>]</span> | |
| <div style="display:flex; justify-content: center;margin-top: 0.5em">Each edit takes ~10 sec </div> | |
| </div> | |
| """ | |
| with gr.Blocks(css="style.css") as demo: | |
| gr.HTML(intro) | |
| frames = gr.State() | |
| inverted_latents = gr.State() | |
| latents = gr.State() | |
| zs = gr.State() | |
| do_inversion = gr.State(value=True) | |
| with gr.Row(): | |
| input_video = gr.Video(label="Input Video", interactive=False, elem_id="input_video", value='examples/bear.mp4') | |
| output_video = gr.Video(label="Edited Video", interactive=False, elem_id="output_video") | |
| input_video.style(height=365, width=365) | |
| output_video.style(height=365, width=365) | |
| with gr.Row(): | |
| prompt = gr.Textbox( | |
| label="Describe your edited video", | |
| max_lines=1, | |
| value="bear, Van Gogh Style" | |
| # placeholder="bear, Van Gogh Style" | |
| ) | |
| with gr.Row(): | |
| run_button = gr.Button("Edit your video!", visible=True) | |
| max_images = 12 | |
| default_num_images = 3 | |
| with gr.Accordion('Advanced options', open=False): | |
| control_type = gr.Dropdown( | |
| ["Canny", "Lineart"], | |
| label="Control Type", | |
| info="Canny or Lineart", | |
| value="Lineart" | |
| ) | |
| num_steps = gr.Slider(label='Steps', | |
| minimum=1, | |
| maximum=100, | |
| value=20, | |
| step=1) | |
| guidance_scale = gr.Slider(label='Guidance Scale', | |
| minimum=0.1, | |
| maximum=30.0, | |
| value=9.0, | |
| step=0.1) | |
| seed = gr.Slider(label='Seed', | |
| minimum=-1, | |
| maximum=2147483647, | |
| step=1, | |
| randomize=True) | |
| n_prompt = gr.Textbox( | |
| label='Negative Prompt', | |
| value="" | |
| ) | |
| input_video.change( | |
| fn = update_prompt, | |
| inputs = [input_video], | |
| outputs = [prompt], | |
| queue = False) | |
| run_button.click(fn = edit_with_pnp, | |
| inputs = [input_video, | |
| prompt, | |
| num_steps, | |
| guidance_scale, | |
| seed, | |
| n_prompt, | |
| control_type, | |
| ], | |
| outputs = [output_video] | |
| ) | |
| gr.Examples( | |
| examples=get_example(), | |
| label='Examples', | |
| inputs=[input_video], | |
| outputs=[output_video], | |
| examples_per_page=8 | |
| ) | |
| demo.queue() | |
| demo.launch(share=True) |