File size: 8,237 Bytes
7ccc423
 
2be0048
7ccc423
 
 
 
 
 
2be0048
7ccc423
 
 
 
 
 
 
 
172285b
7ccc423
 
172285b
 
2be0048
 
7ccc423
 
2be0048
7ccc423
 
 
 
2be0048
 
7ccc423
 
 
 
2be0048
 
 
7ccc423
 
 
 
 
 
2be0048
 
7ccc423
 
2be0048
 
 
7ccc423
 
 
 
 
 
 
 
 
 
 
 
 
 
dd93214
 
7ccc423
 
 
 
 
 
dd93214
7ccc423
 
 
 
 
dd93214
2be0048
dd93214
 
 
 
2be0048
 
 
7ccc423
 
2be0048
7ccc423
 
 
 
 
2be0048
7ccc423
 
 
 
 
 
 
daa45ed
7ccc423
daa45ed
 
 
 
 
2be0048
 
daa45ed
 
7ccc423
 
2be0048
7ccc423
 
2be0048
7ccc423
2be0048
 
 
7ccc423
 
2be0048
 
7ccc423
2be0048
 
7ccc423
2be0048
 
7ccc423
 
2be0048
7ccc423
 
 
 
 
2be0048
 
 
7ccc423
572be6e
7ccc423
 
 
 
 
172285b
572be6e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import random
import spaces
from datetime import datetime
import gradio as gr
import numpy as np
import torch
from diffusers import AutoencoderKL, DDIMScheduler
from einops import repeat
from huggingface_hub import snapshot_download
from omegaconf import OmegaConf
from PIL import Image
from torchvision import transforms
from transformers import CLIPVisionModelWithProjection
from src.models.pose_guider import PoseGuider
from src.models.unet_2d_condition import UNet2DConditionModel
from src.models.unet_3d import UNet3DConditionModel
from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
from src.utils.download_models import prepare_base_model, prepare_image_encoder
from src.utils.util import get_fps, read_frames, save_videos_grid

prepare_base_model()
prepare_image_encoder()
snapshot_download(repo_id="stabilityai/sd-vae-ft-mse", local_dir="./pretrained_weights/sd-vae-ft-mse")
snapshot_download(repo_id="patrolli/AnimateAnyone", local_dir="./pretrained_weights")

class AnimateController:
    def __init__(self, config_path="./configs/prompts/animation.yaml", weight_dtype=torch.float16):
        self.config = OmegaConf.load(config_path)
        self.pipeline = None
        self.weight_dtype = weight_dtype

    @spaces.GPU(duration=60)
    def animate(self, ref_image, pose_video_path, width=512, height=768, length=24, num_inference_steps=25, cfg=3.5, seed=123):
        generator = torch.manual_seed(seed)
        if isinstance(ref_image, np.ndarray):
            ref_image = Image.fromarray(ref_image)
        if self.pipeline is None:
            vae = AutoencoderKL.from_pretrained(self.config.pretrained_vae_path).to("cuda", dtype=self.weight_dtype)
            reference_unet = UNet2DConditionModel.from_pretrained(self.config.pretrained_base_model_path, subfolder="unet").to(dtype=self.weight_dtype, device="cuda")
            infer_config = OmegaConf.load(self.config.inference_config)
            denoising_unet = UNet3DConditionModel.from_pretrained_2d(
                self.config.pretrained_base_model_path,
                self.config.motion_module_path,
                subfolder="unet",
                unet_additional_kwargs=infer_config.unet_additional_kwargs,
            ).to(dtype=self.weight_dtype, device="cuda")
            pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(dtype=self.weight_dtype, device="cuda")
            image_enc = CLIPVisionModelWithProjection.from_pretrained(self.config.image_encoder_path).to(dtype=self.weight_dtype, device="cuda")
            sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
            scheduler = DDIMScheduler(**sched_kwargs)
            denoising_unet.load_state_dict(torch.load(self.config.denoising_unet_path, map_location="cpu"), strict=False)
            reference_unet.load_state_dict(torch.load(self.config.reference_unet_path, map_location="cpu"))
            pose_guider.load_state_dict(torch.load(self.config.pose_guider_path, map_location="cpu"))
            pipe = Pose2VideoPipeline(
                vae=vae,
                image_encoder=image_enc,
                reference_unet=reference_unet,
                denoising_unet=denoising_unet,
                pose_guider=pose_guider,
                scheduler=scheduler,
            )
            pipe = pipe.to("cuda", dtype=self.weight_dtype)
            self.pipeline = pipe

        pose_images = read_frames(pose_video_path)
        src_fps = get_fps(pose_video_path)
        pose_list = []
        total_length = min(length, len(pose_images))
        for pose_image_pil in pose_images[:total_length]:
            pose_list.append(pose_image_pil)
        video = self.pipeline(
            ref_image,
            pose_list,
            width=width,
            height=height,
            video_length=total_length,
            num_inference_steps=num_inference_steps,
            guidance_scale=cfg,
            generator=generator,
        ).videos

        new_h, new_w = video.shape[-2:]
        pose_transform = transforms.Compose([transforms.Resize((new_h, new_w)), transforms.ToTensor()])
        pose_tensor_list = []
        for pose_image_pil in pose_images[:total_length]:
            pose_tensor_list.append(pose_transform(pose_image_pil))

        ref_image_tensor = pose_transform(ref_image).unsqueeze(1).unsqueeze(0)
        ref_image_tensor = repeat(ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=total_length)
        pose_tensor = torch.stack(pose_tensor_list, dim=0).transpose(0, 1).unsqueeze(0)
        video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0)

        save_dir = "./output/gradio"
        if not os.path.exists(save_dir):
            os.makedirs(save_dir, exist_ok=True)
        date_str = datetime.now().strftime("%Y%m%d")
        time_str = datetime.now().strftime("%H%M")
        out_path = os.path.join(save_dir, f"{date_str}T{time_str}.mp4")
        save_videos_grid(video, out_path, n_rows=3, fps=src_fps)
        torch.cuda.empty_cache()
        return out_path

