File size: 8,646 Bytes
e8bdafd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import argparse
import json
import os
from uuid import uuid4

import gradio as gr
import imageio
import numpy as np
import torch
from einops import rearrange
from PIL import Image

from image_to_video import Image2Video, _resize_for_rectangle_crop, camera_pose_lerp
from demo.preview import Previewer
from demo.prompt_extend import QwenPromptExpander

torch.backends.cuda.matmul.allow_tf32 = True


def load_model_name():
    with open(args.model_meta_path, "r") as f:
        data = json.load(f)

    return list(filter(lambda x: "interp" not in x, data.keys()))


def load_camera_pose_type():
    with open(args.camera_pose_meta_path, "r") as f:
        data = json.load(f)

    return list(data.keys())


def main(args):
    captioner = QwenPromptExpander(args.caption_model_path, device=args.device)
    previewer = Previewer(args.depth_model_path)
    image2video = Image2Video(
        args.result_dir,
        args.model_meta_path,
        args.camera_pose_meta_path,
        args.save_fps,
        device=args.device,
    )

    with gr.Blocks(analytics_enabled=False, css=r"""
        #input_img img {height: 480px !important;}
        #output_vid video {width: auto !important; margin: auto !important;}
    """) as demo:
        gr.Markdown("""
            <div align='center'>
                <h1> RealCam-I2V (CogVideoX-1.5-5B-I2V) </h1>
            </div>
        """)

        with gr.Row(equal_height=True):
            input_image = gr.Image(label="Input Image")
            output_3d = gr.Model3D(label="Camera Trajectory", clear_color=[1.0, 1.0, 1.0, 1.0], visible=False)
            preview_video = gr.Video(label="Preview Video", interactive=False, autoplay=True, loop=True)
            output_video1 = gr.Video(label="New Generated Video", elem_id="output_vid", interactive=False, autoplay=True, loop=True)
            output_video2 = gr.Video(label="Previous Generated Video", elem_id="output_vid", interactive=False, autoplay=True, loop=True, visible=False)

        with gr.Row():
            reload_btn = gr.Button("Reload")
            preview_btn = gr.Button("Preview")
            end_btn = gr.Button("Generate")

        with gr.Row(equal_height=True):
            input_text = gr.Textbox(label='Prompt', scale=4)
            caption_btn = gr.Button("Caption")

        with gr.Row():
            negative_prompt = gr.Textbox(label='Negative Prompt', value="Fast movement, jittery motion, abrupt transitions, distorted body, missing limbs, unnatural posture, blurry, cropped, extra limbs, bad anatomy, deformed, glitchy motion, artifacts.")

        with gr.Row(equal_height=True):
            with gr.Column():
                model_name = gr.Dropdown(label='Model Name', choices=load_model_name())
                camera_pose_type = gr.Dropdown(label='Camera Pose Type', choices=load_camera_pose_type())
                seed = gr.Slider(label="Random Seed", minimum=0, maximum=2**31, step=1, value=12333)

            with gr.Column(scale=2):
                with gr.Row():
                    steps = gr.Slider(minimum=1, maximum=250, step=1, label="Sampling Steps (DDPM)", value=25)
                    text_cfg = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Text CFG', value=5)
                    camera_cfg = gr.Slider(minimum=1.0, maximum=5.0, step=0.1, label="Camera CFG", value=1.0, visible=False)
                with gr.Row():
                    trace_extract_ratio = gr.Slider(minimum=0, maximum=1.0, step=0.1, label="Trace Extract Ratio", value=1.0)
                    trace_scale_factor = gr.Slider(minimum=0, maximum=5, step=0.1, label="Camera Trace Scale Factor", value=1.0)
                with gr.Row(equal_height=True):
                    frames = gr.Slider(minimum=17, maximum=161, step=16, label="Video Frames", value=49)
                    height = gr.Slider(minimum=256, maximum=1360, step=16, label="Video Height", value=512)
                    width = gr.Slider(minimum=448, maximum=1360, step=16, label="Video Width", value=896)
                    switch_aspect_ratio = gr.Button("Switch HW")
                with gr.Row(visible=False):
                    noise_shaping = gr.Checkbox(label='Enable Noise Shaping', value=False)
                    noise_shaping_minimum_timesteps = gr.Slider(minimum=0, maximum=1000, step=1, label="Noise Shaping Minimum Timesteps", value=900)

        input_image.upload(fn=lambda : "", outputs=[input_text])

        def caption(*inputs):
            image2video.offload_cpu()
            prompt, image = inputs
            return captioner(prompt, tar_lang="en", image=Image.fromarray(image)).prompt

        caption_btn.click(fn=caption, inputs=[input_text, input_image], outputs=[input_text])

        def preview(input_image, camera_pose_type, trace_extract_ratio, sample_frames, sample_height, sample_width):
            frames = rearrange(torch.from_numpy(input_image), "h w c -> c 1 h w")
            frames, resized_H, resized_W = _resize_for_rectangle_crop(frames, sample_height, sample_width)
            frames = rearrange(frames, "c 1 h w -> 1 h w c").numpy()

            with open(args.camera_pose_meta_path, "r", encoding="utf-8") as f:
                camera_pose_file_path = json.load(f)[camera_pose_type]
            camera_data = torch.from_numpy(np.loadtxt(camera_pose_file_path, comments="https"))  # t, 17

            w2cs_3x4 = camera_data[:, 7:].reshape(-1, 3, 4)
            dummy = torch.tensor([[[0, 0, 0, 1]]] * w2cs_3x4.shape[0])
            w2cs_4x4 = torch.cat([w2cs_3x4, dummy], dim=1)
            c2ws_4x4 = w2cs_4x4.inverse()
            c2ws_lerp_4x4 = camera_pose_lerp(c2ws_4x4, round(sample_frames / trace_extract_ratio))[: sample_frames]
            w2cs_lerp_4x4 = c2ws_lerp_4x4.inverse().numpy()

            fx = fy = 0.5 * max(resized_H, resized_W)
            cx = 0.5 * resized_W
            cy = 0.5 * resized_H
            intrinsics = [fx, fy, cx, cy]
            depths = previewer.estimate_depths(frames, intrinsics)
            previews = previewer.render_previews(frames[0], depths[0], intrinsics, w2cs_lerp_4x4)

            uid = uuid4().fields[0]
            preview_path = f"{args.result_dir}/preview_{uid:08x}.mp4"
            os.makedirs(args.result_dir, exist_ok=True)
            imageio.mimsave(preview_path, previews, fps=args.save_fps)

            return preview_path

        preview_btn.click(
            fn=preview,
            inputs=[input_image, camera_pose_type, trace_extract_ratio, frames, height, width],
            outputs=[preview_video],
        )

        def generate(*inputs):
            *inputs, frames, height, width = inputs
            captioner.offload_cpu()
            return image2video.get_image(*inputs, (frames, height, width))

        end_btn.click(
            fn=generate,
            inputs=[model_name, input_image, input_text, negative_prompt, camera_pose_type, preview_video, steps, trace_extract_ratio, trace_scale_factor, camera_cfg, text_cfg, seed, noise_shaping, noise_shaping_minimum_timesteps, frames, height, width],
            outputs=[output_video1],
        )

        end_btn.click(fn=lambda x: x, inputs=[output_video1], outputs=[output_video2])

        reload_btn.click(
            fn=lambda: (gr.Dropdown(choices=load_model_name()), gr.Dropdown(choices=load_camera_pose_type())),
            outputs=[model_name, camera_pose_type]
        )

        switch_aspect_ratio.click(fn=lambda x: x, inputs=[height], outputs=[width])
        switch_aspect_ratio.click(fn=lambda x: x, inputs=[width], outputs=[height])

    return demo


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--save_fps", type=int, default=16)
    parser.add_argument("--result_dir", type=str, default="results")
    parser.add_argument("--model_meta_path", type=str, default="demo/models.json")
    parser.add_argument("--example_meta_path", type=str, default="demo/examples.json")
    parser.add_argument("--camera_pose_meta_path", type=str, default="demo/camera_poses.json")
    parser.add_argument("--depth_model_path", type=str, default="pretrained/Metric3D/metric_depth_vit_large_800k.pth")
    parser.add_argument("--caption_model_path", type=str, default="pretrained/Qwen2.5-VL-7B-Instruct")
    parser.add_argument("--server_name", type=str, default="0.0.0.0")
    parser.add_argument("--server_port", type=int, default=8080)

    return parser


if __name__ == "__main__":
    parser = get_parser()
    args, _ = parser.parse_known_args()

    main(args).launch(server_name=args.server_name, server_port=args.server_port, allowed_paths=["demo"])