Spaces:
Runtime error
Runtime error
| import os | |
| import shutil | |
| from enum import Enum | |
| import cv2 | |
| import einops | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms as T | |
| from blendmodes.blend import BlendType, blendLayers | |
| from PIL import Image | |
| from pytorch_lightning import seed_everything | |
| from safetensors.torch import load_file | |
| from skimage import exposure | |
| import src.import_util # noqa: F401 | |
| from ControlNet.annotator.canny import CannyDetector | |
| from ControlNet.annotator.hed import HEDdetector | |
| from ControlNet.annotator.util import HWC3 | |
| from ControlNet.cldm.model import create_model, load_state_dict | |
| from gmflow_module.gmflow.gmflow import GMFlow | |
| from flow.flow_utils import get_warped_and_mask | |
| from sd_model_cfg import model_dict | |
| from src.config import RerenderConfig | |
| from src.controller import AttentionControl | |
| from src.ddim_v_hacked import DDIMVSampler | |
| from src.img_util import find_flat_region, numpy2tensor | |
| from src.video_util import (frame_to_video, get_fps, get_frame_count, | |
| prepare_frames) | |
| import huggingface_hub | |
| repo_name = 'Anonymous-sub/Rerender' | |
| huggingface_hub.hf_hub_download(repo_name, | |
| 'pexels-koolshooters-7322716.mp4', | |
| local_dir='videos') | |
| huggingface_hub.hf_hub_download( | |
| repo_name, | |
| 'pexels-antoni-shkraba-8048492-540x960-25fps.mp4', | |
| local_dir='videos') | |
| huggingface_hub.hf_hub_download( | |
| repo_name, | |
| 'pexels-cottonbro-studio-6649832-960x506-25fps.mp4', | |
| local_dir='videos') | |
| inversed_model_dict = dict() | |
| for k, v in model_dict.items(): | |
| inversed_model_dict[v] = k | |
| to_tensor = T.PILToTensor() | |
| blur = T.GaussianBlur(kernel_size=(9, 9), sigma=(18, 18)) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| class ProcessingState(Enum): | |
| NULL = 0 | |
| FIRST_IMG = 1 | |
| KEY_IMGS = 2 | |
| MAX_KEYFRAME = 8 | |
| class GlobalState: | |
| def __init__(self): | |
| self.sd_model = None | |
| self.ddim_v_sampler = None | |
| self.detector_type = None | |
| self.detector = None | |
| self.controller = None | |
| self.processing_state = ProcessingState.NULL | |
| flow_model = GMFlow( | |
| feature_channels=128, | |
| num_scales=1, | |
| upsample_factor=8, | |
| num_head=1, | |
| attention_type='swin', | |
| ffn_dim_expansion=4, | |
| num_transformer_layers=6, | |
| ).to(device) | |
| checkpoint = torch.load('models/gmflow_sintel-0c07dcb3.pth', | |
| map_location=lambda storage, loc: storage) | |
| weights = checkpoint['model'] if 'model' in checkpoint else checkpoint | |
| flow_model.load_state_dict(weights, strict=False) | |
| flow_model.eval() | |
| self.flow_model = flow_model | |
| def update_controller(self, inner_strength, mask_period, cross_period, | |
| ada_period, warp_period): | |
| self.controller = AttentionControl(inner_strength, mask_period, | |
| cross_period, ada_period, | |
| warp_period) | |
| def update_sd_model(self, sd_model, control_type): | |
| if sd_model == self.sd_model: | |
| return | |
| self.sd_model = sd_model | |
| model = create_model('./ControlNet/models/cldm_v15.yaml').cpu() | |
| if control_type == 'HED': | |
| model.load_state_dict( | |
| load_state_dict(huggingface_hub.hf_hub_download( | |
| 'lllyasviel/ControlNet', './models/control_sd15_hed.pth'), | |
| location=device)) | |
| elif control_type == 'canny': | |
| model.load_state_dict( | |
| load_state_dict(huggingface_hub.hf_hub_download( | |
| 'lllyasviel/ControlNet', 'models/control_sd15_canny.pth'), | |
| location=device)) | |
| model.to(device) | |
| sd_model_path = model_dict[sd_model] | |
| if len(sd_model_path) > 0: | |
| model_ext = os.path.splitext(sd_model_path)[1] | |
| downloaded_model = huggingface_hub.hf_hub_download( | |
| repo_name, sd_model_path) | |
| if model_ext == '.safetensors': | |
| model.load_state_dict(load_file(downloaded_model), | |
| strict=False) | |
| elif model_ext == '.ckpt' or model_ext == '.pth': | |
| model.load_state_dict( | |
| torch.load(downloaded_model)['state_dict'], strict=False) | |
| try: | |
| model.first_stage_model.load_state_dict(torch.load( | |
| huggingface_hub.hf_hub_download( | |
| 'stabilityai/sd-vae-ft-mse-original', | |
| 'vae-ft-mse-840000-ema-pruned.ckpt'))['state_dict'], | |
| strict=False) | |
| except Exception: | |
| print('Warning: We suggest you download the fine-tuned VAE', | |
| 'otherwise the generation quality will be degraded') | |
| self.ddim_v_sampler = DDIMVSampler(model) | |
| def clear_sd_model(self): | |
| self.sd_model = None | |
| self.ddim_v_sampler = None | |
| if device == 'cuda': | |
| torch.cuda.empty_cache() | |
| def update_detector(self, control_type, canny_low=100, canny_high=200): | |
| if self.detector_type == control_type: | |
| return | |
| if control_type == 'HED': | |
| self.detector = HEDdetector() | |
| elif control_type == 'canny': | |
| canny_detector = CannyDetector() | |
| low_threshold = canny_low | |
| high_threshold = canny_high | |
| def apply_canny(x): | |
| return canny_detector(x, low_threshold, high_threshold) | |
| self.detector = apply_canny | |
| global_state = GlobalState() | |
| global_video_path = None | |
| video_frame_count = None | |
| def create_cfg(input_path, prompt, image_resolution, control_strength, | |
| color_preserve, left_crop, right_crop, top_crop, bottom_crop, | |
| control_type, low_threshold, high_threshold, ddim_steps, scale, | |
| seed, sd_model, a_prompt, n_prompt, interval, keyframe_count, | |
| x0_strength, use_constraints, cross_start, cross_end, | |
| style_update_freq, warp_start, warp_end, mask_start, mask_end, | |
| ada_start, ada_end, mask_strength, inner_strength, | |
| smooth_boundary): | |
| use_warp = 'shape-aware fusion' in use_constraints | |
| use_mask = 'pixel-aware fusion' in use_constraints | |
| use_ada = 'color-aware AdaIN' in use_constraints | |
| if not use_warp: | |
| warp_start = 1 | |
| warp_end = 0 | |
| if not use_mask: | |
| mask_start = 1 | |
| mask_end = 0 | |
| if not use_ada: | |
| ada_start = 1 | |
| ada_end = 0 | |
| input_name = os.path.split(input_path)[-1].split('.')[0] | |
| frame_count = 2 + keyframe_count * interval | |
| cfg = RerenderConfig() | |
| cfg.create_from_parameters( | |
| input_path, | |
| os.path.join('result', input_name, 'blend.mp4'), | |
| prompt, | |
| a_prompt=a_prompt, | |
| n_prompt=n_prompt, | |
| frame_count=frame_count, | |
| interval=interval, | |
| crop=[left_crop, right_crop, top_crop, bottom_crop], | |
| sd_model=sd_model, | |
| ddim_steps=ddim_steps, | |
| scale=scale, | |
| control_type=control_type, | |
| control_strength=control_strength, | |
| canny_low=low_threshold, | |
| canny_high=high_threshold, | |
| seed=seed, | |
| image_resolution=image_resolution, | |
| x0_strength=x0_strength, | |
| style_update_freq=style_update_freq, | |
| cross_period=(cross_start, cross_end), | |
| warp_period=(warp_start, warp_end), | |
| mask_period=(mask_start, mask_end), | |
| ada_period=(ada_start, ada_end), | |
| mask_strength=mask_strength, | |
| inner_strength=inner_strength, | |
| smooth_boundary=smooth_boundary, | |
| color_preserve=color_preserve) | |
| return cfg | |
| def cfg_to_input(filename): | |
| cfg = RerenderConfig() | |
| cfg.create_from_path(filename) | |
| keyframe_count = (cfg.frame_count - 2) // cfg.interval | |
| use_constraints = [ | |
| 'shape-aware fusion', 'pixel-aware fusion', 'color-aware AdaIN' | |
| ] | |
| sd_model = inversed_model_dict.get(cfg.sd_model, 'Stable Diffusion 1.5') | |
| args = [ | |
| cfg.input_path, cfg.prompt, cfg.image_resolution, cfg.control_strength, | |
| cfg.color_preserve, *cfg.crop, cfg.control_type, cfg.canny_low, | |
| cfg.canny_high, cfg.ddim_steps, cfg.scale, cfg.seed, sd_model, | |
| cfg.a_prompt, cfg.n_prompt, cfg.interval, keyframe_count, | |
| cfg.x0_strength, use_constraints, *cfg.cross_period, | |
| cfg.style_update_freq, *cfg.warp_period, *cfg.mask_period, | |
| *cfg.ada_period, cfg.mask_strength, cfg.inner_strength, | |
| cfg.smooth_boundary | |
| ] | |
| return args | |
| def setup_color_correction(image): | |
| correction_target = cv2.cvtColor(np.asarray(image.copy()), | |
| cv2.COLOR_RGB2LAB) | |
| return correction_target | |
| def apply_color_correction(correction, original_image): | |
| image = Image.fromarray( | |
| cv2.cvtColor( | |
| exposure.match_histograms(cv2.cvtColor(np.asarray(original_image), | |
| cv2.COLOR_RGB2LAB), | |
| correction, | |
| channel_axis=2), | |
| cv2.COLOR_LAB2RGB).astype('uint8')) | |
| image = blendLayers(image, original_image, BlendType.LUMINOSITY) | |
| return image | |
| def process(*args): | |
| first_frame = process1(*args) | |
| keypath = process2(*args) | |
| return first_frame, keypath | |
| def process0(*args): | |
| global global_video_path | |
| global_video_path = args[0] | |
| return process(*args[1:]) | |
| def process1(*args): | |
| global global_video_path | |
| cfg = create_cfg(global_video_path, *args) | |
| global global_state | |
| global_state.update_sd_model(cfg.sd_model, cfg.control_type) | |
| global_state.update_controller(cfg.inner_strength, cfg.mask_period, | |
| cfg.cross_period, cfg.ada_period, | |
| cfg.warp_period) | |
| global_state.update_detector(cfg.control_type, cfg.canny_low, | |
| cfg.canny_high) | |
| global_state.processing_state = ProcessingState.FIRST_IMG | |
| prepare_frames(cfg.input_path, cfg.input_dir, cfg.image_resolution, | |
| cfg.crop) | |
| ddim_v_sampler = global_state.ddim_v_sampler | |
| model = ddim_v_sampler.model | |
| detector = global_state.detector | |
| controller = global_state.controller | |
| model.control_scales = [cfg.control_strength] * 13 | |
| model.to(device) | |
| num_samples = 1 | |
| eta = 0.0 | |
| imgs = sorted(os.listdir(cfg.input_dir)) | |
| imgs = [os.path.join(cfg.input_dir, img) for img in imgs] | |
| model.cond_stage_model.device = device | |
| with torch.no_grad(): | |
| frame = cv2.imread(imgs[0]) | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| img = HWC3(frame) | |
| H, W, C = img.shape | |
| img_ = numpy2tensor(img) | |
| def generate_first_img(img_, strength): | |
| encoder_posterior = model.encode_first_stage(img_.to(device)) | |
| x0 = model.get_first_stage_encoding(encoder_posterior).detach() | |
| detected_map = detector(img) | |
| detected_map = HWC3(detected_map) | |
| control = torch.from_numpy( | |
| detected_map.copy()).float().to(device) / 255.0 | |
| control = torch.stack([control for _ in range(num_samples)], dim=0) | |
| control = einops.rearrange(control, 'b h w c -> b c h w').clone() | |
| cond = { | |
| 'c_concat': [control], | |
| 'c_crossattn': [ | |
| model.get_learned_conditioning( | |
| [cfg.prompt + ', ' + cfg.a_prompt] * num_samples) | |
| ] | |
| } | |
| un_cond = { | |
| 'c_concat': [control], | |
| 'c_crossattn': | |
| [model.get_learned_conditioning([cfg.n_prompt] * num_samples)] | |
| } | |
| shape = (4, H // 8, W // 8) | |
| controller.set_task('initfirst') | |
| seed_everything(cfg.seed) | |
| samples, _ = ddim_v_sampler.sample( | |
| cfg.ddim_steps, | |
| num_samples, | |
| shape, | |
| cond, | |
| verbose=False, | |
| eta=eta, | |
| unconditional_guidance_scale=cfg.scale, | |
| unconditional_conditioning=un_cond, | |
| controller=controller, | |
| x0=x0, | |
| strength=strength) | |
| x_samples = model.decode_first_stage(samples) | |
| x_samples_np = ( | |
| einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + | |
| 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) | |
| return x_samples, x_samples_np | |
| # When not preserve color, draw a different frame at first and use its | |
| # color to redraw the first frame. | |
| if not cfg.color_preserve: | |
| first_strength = -1 | |
| else: | |
| first_strength = 1 - cfg.x0_strength | |
| x_samples, x_samples_np = generate_first_img(img_, first_strength) | |
| if not cfg.color_preserve: | |
| color_corrections = setup_color_correction( | |
| Image.fromarray(x_samples_np[0])) | |
| global_state.color_corrections = color_corrections | |
| img_ = apply_color_correction(color_corrections, | |
| Image.fromarray(img)) | |
| img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1 | |
| x_samples, x_samples_np = generate_first_img( | |
| img_, 1 - cfg.x0_strength) | |
| global_state.first_result = x_samples | |
| global_state.first_img = img | |
| Image.fromarray(x_samples_np[0]).save( | |
| os.path.join(cfg.first_dir, 'first.jpg')) | |
| return x_samples_np[0] | |
| def process2(*args): | |
| global global_state | |
| global global_video_path | |
| if global_state.processing_state != ProcessingState.FIRST_IMG: | |
| raise gr.Error('Please generate the first key image before generating' | |
| ' all key images') | |
| cfg = create_cfg(global_video_path, *args) | |
| global_state.update_sd_model(cfg.sd_model, cfg.control_type) | |
| global_state.update_detector(cfg.control_type, cfg.canny_low, | |
| cfg.canny_high) | |
| global_state.processing_state = ProcessingState.KEY_IMGS | |
| # reset key dir | |
| shutil.rmtree(cfg.key_dir) | |
| os.makedirs(cfg.key_dir, exist_ok=True) | |
| ddim_v_sampler = global_state.ddim_v_sampler | |
| model = ddim_v_sampler.model | |
| detector = global_state.detector | |
| controller = global_state.controller | |
| flow_model = global_state.flow_model | |
| model.control_scales = [cfg.control_strength] * 13 | |
| num_samples = 1 | |
| eta = 0.0 | |
| firstx0 = True | |
| pixelfusion = cfg.use_mask | |
| imgs = sorted(os.listdir(cfg.input_dir)) | |
| imgs = [os.path.join(cfg.input_dir, img) for img in imgs] | |
| first_result = global_state.first_result | |
| first_img = global_state.first_img | |
| pre_result = first_result | |
| pre_img = first_img | |
| for i in range(0, cfg.frame_count - 1, cfg.interval): | |
| cid = i + 1 | |
| frame = cv2.imread(imgs[i + 1]) | |
| print(cid) | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| img = HWC3(frame) | |
| H, W, C = img.shape | |
| if cfg.color_preserve or global_state.color_corrections is None: | |
| img_ = numpy2tensor(img) | |
| else: | |
| img_ = apply_color_correction(global_state.color_corrections, | |
| Image.fromarray(img)) | |
| img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1 | |
| encoder_posterior = model.encode_first_stage(img_.to(device)) | |
| x0 = model.get_first_stage_encoding(encoder_posterior).detach() | |
| detected_map = detector(img) | |
| detected_map = HWC3(detected_map) | |
| control = torch.from_numpy( | |
| detected_map.copy()).float().to(device) / 255.0 | |
| control = torch.stack([control for _ in range(num_samples)], dim=0) | |
| control = einops.rearrange(control, 'b h w c -> b c h w').clone() | |
| cond = { | |
| 'c_concat': [control], | |
| 'c_crossattn': [ | |
| model.get_learned_conditioning( | |
| [cfg.prompt + ', ' + cfg.a_prompt] * num_samples) | |
| ] | |
| } | |
| un_cond = { | |
| 'c_concat': [control], | |
| 'c_crossattn': | |
| [model.get_learned_conditioning([cfg.n_prompt] * num_samples)] | |
| } | |
| shape = (4, H // 8, W // 8) | |
| cond['c_concat'] = [control] | |
| un_cond['c_concat'] = [control] | |
| image1 = torch.from_numpy(pre_img).permute(2, 0, 1).float() | |
| image2 = torch.from_numpy(img).permute(2, 0, 1).float() | |
| warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask( | |
| flow_model, image1, image2, pre_result, False) | |
| blend_mask_pre = blur( | |
| F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4)) | |
| blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1) | |
| image1 = torch.from_numpy(first_img).permute(2, 0, 1).float() | |
| warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask( | |
| flow_model, image1, image2, first_result, False) | |
| blend_mask_0 = blur( | |
| F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4)) | |
| blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1) | |
| if firstx0: | |
| mask = 1 - F.max_pool2d(blend_mask_0, kernel_size=8) | |
| controller.set_warp( | |
| F.interpolate(bwd_flow_0 / 8.0, | |
| scale_factor=1. / 8, | |
| mode='bilinear'), mask) | |
| else: | |
| mask = 1 - F.max_pool2d(blend_mask_pre, kernel_size=8) | |
| controller.set_warp( | |
| F.interpolate(bwd_flow_pre / 8.0, | |
| scale_factor=1. / 8, | |
| mode='bilinear'), mask) | |
| controller.set_task('keepx0, keepstyle') | |
| seed_everything(cfg.seed) | |
| samples, intermediates = ddim_v_sampler.sample( | |
| cfg.ddim_steps, | |
| num_samples, | |
| shape, | |
| cond, | |
| verbose=False, | |
| eta=eta, | |
| unconditional_guidance_scale=cfg.scale, | |
| unconditional_conditioning=un_cond, | |
| controller=controller, | |
| x0=x0, | |
| strength=1 - cfg.x0_strength) | |
| direct_result = model.decode_first_stage(samples) | |
| if not pixelfusion: | |
| pre_result = direct_result | |
| pre_img = img | |
| viz = ( | |
| einops.rearrange(direct_result, 'b c h w -> b h w c') * 127.5 + | |
| 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) | |
| else: | |
| blend_results = (1 - blend_mask_pre | |
| ) * warped_pre + blend_mask_pre * direct_result | |
| blend_results = ( | |
| 1 - blend_mask_0) * warped_0 + blend_mask_0 * blend_results | |
| bwd_occ = 1 - torch.clamp(1 - bwd_occ_pre + 1 - bwd_occ_0, 0, 1) | |
| blend_mask = blur( | |
| F.max_pool2d(bwd_occ, kernel_size=9, stride=1, padding=4)) | |
| blend_mask = 1 - torch.clamp(blend_mask + bwd_occ, 0, 1) | |
| encoder_posterior = model.encode_first_stage(blend_results) | |
| xtrg = model.get_first_stage_encoding( | |
| encoder_posterior).detach() # * mask | |
| blend_results_rec = model.decode_first_stage(xtrg) | |
| encoder_posterior = model.encode_first_stage(blend_results_rec) | |
| xtrg_rec = model.get_first_stage_encoding( | |
| encoder_posterior).detach() | |
| xtrg_ = (xtrg + 1 * (xtrg - xtrg_rec)) # * mask | |
| blend_results_rec_new = model.decode_first_stage(xtrg_) | |
| tmp = (abs(blend_results_rec_new - blend_results).mean( | |
| dim=1, keepdims=True) > 0.25).float() | |
| mask_x = F.max_pool2d((F.interpolate(tmp, | |
| scale_factor=1 / 8., | |
| mode='bilinear') > 0).float(), | |
| kernel_size=3, | |
| stride=1, | |
| padding=1) | |
| mask = (1 - F.max_pool2d(1 - blend_mask, kernel_size=8) | |
| ) # * (1-mask_x) | |
| if cfg.smooth_boundary: | |
| noise_rescale = find_flat_region(mask) | |
| else: | |
| noise_rescale = torch.ones_like(mask) | |
| masks = [] | |
| for i in range(cfg.ddim_steps): | |
| if i <= cfg.ddim_steps * cfg.mask_period[ | |
| 0] or i >= cfg.ddim_steps * cfg.mask_period[1]: | |
| masks += [None] | |
| else: | |
| masks += [mask * cfg.mask_strength] | |
| # mask 3 | |
| # xtrg = ((1-mask_x) * | |
| # (xtrg + xtrg - xtrg_rec) + mask_x * samples) * mask | |
| # mask 2 | |
| # xtrg = (xtrg + 1 * (xtrg - xtrg_rec)) * mask | |
| xtrg = (xtrg + (1 - mask_x) * (xtrg - xtrg_rec)) * mask # mask 1 | |
| tasks = 'keepstyle, keepx0' | |
| if not firstx0: | |
| tasks += ', updatex0' | |
| if i % cfg.style_update_freq == 0: | |
| tasks += ', updatestyle' | |
| controller.set_task(tasks, 1.0) | |
| seed_everything(cfg.seed) | |
| samples, _ = ddim_v_sampler.sample( | |
| cfg.ddim_steps, | |
| num_samples, | |
| shape, | |
| cond, | |
| verbose=False, | |
| eta=eta, | |
| unconditional_guidance_scale=cfg.scale, | |
| unconditional_conditioning=un_cond, | |
| controller=controller, | |
| x0=x0, | |
| strength=1 - cfg.x0_strength, | |
| xtrg=xtrg, | |
| mask=masks, | |
| noise_rescale=noise_rescale) | |
| x_samples = model.decode_first_stage(samples) | |
| pre_result = x_samples | |
| pre_img = img | |
| viz = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + | |
| 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) | |
| Image.fromarray(viz[0]).save( | |
| os.path.join(cfg.key_dir, f'{cid:04d}.png')) | |
| key_video_path = os.path.join(cfg.work_dir, 'key.mp4') | |
| fps = get_fps(cfg.input_path) | |
| fps //= cfg.interval | |
| frame_to_video(key_video_path, cfg.key_dir, fps, False) | |
| return key_video_path | |
| DESCRIPTION = ''' | |
| ## Rerender A Video | |
| ### This space provides the function of key frame translation. Full code for full video translation will be released upon the publication of the paper. | |
| ### To avoid overload, we set limitations to the maximum frame number (8) and the maximum frame resolution (512x768). | |
| ### The running time of a video of size 512x640 is about 1 minute per keyframe under T4 GPU. | |
| ### Tips: | |
| 1. This method cannot handle large or quick motions where the optical flow is hard to estimate. **Videos with stable motions are preferred**. | |
| 2. Pixel-aware fusion may not work for large or quick motions. | |
| 3. Try different color-aware AdaIN settings and even unuse it to avoid color jittering. | |
| 4. `revAnimated_v11` model for non-photorealstic style, `realisticVisionV20_v20` model for photorealstic style. | |
| 5. To use your own SD/LoRA model, you may clone the space and specify your model with [sd_model_cfg.py](https://huggingface.co/spaces/Anonymous-sub/Rerender/blob/main/sd_model_cfg.py). | |
| 6. This method is based on the original SD model. You may need to [convert](https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py) Diffuser/Automatic1111 models to the original one. | |
| ''' | |
| block = gr.Blocks().queue() | |
| with block: | |
| with gr.Row(): | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_path = gr.Video(label='Input Video', | |
| source='upload', | |
| format='mp4', | |
| visible=True) | |
| prompt = gr.Textbox(label='Prompt') | |
| seed = gr.Slider(label='Seed', | |
| minimum=0, | |
| maximum=2147483647, | |
| step=1, | |
| value=0, | |
| randomize=True) | |
| run_button = gr.Button(value='Run All') | |
| with gr.Row(): | |
| run_button1 = gr.Button(value='Run 1st Key Frame') | |
| run_button2 = gr.Button(value='Run Key Frames') | |
| run_button3 = gr.Button(value='Run Propagation') | |
| with gr.Accordion('Advanced options for the 1st frame translation', | |
| open=False): | |
| image_resolution = gr.Slider( | |
| label='Frame rsolution', | |
| minimum=256, | |
| maximum=512, | |
| value=512, | |
| step=64, | |
| info='To avoid overload, maximum 512') | |
| control_strength = gr.Slider(label='ControNet strength', | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.01) | |
| x0_strength = gr.Slider( | |
| label='Denoising strength', | |
| minimum=0.00, | |
| maximum=1.05, | |
| value=0.75, | |
| step=0.05, | |
| info=('0: fully recover the input.' | |
| '1.05: fully rerender the input.')) | |
| color_preserve = gr.Checkbox( | |
| label='Preserve color', | |
| value=True, | |
| info='Keep the color of the input video') | |
| with gr.Row(): | |
| left_crop = gr.Slider(label='Left crop length', | |
| minimum=0, | |
| maximum=512, | |
| value=0, | |
| step=1) | |
| right_crop = gr.Slider(label='Right crop length', | |
| minimum=0, | |
| maximum=512, | |
| value=0, | |
| step=1) | |
| with gr.Row(): | |
| top_crop = gr.Slider(label='Top crop length', | |
| minimum=0, | |
| maximum=512, | |
| value=0, | |
| step=1) | |
| bottom_crop = gr.Slider(label='Bottom crop length', | |
| minimum=0, | |
| maximum=512, | |
| value=0, | |
| step=1) | |
| with gr.Row(): | |
| control_type = gr.Dropdown(['HED', 'canny'], | |
| label='Control type', | |
| value='HED') | |
| low_threshold = gr.Slider(label='Canny low threshold', | |
| minimum=1, | |
| maximum=255, | |
| value=100, | |
| step=1) | |
| high_threshold = gr.Slider(label='Canny high threshold', | |
| minimum=1, | |
| maximum=255, | |
| value=200, | |
| step=1) | |
| ddim_steps = gr.Slider(label='Steps', | |
| minimum=1, | |
| maximum=20, | |
| value=20, | |
| step=1, | |
| info='To avoid overload, maximum 20') | |
| scale = gr.Slider(label='CFG scale', | |
| minimum=0.1, | |
| maximum=30.0, | |
| value=7.5, | |
| step=0.1) | |
| sd_model_list = list(model_dict.keys()) | |
| sd_model = gr.Dropdown(sd_model_list, | |
| label='Base model', | |
| value='Stable Diffusion 1.5') | |
| a_prompt = gr.Textbox(label='Added prompt', | |
| value='best quality, extremely detailed') | |
| n_prompt = gr.Textbox( | |
| label='Negative prompt', | |
| value=('longbody, lowres, bad anatomy, bad hands, ' | |
| 'missing fingers, extra digit, fewer digits, ' | |
| 'cropped, worst quality, low quality')) | |
| with gr.Accordion('Advanced options for the key fame translation', | |
| open=False): | |
| interval = gr.Slider( | |
| label='Key frame frequency (K)', | |
| minimum=1, | |
| maximum=1, | |
| value=1, | |
| step=1, | |
| info='Uniformly sample the key frames every K frames') | |
| keyframe_count = gr.Slider( | |
| label='Number of key frames', | |
| minimum=1, | |
| maximum=1, | |
| value=1, | |
| step=1, | |
| info='To avoid overload, maximum 8 key frames') | |
| use_constraints = gr.CheckboxGroup( | |
| [ | |
| 'shape-aware fusion', 'pixel-aware fusion', | |
| 'color-aware AdaIN' | |
| ], | |
| label='Select the cross-frame contraints to be used', | |
| value=[ | |
| 'shape-aware fusion', 'pixel-aware fusion', | |
| 'color-aware AdaIN' | |
| ]), | |
| with gr.Row(): | |
| cross_start = gr.Slider( | |
| label='Cross-frame attention start', | |
| minimum=0, | |
| maximum=1, | |
| value=0, | |
| step=0.05) | |
| cross_end = gr.Slider(label='Cross-frame attention end', | |
| minimum=0, | |
| maximum=1, | |
| value=1, | |
| step=0.05) | |
| style_update_freq = gr.Slider( | |
| label='Cross-frame attention update frequency', | |
| minimum=1, | |
| maximum=100, | |
| value=1, | |
| step=1, | |
| info= | |
| ('Update the key and value for ' | |
| 'cross-frame attention every N key frames (recommend N*K>=10)' | |
| )) | |
| with gr.Row(): | |
| warp_start = gr.Slider(label='Shape-aware fusion start', | |
| minimum=0, | |
| maximum=1, | |
| value=0, | |
| step=0.05) | |
| warp_end = gr.Slider(label='Shape-aware fusion end', | |
| minimum=0, | |
| maximum=1, | |
| value=0.1, | |
| step=0.05) | |
| with gr.Row(): | |
| mask_start = gr.Slider(label='Pixel-aware fusion start', | |
| minimum=0, | |
| maximum=1, | |
| value=0.5, | |
| step=0.05) | |
| mask_end = gr.Slider(label='Pixel-aware fusion end', | |
| minimum=0, | |
| maximum=1, | |
| value=0.8, | |
| step=0.05) | |
| with gr.Row(): | |
| ada_start = gr.Slider(label='Color-aware AdaIN start', | |
| minimum=0, | |
| maximum=1, | |
| value=0.8, | |
| step=0.05) | |
| ada_end = gr.Slider(label='Color-aware AdaIN end', | |
| minimum=0, | |
| maximum=1, | |
| value=1, | |
| step=0.05) | |
| mask_strength = gr.Slider(label='Pixel-aware fusion stength', | |
| minimum=0, | |
| maximum=1, | |
| value=0.5, | |
| step=0.01) | |
| inner_strength = gr.Slider( | |
| label='Pixel-aware fusion detail level', | |
| minimum=0.5, | |
| maximum=1, | |
| value=0.9, | |
| step=0.01, | |
| info='Use a low value to prevent artifacts') | |
| smooth_boundary = gr.Checkbox( | |
| label='Smooth fusion boundary', | |
| value=True, | |
| info='Select to prevent artifacts at boundary') | |
| with gr.Accordion('Example configs', open=True): | |
| config_dir = 'config' | |
| config_list = os.listdir(config_dir) | |
| args_list = [] | |
| for config in config_list: | |
| try: | |
| config_path = os.path.join(config_dir, config) | |
| args = cfg_to_input(config_path) | |
| args_list.append(args) | |
| except FileNotFoundError: | |
| # The video file does not exist, skipped | |
| pass | |
| ips = [ | |
| prompt, image_resolution, control_strength, color_preserve, | |
| left_crop, right_crop, top_crop, bottom_crop, control_type, | |
| low_threshold, high_threshold, ddim_steps, scale, seed, | |
| sd_model, a_prompt, n_prompt, interval, keyframe_count, | |
| x0_strength, use_constraints[0], cross_start, cross_end, | |
| style_update_freq, warp_start, warp_end, mask_start, | |
| mask_end, ada_start, ada_end, mask_strength, | |
| inner_strength, smooth_boundary | |
| ] | |
| with gr.Column(): | |
| result_image = gr.Image(label='Output first frame', | |
| type='numpy', | |
| interactive=False) | |
| result_keyframe = gr.Video(label='Output key frame video', | |
| format='mp4', | |
| interactive=False) | |
| with gr.Row(): | |
| gr.Examples(examples=args_list, | |
| inputs=[input_path, *ips], | |
| fn=process0, | |
| outputs=[result_image, result_keyframe], | |
| cache_examples=True) | |
| def input_uploaded(path): | |
| frame_count = get_frame_count(path) | |
| if frame_count <= 2: | |
| raise gr.Error('The input video is too short!' | |
| 'Please input another video.') | |
| default_interval = min(10, frame_count - 2) | |
| max_keyframe = min((frame_count - 2) // default_interval, MAX_KEYFRAME) | |
| global video_frame_count | |
| video_frame_count = frame_count | |
| global global_video_path | |
| global_video_path = path | |
| return gr.Slider.update(value=default_interval, | |
| maximum=MAX_KEYFRAME), gr.Slider.update( | |
| value=max_keyframe, maximum=max_keyframe) | |
| def input_changed(path): | |
| frame_count = get_frame_count(path) | |
| if frame_count <= 2: | |
| return gr.Slider.update(maximum=1), gr.Slider.update(maximum=1) | |
| default_interval = min(10, frame_count - 2) | |
| max_keyframe = min((frame_count - 2) // default_interval, MAX_KEYFRAME) | |
| global video_frame_count | |
| video_frame_count = frame_count | |
| global global_video_path | |
| global_video_path = path | |
| return gr.Slider.update(maximum=max_keyframe), \ | |
| gr.Slider.update(maximum=max_keyframe) | |
| def interval_changed(interval): | |
| global video_frame_count | |
| if video_frame_count is None: | |
| return gr.Slider.update() | |
| max_keyframe = (video_frame_count - 2) // interval | |
| return gr.Slider.update(value=max_keyframe, maximum=max_keyframe) | |
| input_path.change(input_changed, input_path, [interval, keyframe_count]) | |
| input_path.upload(input_uploaded, input_path, [interval, keyframe_count]) | |
| interval.change(interval_changed, interval, keyframe_count) | |
| run_button.click(fn=process, | |
| inputs=ips, | |
| outputs=[result_image, result_keyframe]) | |
| run_button1.click(fn=process1, inputs=ips, outputs=[result_image]) | |
| run_button2.click(fn=process2, inputs=ips, outputs=[result_keyframe]) | |
| def process3(): | |
| raise gr.Error( | |
| "Coming Soon. Full code for full video translation will be " | |
| "released upon the publication of the paper.") | |
| run_button3.click(fn=process3, outputs=[result_keyframe]) | |
| block.launch(server_name='0.0.0.0') | |