controller = AnimateController()

def ui():
    with gr.Blocks() as demo:
        gr.HTML(
            """
            <h1 style="color:#dc5b1c;text-align:center">
                Moore-AnimateAnyone Gradio Demo 
            </h1>
            <div style="text-align:center">
            <div style="display: inline-block; text-align: left;">
            <p>This is a quick preview demo of Moore-AnimateAnyone. We appreciate the assistance provided by the HuggingFace team in setting up this demo.</p> 
            <p>If you like this project, please consider giving a star on <a herf="https://github.com/MooreThreads/Moore-AnimateAnyone">our GitHub repo</a> 🤗.</p>
            </div>
            </div>
            """
        )
        animation = gr.Video(format="mp4", label="Animation Results", height=448, autoplay=True)
        with gr.Row():
            reference_image = gr.Image(label="Reference Image")
            motion_sequence = gr.Video(format="mp4", label="Motion Sequence", height=512)
            with gr.Column():
                width_slider = gr.Slider(label="Width", minimum=448, maximum=768, value=512, step=64)
                height_slider = gr.Slider(label="Height", minimum=512, maximum=960, value=768, step=64)
                length_slider = gr.Slider(label="Video Length", minimum=24, maximum=128, value=72, step=24)
                with gr.Row():
                    seed_textbox = gr.Textbox(label="Seed", value=-1)
                    seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
                    seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
                with gr.Row():
                    sampling_steps = gr.Slider(label="Sampling steps", value=15, info="default: 15", step=5, maximum=20, minimum=10)
                    guidance_scale = gr.Slider(label="Guidance scale", value=3.5, info="default: 3.5", step=0.5, maximum=6.5, minimum=2.0)
                submit = gr.Button("Animate")
        motion_sequence.upload(lambda x: x, motion_sequence, motion_sequence, queue=False)
        reference_image.upload(lambda x: Image.fromarray(x), reference_image, reference_image, queue=False)
        submit.click(
            controller.animate,
            [reference_image, motion_sequence, width_slider, height_slider, length_slider, sampling_steps, guidance_scale, seed_textbox],
            animation,
        )
        gr.Markdown("## Examples")
        gr.Examples(
            examples=[
                ["./configs/inference/ref_images/anyone-5.png", "./configs/inference/pose_videos/anyone-video-2_kps.mp4", 512, 768, 72],
                ["./configs/inference/ref_images/anyone-10.png", "./configs/inference/pose_videos/anyone-video-1_kps.mp4", 512, 768, 72],
                ["./configs/inference/ref_images/anyone-2.png", "./configs/inference/pose_videos/anyone-video-5_kps.mp4", 512, 768, 72],
            ],
            inputs=[reference_image, motion_sequence, width_slider, height_slider, length_slider],
            outputs=animation,
        )
    return demo

demo = ui()
demo.queue(max_size=10)
demo.launch(share=True, show_api=False)