|  | import os | 
					
						
						|  | import shutil | 
					
						
						|  | from enum import Enum | 
					
						
						|  |  | 
					
						
						|  | import cv2 | 
					
						
						|  | import einops | 
					
						
						|  | import gradio as gr | 
					
						
						|  | import numpy as np | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | import torchvision.transforms as T | 
					
						
						|  | from blendmodes.blend import BlendType, blendLayers | 
					
						
						|  | from PIL import Image | 
					
						
						|  | from pytorch_lightning import seed_everything | 
					
						
						|  | from safetensors.torch import load_file | 
					
						
						|  | from skimage import exposure | 
					
						
						|  |  | 
					
						
						|  | import src.import_util | 
					
						
						|  | from ControlNet.annotator.canny import CannyDetector | 
					
						
						|  | from ControlNet.annotator.hed import HEDdetector | 
					
						
						|  | from ControlNet.annotator.util import HWC3 | 
					
						
						|  | from ControlNet.cldm.model import create_model, load_state_dict | 
					
						
						|  | from gmflow_module.gmflow.gmflow import GMFlow | 
					
						
						|  | from flow.flow_utils import get_warped_and_mask | 
					
						
						|  | from sd_model_cfg import model_dict | 
					
						
						|  | from src.config import RerenderConfig | 
					
						
						|  | from src.controller import AttentionControl | 
					
						
						|  | from src.ddim_v_hacked import DDIMVSampler | 
					
						
						|  | from src.img_util import find_flat_region, numpy2tensor | 
					
						
						|  | from src.video_util import (frame_to_video, get_fps, get_frame_count, | 
					
						
						|  | prepare_frames) | 
					
						
						|  |  | 
					
						
						|  | import huggingface_hub | 
					
						
						|  |  | 
					
						
						|  | repo_name = 'Anonymous-sub/Rerender' | 
					
						
						|  |  | 
					
						
						|  | huggingface_hub.hf_hub_download(repo_name, | 
					
						
						|  | 'pexels-koolshooters-7322716.mp4', | 
					
						
						|  | local_dir='videos') | 
					
						
						|  | huggingface_hub.hf_hub_download( | 
					
						
						|  | repo_name, | 
					
						
						|  | 'pexels-antoni-shkraba-8048492-540x960-25fps.mp4', | 
					
						
						|  | local_dir='videos') | 
					
						
						|  | huggingface_hub.hf_hub_download( | 
					
						
						|  | repo_name, | 
					
						
						|  | 'pexels-cottonbro-studio-6649832-960x506-25fps.mp4', | 
					
						
						|  | local_dir='videos') | 
					
						
						|  |  | 
					
						
						|  | inversed_model_dict = dict() | 
					
						
						|  | for k, v in model_dict.items(): | 
					
						
						|  | inversed_model_dict[v] = k | 
					
						
						|  |  | 
					
						
						|  | to_tensor = T.PILToTensor() | 
					
						
						|  | blur = T.GaussianBlur(kernel_size=(9, 9), sigma=(18, 18)) | 
					
						
						|  | device = 'cuda' if torch.cuda.is_available() else 'cpu' | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ProcessingState(Enum): | 
					
						
						|  | NULL = 0 | 
					
						
						|  | FIRST_IMG = 1 | 
					
						
						|  | KEY_IMGS = 2 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | MAX_KEYFRAME = 8 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class GlobalState: | 
					
						
						|  |  | 
					
						
						|  | def __init__(self): | 
					
						
						|  | self.sd_model = None | 
					
						
						|  | self.ddim_v_sampler = None | 
					
						
						|  | self.detector_type = None | 
					
						
						|  | self.detector = None | 
					
						
						|  | self.controller = None | 
					
						
						|  | self.processing_state = ProcessingState.NULL | 
					
						
						|  | flow_model = GMFlow( | 
					
						
						|  | feature_channels=128, | 
					
						
						|  | num_scales=1, | 
					
						
						|  | upsample_factor=8, | 
					
						
						|  | num_head=1, | 
					
						
						|  | attention_type='swin', | 
					
						
						|  | ffn_dim_expansion=4, | 
					
						
						|  | num_transformer_layers=6, | 
					
						
						|  | ).to(device) | 
					
						
						|  |  | 
					
						
						|  | checkpoint = torch.load('models/gmflow_sintel-0c07dcb3.pth', | 
					
						
						|  | map_location=lambda storage, loc: storage) | 
					
						
						|  | weights = checkpoint['model'] if 'model' in checkpoint else checkpoint | 
					
						
						|  | flow_model.load_state_dict(weights, strict=False) | 
					
						
						|  | flow_model.eval() | 
					
						
						|  | self.flow_model = flow_model | 
					
						
						|  |  | 
					
						
						|  | def update_controller(self, inner_strength, mask_period, cross_period, | 
					
						
						|  | ada_period, warp_period): | 
					
						
						|  | self.controller = AttentionControl(inner_strength, mask_period, | 
					
						
						|  | cross_period, ada_period, | 
					
						
						|  | warp_period) | 
					
						
						|  |  | 
					
						
						|  | def update_sd_model(self, sd_model, control_type): | 
					
						
						|  | if sd_model == self.sd_model: | 
					
						
						|  | return | 
					
						
						|  | self.sd_model = sd_model | 
					
						
						|  | model = create_model('./ControlNet/models/cldm_v15.yaml').cpu() | 
					
						
						|  | if control_type == 'HED': | 
					
						
						|  | model.load_state_dict( | 
					
						
						|  | load_state_dict(huggingface_hub.hf_hub_download( | 
					
						
						|  | 'lllyasviel/ControlNet', './models/control_sd15_hed.pth'), | 
					
						
						|  | location=device)) | 
					
						
						|  | elif control_type == 'canny': | 
					
						
						|  | model.load_state_dict( | 
					
						
						|  | load_state_dict(huggingface_hub.hf_hub_download( | 
					
						
						|  | 'lllyasviel/ControlNet', 'models/control_sd15_canny.pth'), | 
					
						
						|  | location=device)) | 
					
						
						|  | model.to(device) | 
					
						
						|  | sd_model_path = model_dict[sd_model] | 
					
						
						|  | if len(sd_model_path) > 0: | 
					
						
						|  | model_ext = os.path.splitext(sd_model_path)[1] | 
					
						
						|  | downloaded_model = huggingface_hub.hf_hub_download( | 
					
						
						|  | repo_name, sd_model_path) | 
					
						
						|  | if model_ext == '.safetensors': | 
					
						
						|  | model.load_state_dict(load_file(downloaded_model), | 
					
						
						|  | strict=False) | 
					
						
						|  | elif model_ext == '.ckpt' or model_ext == '.pth': | 
					
						
						|  | model.load_state_dict( | 
					
						
						|  | torch.load(downloaded_model)['state_dict'], strict=False) | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | model.first_stage_model.load_state_dict(torch.load( | 
					
						
						|  | huggingface_hub.hf_hub_download( | 
					
						
						|  | 'stabilityai/sd-vae-ft-mse-original', | 
					
						
						|  | 'vae-ft-mse-840000-ema-pruned.ckpt'))['state_dict'], | 
					
						
						|  | strict=False) | 
					
						
						|  | except Exception: | 
					
						
						|  | print('Warning: We suggest you download the fine-tuned VAE', | 
					
						
						|  | 'otherwise the generation quality will be degraded') | 
					
						
						|  |  | 
					
						
						|  | self.ddim_v_sampler = DDIMVSampler(model) | 
					
						
						|  |  | 
					
						
						|  | def clear_sd_model(self): | 
					
						
						|  | self.sd_model = None | 
					
						
						|  | self.ddim_v_sampler = None | 
					
						
						|  | if device == 'cuda': | 
					
						
						|  | torch.cuda.empty_cache() | 
					
						
						|  |  | 
					
						
						|  | def update_detector(self, control_type, canny_low=100, canny_high=200): | 
					
						
						|  | if self.detector_type == control_type: | 
					
						
						|  | return | 
					
						
						|  | if control_type == 'HED': | 
					
						
						|  | self.detector = HEDdetector() | 
					
						
						|  | elif control_type == 'canny': | 
					
						
						|  | canny_detector = CannyDetector() | 
					
						
						|  | low_threshold = canny_low | 
					
						
						|  | high_threshold = canny_high | 
					
						
						|  |  | 
					
						
						|  | def apply_canny(x): | 
					
						
						|  | return canny_detector(x, low_threshold, high_threshold) | 
					
						
						|  |  | 
					
						
						|  | self.detector = apply_canny | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | global_state = GlobalState() | 
					
						
						|  | global_video_path = None | 
					
						
						|  | video_frame_count = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def create_cfg(input_path, prompt, image_resolution, control_strength, | 
					
						
						|  | color_preserve, left_crop, right_crop, top_crop, bottom_crop, | 
					
						
						|  | control_type, low_threshold, high_threshold, ddim_steps, scale, | 
					
						
						|  | seed, sd_model, a_prompt, n_prompt, interval, keyframe_count, | 
					
						
						|  | x0_strength, use_constraints, cross_start, cross_end, | 
					
						
						|  | style_update_freq, warp_start, warp_end, mask_start, mask_end, | 
					
						
						|  | ada_start, ada_end, mask_strength, inner_strength, | 
					
						
						|  | smooth_boundary): | 
					
						
						|  | use_warp = 'shape-aware fusion' in use_constraints | 
					
						
						|  | use_mask = 'pixel-aware fusion' in use_constraints | 
					
						
						|  | use_ada = 'color-aware AdaIN' in use_constraints | 
					
						
						|  |  | 
					
						
						|  | if not use_warp: | 
					
						
						|  | warp_start = 1 | 
					
						
						|  | warp_end = 0 | 
					
						
						|  |  | 
					
						
						|  | if not use_mask: | 
					
						
						|  | mask_start = 1 | 
					
						
						|  | mask_end = 0 | 
					
						
						|  |  | 
					
						
						|  | if not use_ada: | 
					
						
						|  | ada_start = 1 | 
					
						
						|  | ada_end = 0 | 
					
						
						|  |  | 
					
						
						|  | input_name = os.path.split(input_path)[-1].split('.')[0] | 
					
						
						|  | frame_count = 2 + keyframe_count * interval | 
					
						
						|  | cfg = RerenderConfig() | 
					
						
						|  | cfg.create_from_parameters( | 
					
						
						|  | input_path, | 
					
						
						|  | os.path.join('result', input_name, 'blend.mp4'), | 
					
						
						|  | prompt, | 
					
						
						|  | a_prompt=a_prompt, | 
					
						
						|  | n_prompt=n_prompt, | 
					
						
						|  | frame_count=frame_count, | 
					
						
						|  | interval=interval, | 
					
						
						|  | crop=[left_crop, right_crop, top_crop, bottom_crop], | 
					
						
						|  | sd_model=sd_model, | 
					
						
						|  | ddim_steps=ddim_steps, | 
					
						
						|  | scale=scale, | 
					
						
						|  | control_type=control_type, | 
					
						
						|  | control_strength=control_strength, | 
					
						
						|  | canny_low=low_threshold, | 
					
						
						|  | canny_high=high_threshold, | 
					
						
						|  | seed=seed, | 
					
						
						|  | image_resolution=image_resolution, | 
					
						
						|  | x0_strength=x0_strength, | 
					
						
						|  | style_update_freq=style_update_freq, | 
					
						
						|  | cross_period=(cross_start, cross_end), | 
					
						
						|  | warp_period=(warp_start, warp_end), | 
					
						
						|  | mask_period=(mask_start, mask_end), | 
					
						
						|  | ada_period=(ada_start, ada_end), | 
					
						
						|  | mask_strength=mask_strength, | 
					
						
						|  | inner_strength=inner_strength, | 
					
						
						|  | smooth_boundary=smooth_boundary, | 
					
						
						|  | color_preserve=color_preserve) | 
					
						
						|  | return cfg | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def cfg_to_input(filename): | 
					
						
						|  |  | 
					
						
						|  | cfg = RerenderConfig() | 
					
						
						|  | cfg.create_from_path(filename) | 
					
						
						|  | keyframe_count = (cfg.frame_count - 2) // cfg.interval | 
					
						
						|  | use_constraints = [ | 
					
						
						|  | 'shape-aware fusion', 'pixel-aware fusion', 'color-aware AdaIN' | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | sd_model = inversed_model_dict.get(cfg.sd_model, 'Stable Diffusion 1.5') | 
					
						
						|  |  | 
					
						
						|  | args = [ | 
					
						
						|  | cfg.input_path, cfg.prompt, cfg.image_resolution, cfg.control_strength, | 
					
						
						|  | cfg.color_preserve, *cfg.crop, cfg.control_type, cfg.canny_low, | 
					
						
						|  | cfg.canny_high, cfg.ddim_steps, cfg.scale, cfg.seed, sd_model, | 
					
						
						|  | cfg.a_prompt, cfg.n_prompt, cfg.interval, keyframe_count, | 
					
						
						|  | cfg.x0_strength, use_constraints, *cfg.cross_period, | 
					
						
						|  | cfg.style_update_freq, *cfg.warp_period, *cfg.mask_period, | 
					
						
						|  | *cfg.ada_period, cfg.mask_strength, cfg.inner_strength, | 
					
						
						|  | cfg.smooth_boundary | 
					
						
						|  | ] | 
					
						
						|  | return args | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def setup_color_correction(image): | 
					
						
						|  | correction_target = cv2.cvtColor(np.asarray(image.copy()), | 
					
						
						|  | cv2.COLOR_RGB2LAB) | 
					
						
						|  | return correction_target | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def apply_color_correction(correction, original_image): | 
					
						
						|  | image = Image.fromarray( | 
					
						
						|  | cv2.cvtColor( | 
					
						
						|  | exposure.match_histograms(cv2.cvtColor(np.asarray(original_image), | 
					
						
						|  | cv2.COLOR_RGB2LAB), | 
					
						
						|  | correction, | 
					
						
						|  | channel_axis=2), | 
					
						
						|  | cv2.COLOR_LAB2RGB).astype('uint8')) | 
					
						
						|  |  | 
					
						
						|  | image = blendLayers(image, original_image, BlendType.LUMINOSITY) | 
					
						
						|  |  | 
					
						
						|  | return image | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def process(*args): | 
					
						
						|  | first_frame = process1(*args) | 
					
						
						|  |  | 
					
						
						|  | keypath = process2(*args) | 
					
						
						|  |  | 
					
						
						|  | return first_frame, keypath | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def process0(*args): | 
					
						
						|  | global global_video_path | 
					
						
						|  | global_video_path = args[0] | 
					
						
						|  | return process(*args[1:]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def process1(*args): | 
					
						
						|  |  | 
					
						
						|  | global global_video_path | 
					
						
						|  | cfg = create_cfg(global_video_path, *args) | 
					
						
						|  | global global_state | 
					
						
						|  | global_state.update_sd_model(cfg.sd_model, cfg.control_type) | 
					
						
						|  | global_state.update_controller(cfg.inner_strength, cfg.mask_period, | 
					
						
						|  | cfg.cross_period, cfg.ada_period, | 
					
						
						|  | cfg.warp_period) | 
					
						
						|  | global_state.update_detector(cfg.control_type, cfg.canny_low, | 
					
						
						|  | cfg.canny_high) | 
					
						
						|  | global_state.processing_state = ProcessingState.FIRST_IMG | 
					
						
						|  |  | 
					
						
						|  | prepare_frames(cfg.input_path, cfg.input_dir, cfg.image_resolution, | 
					
						
						|  | cfg.crop) | 
					
						
						|  |  | 
					
						
						|  | ddim_v_sampler = global_state.ddim_v_sampler | 
					
						
						|  | model = ddim_v_sampler.model | 
					
						
						|  | detector = global_state.detector | 
					
						
						|  | controller = global_state.controller | 
					
						
						|  | model.control_scales = [cfg.control_strength] * 13 | 
					
						
						|  | model.to(device) | 
					
						
						|  |  | 
					
						
						|  | num_samples = 1 | 
					
						
						|  | eta = 0.0 | 
					
						
						|  | imgs = sorted(os.listdir(cfg.input_dir)) | 
					
						
						|  | imgs = [os.path.join(cfg.input_dir, img) for img in imgs] | 
					
						
						|  |  | 
					
						
						|  | model.cond_stage_model.device = device | 
					
						
						|  |  | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | frame = cv2.imread(imgs[0]) | 
					
						
						|  | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | 
					
						
						|  | img = HWC3(frame) | 
					
						
						|  | H, W, C = img.shape | 
					
						
						|  |  | 
					
						
						|  | img_ = numpy2tensor(img) | 
					
						
						|  |  | 
					
						
						|  | def generate_first_img(img_, strength): | 
					
						
						|  | encoder_posterior = model.encode_first_stage(img_.to(device)) | 
					
						
						|  | x0 = model.get_first_stage_encoding(encoder_posterior).detach() | 
					
						
						|  |  | 
					
						
						|  | detected_map = detector(img) | 
					
						
						|  | detected_map = HWC3(detected_map) | 
					
						
						|  |  | 
					
						
						|  | control = torch.from_numpy( | 
					
						
						|  | detected_map.copy()).float().to(device) / 255.0 | 
					
						
						|  | control = torch.stack([control for _ in range(num_samples)], dim=0) | 
					
						
						|  | control = einops.rearrange(control, 'b h w c -> b c h w').clone() | 
					
						
						|  | cond = { | 
					
						
						|  | 'c_concat': [control], | 
					
						
						|  | 'c_crossattn': [ | 
					
						
						|  | model.get_learned_conditioning( | 
					
						
						|  | [cfg.prompt + ', ' + cfg.a_prompt] * num_samples) | 
					
						
						|  | ] | 
					
						
						|  | } | 
					
						
						|  | un_cond = { | 
					
						
						|  | 'c_concat': [control], | 
					
						
						|  | 'c_crossattn': | 
					
						
						|  | [model.get_learned_conditioning([cfg.n_prompt] * num_samples)] | 
					
						
						|  | } | 
					
						
						|  | shape = (4, H // 8, W // 8) | 
					
						
						|  |  | 
					
						
						|  | controller.set_task('initfirst') | 
					
						
						|  | seed_everything(cfg.seed) | 
					
						
						|  |  | 
					
						
						|  | samples, _ = ddim_v_sampler.sample( | 
					
						
						|  | cfg.ddim_steps, | 
					
						
						|  | num_samples, | 
					
						
						|  | shape, | 
					
						
						|  | cond, | 
					
						
						|  | verbose=False, | 
					
						
						|  | eta=eta, | 
					
						
						|  | unconditional_guidance_scale=cfg.scale, | 
					
						
						|  | unconditional_conditioning=un_cond, | 
					
						
						|  | controller=controller, | 
					
						
						|  | x0=x0, | 
					
						
						|  | strength=strength) | 
					
						
						|  | x_samples = model.decode_first_stage(samples) | 
					
						
						|  | x_samples_np = ( | 
					
						
						|  | einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + | 
					
						
						|  | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) | 
					
						
						|  | return x_samples, x_samples_np | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if not cfg.color_preserve: | 
					
						
						|  | first_strength = -1 | 
					
						
						|  | else: | 
					
						
						|  | first_strength = 1 - cfg.x0_strength | 
					
						
						|  |  | 
					
						
						|  | x_samples, x_samples_np = generate_first_img(img_, first_strength) | 
					
						
						|  |  | 
					
						
						|  | if not cfg.color_preserve: | 
					
						
						|  | color_corrections = setup_color_correction( | 
					
						
						|  | Image.fromarray(x_samples_np[0])) | 
					
						
						|  | global_state.color_corrections = color_corrections | 
					
						
						|  | img_ = apply_color_correction(color_corrections, | 
					
						
						|  | Image.fromarray(img)) | 
					
						
						|  | img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1 | 
					
						
						|  | x_samples, x_samples_np = generate_first_img( | 
					
						
						|  | img_, 1 - cfg.x0_strength) | 
					
						
						|  |  | 
					
						
						|  | global_state.first_result = x_samples | 
					
						
						|  | global_state.first_img = img | 
					
						
						|  |  | 
					
						
						|  | Image.fromarray(x_samples_np[0]).save( | 
					
						
						|  | os.path.join(cfg.first_dir, 'first.jpg')) | 
					
						
						|  |  | 
					
						
						|  | return x_samples_np[0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def process2(*args): | 
					
						
						|  | global global_state | 
					
						
						|  | global global_video_path | 
					
						
						|  |  | 
					
						
						|  | if global_state.processing_state != ProcessingState.FIRST_IMG: | 
					
						
						|  | raise gr.Error('Please generate the first key image before generating' | 
					
						
						|  | ' all key images') | 
					
						
						|  |  | 
					
						
						|  | cfg = create_cfg(global_video_path, *args) | 
					
						
						|  | global_state.update_sd_model(cfg.sd_model, cfg.control_type) | 
					
						
						|  | global_state.update_detector(cfg.control_type, cfg.canny_low, | 
					
						
						|  | cfg.canny_high) | 
					
						
						|  | global_state.processing_state = ProcessingState.KEY_IMGS | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | shutil.rmtree(cfg.key_dir) | 
					
						
						|  | os.makedirs(cfg.key_dir, exist_ok=True) | 
					
						
						|  |  | 
					
						
						|  | ddim_v_sampler = global_state.ddim_v_sampler | 
					
						
						|  | model = ddim_v_sampler.model | 
					
						
						|  | detector = global_state.detector | 
					
						
						|  | controller = global_state.controller | 
					
						
						|  | flow_model = global_state.flow_model | 
					
						
						|  | model.control_scales = [cfg.control_strength] * 13 | 
					
						
						|  |  | 
					
						
						|  | num_samples = 1 | 
					
						
						|  | eta = 0.0 | 
					
						
						|  | firstx0 = True | 
					
						
						|  | pixelfusion = cfg.use_mask | 
					
						
						|  | imgs = sorted(os.listdir(cfg.input_dir)) | 
					
						
						|  | imgs = [os.path.join(cfg.input_dir, img) for img in imgs] | 
					
						
						|  |  | 
					
						
						|  | first_result = global_state.first_result | 
					
						
						|  | first_img = global_state.first_img | 
					
						
						|  | pre_result = first_result | 
					
						
						|  | pre_img = first_img | 
					
						
						|  |  | 
					
						
						|  | for i in range(0, cfg.frame_count - 1, cfg.interval): | 
					
						
						|  | cid = i + 1 | 
					
						
						|  | frame = cv2.imread(imgs[i + 1]) | 
					
						
						|  | print(cid) | 
					
						
						|  | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | 
					
						
						|  | img = HWC3(frame) | 
					
						
						|  | H, W, C = img.shape | 
					
						
						|  |  | 
					
						
						|  | if cfg.color_preserve or global_state.color_corrections is None: | 
					
						
						|  | img_ = numpy2tensor(img) | 
					
						
						|  | else: | 
					
						
						|  | img_ = apply_color_correction(global_state.color_corrections, | 
					
						
						|  | Image.fromarray(img)) | 
					
						
						|  | img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1 | 
					
						
						|  | encoder_posterior = model.encode_first_stage(img_.to(device)) | 
					
						
						|  | x0 = model.get_first_stage_encoding(encoder_posterior).detach() | 
					
						
						|  |  | 
					
						
						|  | detected_map = detector(img) | 
					
						
						|  | detected_map = HWC3(detected_map) | 
					
						
						|  |  | 
					
						
						|  | control = torch.from_numpy( | 
					
						
						|  | detected_map.copy()).float().to(device) / 255.0 | 
					
						
						|  | control = torch.stack([control for _ in range(num_samples)], dim=0) | 
					
						
						|  | control = einops.rearrange(control, 'b h w c -> b c h w').clone() | 
					
						
						|  | cond = { | 
					
						
						|  | 'c_concat': [control], | 
					
						
						|  | 'c_crossattn': [ | 
					
						
						|  | model.get_learned_conditioning( | 
					
						
						|  | [cfg.prompt + ', ' + cfg.a_prompt] * num_samples) | 
					
						
						|  | ] | 
					
						
						|  | } | 
					
						
						|  | un_cond = { | 
					
						
						|  | 'c_concat': [control], | 
					
						
						|  | 'c_crossattn': | 
					
						
						|  | [model.get_learned_conditioning([cfg.n_prompt] * num_samples)] | 
					
						
						|  | } | 
					
						
						|  | shape = (4, H // 8, W // 8) | 
					
						
						|  |  | 
					
						
						|  | cond['c_concat'] = [control] | 
					
						
						|  | un_cond['c_concat'] = [control] | 
					
						
						|  |  | 
					
						
						|  | image1 = torch.from_numpy(pre_img).permute(2, 0, 1).float() | 
					
						
						|  | image2 = torch.from_numpy(img).permute(2, 0, 1).float() | 
					
						
						|  | warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask( | 
					
						
						|  | flow_model, image1, image2, pre_result, False) | 
					
						
						|  | blend_mask_pre = blur( | 
					
						
						|  | F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4)) | 
					
						
						|  | blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1) | 
					
						
						|  |  | 
					
						
						|  | image1 = torch.from_numpy(first_img).permute(2, 0, 1).float() | 
					
						
						|  | warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask( | 
					
						
						|  | flow_model, image1, image2, first_result, False) | 
					
						
						|  | blend_mask_0 = blur( | 
					
						
						|  | F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4)) | 
					
						
						|  | blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1) | 
					
						
						|  |  | 
					
						
						|  | if firstx0: | 
					
						
						|  | mask = 1 - F.max_pool2d(blend_mask_0, kernel_size=8) | 
					
						
						|  | controller.set_warp( | 
					
						
						|  | F.interpolate(bwd_flow_0 / 8.0, | 
					
						
						|  | scale_factor=1. / 8, | 
					
						
						|  | mode='bilinear'), mask) | 
					
						
						|  | else: | 
					
						
						|  | mask = 1 - F.max_pool2d(blend_mask_pre, kernel_size=8) | 
					
						
						|  | controller.set_warp( | 
					
						
						|  | F.interpolate(bwd_flow_pre / 8.0, | 
					
						
						|  | scale_factor=1. / 8, | 
					
						
						|  | mode='bilinear'), mask) | 
					
						
						|  |  | 
					
						
						|  | controller.set_task('keepx0, keepstyle') | 
					
						
						|  | seed_everything(cfg.seed) | 
					
						
						|  | samples, intermediates = ddim_v_sampler.sample( | 
					
						
						|  | cfg.ddim_steps, | 
					
						
						|  | num_samples, | 
					
						
						|  | shape, | 
					
						
						|  | cond, | 
					
						
						|  | verbose=False, | 
					
						
						|  | eta=eta, | 
					
						
						|  | unconditional_guidance_scale=cfg.scale, | 
					
						
						|  | unconditional_conditioning=un_cond, | 
					
						
						|  | controller=controller, | 
					
						
						|  | x0=x0, | 
					
						
						|  | strength=1 - cfg.x0_strength) | 
					
						
						|  | direct_result = model.decode_first_stage(samples) | 
					
						
						|  |  | 
					
						
						|  | if not pixelfusion: | 
					
						
						|  | pre_result = direct_result | 
					
						
						|  | pre_img = img | 
					
						
						|  | viz = ( | 
					
						
						|  | einops.rearrange(direct_result, 'b c h w -> b h w c') * 127.5 + | 
					
						
						|  | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | blend_results = (1 - blend_mask_pre | 
					
						
						|  | ) * warped_pre + blend_mask_pre * direct_result | 
					
						
						|  | blend_results = ( | 
					
						
						|  | 1 - blend_mask_0) * warped_0 + blend_mask_0 * blend_results | 
					
						
						|  |  | 
					
						
						|  | bwd_occ = 1 - torch.clamp(1 - bwd_occ_pre + 1 - bwd_occ_0, 0, 1) | 
					
						
						|  | blend_mask = blur( | 
					
						
						|  | F.max_pool2d(bwd_occ, kernel_size=9, stride=1, padding=4)) | 
					
						
						|  | blend_mask = 1 - torch.clamp(blend_mask + bwd_occ, 0, 1) | 
					
						
						|  |  | 
					
						
						|  | encoder_posterior = model.encode_first_stage(blend_results) | 
					
						
						|  | xtrg = model.get_first_stage_encoding( | 
					
						
						|  | encoder_posterior).detach() | 
					
						
						|  | blend_results_rec = model.decode_first_stage(xtrg) | 
					
						
						|  | encoder_posterior = model.encode_first_stage(blend_results_rec) | 
					
						
						|  | xtrg_rec = model.get_first_stage_encoding( | 
					
						
						|  | encoder_posterior).detach() | 
					
						
						|  | xtrg_ = (xtrg + 1 * (xtrg - xtrg_rec)) | 
					
						
						|  | blend_results_rec_new = model.decode_first_stage(xtrg_) | 
					
						
						|  | tmp = (abs(blend_results_rec_new - blend_results).mean( | 
					
						
						|  | dim=1, keepdims=True) > 0.25).float() | 
					
						
						|  | mask_x = F.max_pool2d((F.interpolate(tmp, | 
					
						
						|  | scale_factor=1 / 8., | 
					
						
						|  | mode='bilinear') > 0).float(), | 
					
						
						|  | kernel_size=3, | 
					
						
						|  | stride=1, | 
					
						
						|  | padding=1) | 
					
						
						|  |  | 
					
						
						|  | mask = (1 - F.max_pool2d(1 - blend_mask, kernel_size=8) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if cfg.smooth_boundary: | 
					
						
						|  | noise_rescale = find_flat_region(mask) | 
					
						
						|  | else: | 
					
						
						|  | noise_rescale = torch.ones_like(mask) | 
					
						
						|  | masks = [] | 
					
						
						|  | for i in range(cfg.ddim_steps): | 
					
						
						|  | if i <= cfg.ddim_steps * cfg.mask_period[ | 
					
						
						|  | 0] or i >= cfg.ddim_steps * cfg.mask_period[1]: | 
					
						
						|  | masks += [None] | 
					
						
						|  | else: | 
					
						
						|  | masks += [mask * cfg.mask_strength] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | xtrg = (xtrg + (1 - mask_x) * (xtrg - xtrg_rec)) * mask | 
					
						
						|  |  | 
					
						
						|  | tasks = 'keepstyle, keepx0' | 
					
						
						|  | if not firstx0: | 
					
						
						|  | tasks += ', updatex0' | 
					
						
						|  | if i % cfg.style_update_freq == 0: | 
					
						
						|  | tasks += ', updatestyle' | 
					
						
						|  | controller.set_task(tasks, 1.0) | 
					
						
						|  |  | 
					
						
						|  | seed_everything(cfg.seed) | 
					
						
						|  | samples, _ = ddim_v_sampler.sample( | 
					
						
						|  | cfg.ddim_steps, | 
					
						
						|  | num_samples, | 
					
						
						|  | shape, | 
					
						
						|  | cond, | 
					
						
						|  | verbose=False, | 
					
						
						|  | eta=eta, | 
					
						
						|  | unconditional_guidance_scale=cfg.scale, | 
					
						
						|  | unconditional_conditioning=un_cond, | 
					
						
						|  | controller=controller, | 
					
						
						|  | x0=x0, | 
					
						
						|  | strength=1 - cfg.x0_strength, | 
					
						
						|  | xtrg=xtrg, | 
					
						
						|  | mask=masks, | 
					
						
						|  | noise_rescale=noise_rescale) | 
					
						
						|  | x_samples = model.decode_first_stage(samples) | 
					
						
						|  | pre_result = x_samples | 
					
						
						|  | pre_img = img | 
					
						
						|  |  | 
					
						
						|  | viz = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + | 
					
						
						|  | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) | 
					
						
						|  |  | 
					
						
						|  | Image.fromarray(viz[0]).save( | 
					
						
						|  | os.path.join(cfg.key_dir, f'{cid:04d}.png')) | 
					
						
						|  |  | 
					
						
						|  | key_video_path = os.path.join(cfg.work_dir, 'key.mp4') | 
					
						
						|  | fps = get_fps(cfg.input_path) | 
					
						
						|  | fps //= cfg.interval | 
					
						
						|  | frame_to_video(key_video_path, cfg.key_dir, fps, False) | 
					
						
						|  |  | 
					
						
						|  | return key_video_path | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | DESCRIPTION = ''' | 
					
						
						|  | ## Rerender A Video | 
					
						
						|  | ### This space provides the function of key frame translation. Full code for full video translation will be released upon the publication of the paper. | 
					
						
						|  | ### To avoid overload, we set limitations to the maximum frame number (8) and the maximum frame resolution (512x768). | 
					
						
						|  | ### The running time of a video of size 512x640 is about 1 minute per keyframe under T4 GPU. | 
					
						
						|  | ### Tips: | 
					
						
						|  | 1. This method cannot handle large or quick motions where the optical flow is hard to estimate. **Videos with stable motions are preferred**. | 
					
						
						|  | 2. Pixel-aware fusion may not work for large or quick motions. | 
					
						
						|  | 3. Try different color-aware AdaIN settings and even unuse it to avoid color jittering. | 
					
						
						|  | 4. `revAnimated_v11` model for non-photorealstic style, `realisticVisionV20_v20` model for photorealstic style. | 
					
						
						|  | 5. To use your own SD/LoRA model, you may clone the space and specify your model with [sd_model_cfg.py](https://huggingface.co/spaces/Anonymous-sub/Rerender/blob/main/sd_model_cfg.py). | 
					
						
						|  | 6. This method is based on the original SD model. You may need to [convert](https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py) Diffuser/Automatic1111 models to the original one. | 
					
						
						|  | ''' | 
					
						
						|  |  | 
					
						
						|  | block = gr.Blocks().queue() | 
					
						
						|  | with block: | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | gr.Markdown(DESCRIPTION) | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | with gr.Column(): | 
					
						
						|  | input_path = gr.Video(label='Input Video', | 
					
						
						|  | source='upload', | 
					
						
						|  | format='mp4', | 
					
						
						|  | visible=True) | 
					
						
						|  | prompt = gr.Textbox(label='Prompt') | 
					
						
						|  | seed = gr.Slider(label='Seed', | 
					
						
						|  | minimum=0, | 
					
						
						|  | maximum=2147483647, | 
					
						
						|  | step=1, | 
					
						
						|  | value=0, | 
					
						
						|  | randomize=True) | 
					
						
						|  | run_button = gr.Button(value='Run All') | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | run_button1 = gr.Button(value='Run 1st Key Frame') | 
					
						
						|  | run_button2 = gr.Button(value='Run Key Frames') | 
					
						
						|  | run_button3 = gr.Button(value='Run Propagation') | 
					
						
						|  | with gr.Accordion('Advanced options for the 1st frame translation', | 
					
						
						|  | open=False): | 
					
						
						|  | image_resolution = gr.Slider( | 
					
						
						|  | label='Frame rsolution', | 
					
						
						|  | minimum=256, | 
					
						
						|  | maximum=512, | 
					
						
						|  | value=512, | 
					
						
						|  | step=64, | 
					
						
						|  | info='To avoid overload, maximum 512') | 
					
						
						|  | control_strength = gr.Slider(label='ControNet strength', | 
					
						
						|  | minimum=0.0, | 
					
						
						|  | maximum=2.0, | 
					
						
						|  | value=1.0, | 
					
						
						|  | step=0.01) | 
					
						
						|  | x0_strength = gr.Slider( | 
					
						
						|  | label='Denoising strength', | 
					
						
						|  | minimum=0.00, | 
					
						
						|  | maximum=1.05, | 
					
						
						|  | value=0.75, | 
					
						
						|  | step=0.05, | 
					
						
						|  | info=('0: fully recover the input.' | 
					
						
						|  | '1.05: fully rerender the input.')) | 
					
						
						|  | color_preserve = gr.Checkbox( | 
					
						
						|  | label='Preserve color', | 
					
						
						|  | value=True, | 
					
						
						|  | info='Keep the color of the input video') | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | left_crop = gr.Slider(label='Left crop length', | 
					
						
						|  | minimum=0, | 
					
						
						|  | maximum=512, | 
					
						
						|  | value=0, | 
					
						
						|  | step=1) | 
					
						
						|  | right_crop = gr.Slider(label='Right crop length', | 
					
						
						|  | minimum=0, | 
					
						
						|  | maximum=512, | 
					
						
						|  | value=0, | 
					
						
						|  | step=1) | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | top_crop = gr.Slider(label='Top crop length', | 
					
						
						|  | minimum=0, | 
					
						
						|  | maximum=512, | 
					
						
						|  | value=0, | 
					
						
						|  | step=1) | 
					
						
						|  | bottom_crop = gr.Slider(label='Bottom crop length', | 
					
						
						|  | minimum=0, | 
					
						
						|  | maximum=512, | 
					
						
						|  | value=0, | 
					
						
						|  | step=1) | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | control_type = gr.Dropdown(['HED', 'canny'], | 
					
						
						|  | label='Control type', | 
					
						
						|  | value='HED') | 
					
						
						|  | low_threshold = gr.Slider(label='Canny low threshold', | 
					
						
						|  | minimum=1, | 
					
						
						|  | maximum=255, | 
					
						
						|  | value=100, | 
					
						
						|  | step=1) | 
					
						
						|  | high_threshold = gr.Slider(label='Canny high threshold', | 
					
						
						|  | minimum=1, | 
					
						
						|  | maximum=255, | 
					
						
						|  | value=200, | 
					
						
						|  | step=1) | 
					
						
						|  | ddim_steps = gr.Slider(label='Steps', | 
					
						
						|  | minimum=1, | 
					
						
						|  | maximum=20, | 
					
						
						|  | value=20, | 
					
						
						|  | step=1, | 
					
						
						|  | info='To avoid overload, maximum 20') | 
					
						
						|  | scale = gr.Slider(label='CFG scale', | 
					
						
						|  | minimum=0.1, | 
					
						
						|  | maximum=30.0, | 
					
						
						|  | value=7.5, | 
					
						
						|  | step=0.1) | 
					
						
						|  | sd_model_list = list(model_dict.keys()) | 
					
						
						|  | sd_model = gr.Dropdown(sd_model_list, | 
					
						
						|  | label='Base model', | 
					
						
						|  | value='Stable Diffusion 1.5') | 
					
						
						|  | a_prompt = gr.Textbox(label='Added prompt', | 
					
						
						|  | value='best quality, extremely detailed') | 
					
						
						|  | n_prompt = gr.Textbox( | 
					
						
						|  | label='Negative prompt', | 
					
						
						|  | value=('longbody, lowres, bad anatomy, bad hands, ' | 
					
						
						|  | 'missing fingers, extra digit, fewer digits, ' | 
					
						
						|  | 'cropped, worst quality, low quality')) | 
					
						
						|  | with gr.Accordion('Advanced options for the key fame translation', | 
					
						
						|  | open=False): | 
					
						
						|  | interval = gr.Slider( | 
					
						
						|  | label='Key frame frequency (K)', | 
					
						
						|  | minimum=1, | 
					
						
						|  | maximum=1, | 
					
						
						|  | value=1, | 
					
						
						|  | step=1, | 
					
						
						|  | info='Uniformly sample the key frames every K frames') | 
					
						
						|  | keyframe_count = gr.Slider( | 
					
						
						|  | label='Number of key frames', | 
					
						
						|  | minimum=1, | 
					
						
						|  | maximum=1, | 
					
						
						|  | value=1, | 
					
						
						|  | step=1, | 
					
						
						|  | info='To avoid overload, maximum 8 key frames') | 
					
						
						|  |  | 
					
						
						|  | use_constraints = gr.CheckboxGroup( | 
					
						
						|  | [ | 
					
						
						|  | 'shape-aware fusion', 'pixel-aware fusion', | 
					
						
						|  | 'color-aware AdaIN' | 
					
						
						|  | ], | 
					
						
						|  | label='Select the cross-frame contraints to be used', | 
					
						
						|  | value=[ | 
					
						
						|  | 'shape-aware fusion', 'pixel-aware fusion', | 
					
						
						|  | 'color-aware AdaIN' | 
					
						
						|  | ]), | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | cross_start = gr.Slider( | 
					
						
						|  | label='Cross-frame attention start', | 
					
						
						|  | minimum=0, | 
					
						
						|  | maximum=1, | 
					
						
						|  | value=0, | 
					
						
						|  | step=0.05) | 
					
						
						|  | cross_end = gr.Slider(label='Cross-frame attention end', | 
					
						
						|  | minimum=0, | 
					
						
						|  | maximum=1, | 
					
						
						|  | value=1, | 
					
						
						|  | step=0.05) | 
					
						
						|  | style_update_freq = gr.Slider( | 
					
						
						|  | label='Cross-frame attention update frequency', | 
					
						
						|  | minimum=1, | 
					
						
						|  | maximum=100, | 
					
						
						|  | value=1, | 
					
						
						|  | step=1, | 
					
						
						|  | info= | 
					
						
						|  | ('Update the key and value for ' | 
					
						
						|  | 'cross-frame attention every N key frames (recommend N*K>=10)' | 
					
						
						|  | )) | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | warp_start = gr.Slider(label='Shape-aware fusion start', | 
					
						
						|  | minimum=0, | 
					
						
						|  | maximum=1, | 
					
						
						|  | value=0, | 
					
						
						|  | step=0.05) | 
					
						
						|  | warp_end = gr.Slider(label='Shape-aware fusion end', | 
					
						
						|  | minimum=0, | 
					
						
						|  | maximum=1, | 
					
						
						|  | value=0.1, | 
					
						
						|  | step=0.05) | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | mask_start = gr.Slider(label='Pixel-aware fusion start', | 
					
						
						|  | minimum=0, | 
					
						
						|  | maximum=1, | 
					
						
						|  | value=0.5, | 
					
						
						|  | step=0.05) | 
					
						
						|  | mask_end = gr.Slider(label='Pixel-aware fusion end', | 
					
						
						|  | minimum=0, | 
					
						
						|  | maximum=1, | 
					
						
						|  | value=0.8, | 
					
						
						|  | step=0.05) | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | ada_start = gr.Slider(label='Color-aware AdaIN start', | 
					
						
						|  | minimum=0, | 
					
						
						|  | maximum=1, | 
					
						
						|  | value=0.8, | 
					
						
						|  | step=0.05) | 
					
						
						|  | ada_end = gr.Slider(label='Color-aware AdaIN end', | 
					
						
						|  | minimum=0, | 
					
						
						|  | maximum=1, | 
					
						
						|  | value=1, | 
					
						
						|  | step=0.05) | 
					
						
						|  | mask_strength = gr.Slider(label='Pixel-aware fusion stength', | 
					
						
						|  | minimum=0, | 
					
						
						|  | maximum=1, | 
					
						
						|  | value=0.5, | 
					
						
						|  | step=0.01) | 
					
						
						|  | inner_strength = gr.Slider( | 
					
						
						|  | label='Pixel-aware fusion detail level', | 
					
						
						|  | minimum=0.5, | 
					
						
						|  | maximum=1, | 
					
						
						|  | value=0.9, | 
					
						
						|  | step=0.01, | 
					
						
						|  | info='Use a low value to prevent artifacts') | 
					
						
						|  | smooth_boundary = gr.Checkbox( | 
					
						
						|  | label='Smooth fusion boundary', | 
					
						
						|  | value=True, | 
					
						
						|  | info='Select to prevent artifacts at boundary') | 
					
						
						|  |  | 
					
						
						|  | with gr.Accordion('Example configs', open=True): | 
					
						
						|  | config_dir = 'config' | 
					
						
						|  | config_list = os.listdir(config_dir) | 
					
						
						|  | args_list = [] | 
					
						
						|  | for config in config_list: | 
					
						
						|  | try: | 
					
						
						|  | config_path = os.path.join(config_dir, config) | 
					
						
						|  | args = cfg_to_input(config_path) | 
					
						
						|  | args_list.append(args) | 
					
						
						|  | except FileNotFoundError: | 
					
						
						|  |  | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  | ips = [ | 
					
						
						|  | prompt, image_resolution, control_strength, color_preserve, | 
					
						
						|  | left_crop, right_crop, top_crop, bottom_crop, control_type, | 
					
						
						|  | low_threshold, high_threshold, ddim_steps, scale, seed, | 
					
						
						|  | sd_model, a_prompt, n_prompt, interval, keyframe_count, | 
					
						
						|  | x0_strength, use_constraints[0], cross_start, cross_end, | 
					
						
						|  | style_update_freq, warp_start, warp_end, mask_start, | 
					
						
						|  | mask_end, ada_start, ada_end, mask_strength, | 
					
						
						|  | inner_strength, smooth_boundary | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | with gr.Column(): | 
					
						
						|  | result_image = gr.Image(label='Output first frame', | 
					
						
						|  | type='numpy', | 
					
						
						|  | interactive=False) | 
					
						
						|  | result_keyframe = gr.Video(label='Output key frame video', | 
					
						
						|  | format='mp4', | 
					
						
						|  | interactive=False) | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | gr.Examples(examples=args_list, | 
					
						
						|  | inputs=[input_path, *ips], | 
					
						
						|  | fn=process0, | 
					
						
						|  | outputs=[result_image, result_keyframe], | 
					
						
						|  | cache_examples=True) | 
					
						
						|  |  | 
					
						
						|  | def input_uploaded(path): | 
					
						
						|  | frame_count = get_frame_count(path) | 
					
						
						|  | if frame_count <= 2: | 
					
						
						|  | raise gr.Error('The input video is too short!' | 
					
						
						|  | 'Please input another video.') | 
					
						
						|  |  | 
					
						
						|  | default_interval = min(10, frame_count - 2) | 
					
						
						|  | max_keyframe = min((frame_count - 2) // default_interval, MAX_KEYFRAME) | 
					
						
						|  |  | 
					
						
						|  | global video_frame_count | 
					
						
						|  | video_frame_count = frame_count | 
					
						
						|  | global global_video_path | 
					
						
						|  | global_video_path = path | 
					
						
						|  |  | 
					
						
						|  | return gr.Slider.update(value=default_interval, | 
					
						
						|  | maximum=MAX_KEYFRAME), gr.Slider.update( | 
					
						
						|  | value=max_keyframe, maximum=max_keyframe) | 
					
						
						|  |  | 
					
						
						|  | def input_changed(path): | 
					
						
						|  | frame_count = get_frame_count(path) | 
					
						
						|  | if frame_count <= 2: | 
					
						
						|  | return gr.Slider.update(maximum=1), gr.Slider.update(maximum=1) | 
					
						
						|  |  | 
					
						
						|  | default_interval = min(10, frame_count - 2) | 
					
						
						|  | max_keyframe = min((frame_count - 2) // default_interval, MAX_KEYFRAME) | 
					
						
						|  |  | 
					
						
						|  | global video_frame_count | 
					
						
						|  | video_frame_count = frame_count | 
					
						
						|  | global global_video_path | 
					
						
						|  | global_video_path = path | 
					
						
						|  |  | 
					
						
						|  | return gr.Slider.update(maximum=max_keyframe), \ | 
					
						
						|  | gr.Slider.update(maximum=max_keyframe) | 
					
						
						|  |  | 
					
						
						|  | def interval_changed(interval): | 
					
						
						|  | global video_frame_count | 
					
						
						|  | if video_frame_count is None: | 
					
						
						|  | return gr.Slider.update() | 
					
						
						|  |  | 
					
						
						|  | max_keyframe = (video_frame_count - 2) // interval | 
					
						
						|  |  | 
					
						
						|  | return gr.Slider.update(value=max_keyframe, maximum=max_keyframe) | 
					
						
						|  |  | 
					
						
						|  | input_path.change(input_changed, input_path, [interval, keyframe_count]) | 
					
						
						|  | input_path.upload(input_uploaded, input_path, [interval, keyframe_count]) | 
					
						
						|  | interval.change(interval_changed, interval, keyframe_count) | 
					
						
						|  |  | 
					
						
						|  | run_button.click(fn=process, | 
					
						
						|  | inputs=ips, | 
					
						
						|  | outputs=[result_image, result_keyframe]) | 
					
						
						|  | run_button1.click(fn=process1, inputs=ips, outputs=[result_image]) | 
					
						
						|  | run_button2.click(fn=process2, inputs=ips, outputs=[result_keyframe]) | 
					
						
						|  |  | 
					
						
						|  | def process3(): | 
					
						
						|  | raise gr.Error( | 
					
						
						|  | "Coming Soon. Full code for full video translation will be " | 
					
						
						|  | "released upon the publication of the paper.") | 
					
						
						|  |  | 
					
						
						|  | run_button3.click(fn=process3, outputs=[result_keyframe]) | 
					
						
						|  |  | 
					
						
						|  | block.launch(server_name='0.0.0.0') | 
					
						
						|  |  |