import argparse
import json
import os
from uuid import uuid4
import gradio as gr
import imageio
import numpy as np
import torch
from einops import rearrange
from PIL import Image
from image_to_video import Image2Video, _resize_for_rectangle_crop, camera_pose_lerp
from demo.preview import Previewer
from demo.prompt_extend import QwenPromptExpander
torch.backends.cuda.matmul.allow_tf32 = True
def load_model_name():
with open(args.model_meta_path, "r") as f:
data = json.load(f)
return list(filter(lambda x: "interp" not in x, data.keys()))
def load_camera_pose_type():
with open(args.camera_pose_meta_path, "r") as f:
data = json.load(f)
return list(data.keys())
def main(args):
captioner = QwenPromptExpander(args.caption_model_path, device=args.device)
previewer = Previewer(args.depth_model_path)
image2video = Image2Video(
args.result_dir,
args.model_meta_path,
args.camera_pose_meta_path,
args.save_fps,
device=args.device,
)
with gr.Blocks(analytics_enabled=False, css=r"""
#input_img img {height: 480px !important;}
#output_vid video {width: auto !important; margin: auto !important;}
""") as demo:
gr.Markdown("""
RealCam-I2V (CogVideoX-1.5-5B-I2V)
""")
with gr.Row(equal_height=True):
input_image = gr.Image(label="Input Image")
output_3d = gr.Model3D(label="Camera Trajectory", clear_color=[1.0, 1.0, 1.0, 1.0], visible=False)
preview_video = gr.Video(label="Preview Video", interactive=False, autoplay=True, loop=True)
output_video1 = gr.Video(label="New Generated Video", elem_id="output_vid", interactive=False, autoplay=True, loop=True)
output_video2 = gr.Video(label="Previous Generated Video", elem_id="output_vid", interactive=False, autoplay=True, loop=True, visible=False)
with gr.Row():
reload_btn = gr.Button("Reload")
preview_btn = gr.Button("Preview")
end_btn = gr.Button("Generate")
with gr.Row(equal_height=True):
input_text = gr.Textbox(label='Prompt', scale=4)
caption_btn = gr.Button("Caption")
with gr.Row():
negative_prompt = gr.Textbox(label='Negative Prompt', value="Fast movement, jittery motion, abrupt transitions, distorted body, missing limbs, unnatural posture, blurry, cropped, extra limbs, bad anatomy, deformed, glitchy motion, artifacts.")
with gr.Row(equal_height=True):
with gr.Column():
model_name = gr.Dropdown(label='Model Name', choices=load_model_name())
camera_pose_type = gr.Dropdown(label='Camera Pose Type', choices=load_camera_pose_type())
seed = gr.Slider(label="Random Seed", minimum=0, maximum=2**31, step=1, value=12333)
with gr.Column(scale=2):
with gr.Row():
steps = gr.Slider(minimum=1, maximum=250, step=1, label="Sampling Steps (DDPM)", value=25)
text_cfg = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Text CFG', value=5)
camera_cfg = gr.Slider(minimum=1.0, maximum=5.0, step=0.1, label="Camera CFG", value=1.0, visible=False)
with gr.Row():
trace_extract_ratio = gr.Slider(minimum=0, maximum=1.0, step=0.1, label="Trace Extract Ratio", value=1.0)
trace_scale_factor = gr.Slider(minimum=0, maximum=5, step=0.1, label="Camera Trace Scale Factor", value=1.0)
with gr.Row(equal_height=True):
frames = gr.Slider(minimum=17, maximum=161, step=16, label="Video Frames", value=49)
height = gr.Slider(minimum=256, maximum=1360, step=16, label="Video Height", value=512)
width = gr.Slider(minimum=448, maximum=1360, step=16, label="Video Width", value=896)
switch_aspect_ratio = gr.Button("Switch HW")
with gr.Row(visible=False):
noise_shaping = gr.Checkbox(label='Enable Noise Shaping', value=False)
noise_shaping_minimum_timesteps = gr.Slider(minimum=0, maximum=1000, step=1, label="Noise Shaping Minimum Timesteps", value=900)
input_image.upload(fn=lambda : "", outputs=[input_text])
def caption(*inputs):
image2video.offload_cpu()
prompt, image = inputs
return captioner(prompt, tar_lang="en", image=Image.fromarray(image)).prompt
caption_btn.click(fn=caption, inputs=[input_text, input_image], outputs=[input_text])
def preview(input_image, camera_pose_type, trace_extract_ratio, sample_frames, sample_height, sample_width):
frames = rearrange(torch.from_numpy(input_image), "h w c -> c 1 h w")
frames, resized_H, resized_W = _resize_for_rectangle_crop(frames, sample_height, sample_width)
frames = rearrange(frames, "c 1 h w -> 1 h w c").numpy()
with open(args.camera_pose_meta_path, "r", encoding="utf-8") as f:
camera_pose_file_path = json.load(f)[camera_pose_type]
camera_data = torch.from_numpy(np.loadtxt(camera_pose_file_path, comments="https")) # t, 17
w2cs_3x4 = camera_data[:, 7:].reshape(-1, 3, 4)
dummy = torch.tensor([[[0, 0, 0, 1]]] * w2cs_3x4.shape[0])
w2cs_4x4 = torch.cat([w2cs_3x4, dummy], dim=1)
c2ws_4x4 = w2cs_4x4.inverse()
c2ws_lerp_4x4 = camera_pose_lerp(c2ws_4x4, round(sample_frames / trace_extract_ratio))[: sample_frames]
w2cs_lerp_4x4 = c2ws_lerp_4x4.inverse().numpy()
fx = fy = 0.5 * max(resized_H, resized_W)
cx = 0.5 * resized_W
cy = 0.5 * resized_H
intrinsics = [fx, fy, cx, cy]
depths = previewer.estimate_depths(frames, intrinsics)
previews = previewer.render_previews(frames[0], depths[0], intrinsics, w2cs_lerp_4x4)
uid = uuid4().fields[0]
preview_path = f"{args.result_dir}/preview_{uid:08x}.mp4"
os.makedirs(args.result_dir, exist_ok=True)
imageio.mimsave(preview_path, previews, fps=args.save_fps)
return preview_path
preview_btn.click(
fn=preview,
inputs=[input_image, camera_pose_type, trace_extract_ratio, frames, height, width],
outputs=[preview_video],
)
def generate(*inputs):
*inputs, frames, height, width = inputs
captioner.offload_cpu()
return image2video.get_image(*inputs, (frames, height, width))
end_btn.click(
fn=generate,
inputs=[model_name, input_image, input_text, negative_prompt, camera_pose_type, preview_video, steps, trace_extract_ratio, trace_scale_factor, camera_cfg, text_cfg, seed, noise_shaping, noise_shaping_minimum_timesteps, frames, height, width],
outputs=[output_video1],
)
end_btn.click(fn=lambda x: x, inputs=[output_video1], outputs=[output_video2])
reload_btn.click(
fn=lambda: (gr.Dropdown(choices=load_model_name()), gr.Dropdown(choices=load_camera_pose_type())),
outputs=[model_name, camera_pose_type]
)
switch_aspect_ratio.click(fn=lambda x: x, inputs=[height], outputs=[width])
switch_aspect_ratio.click(fn=lambda x: x, inputs=[width], outputs=[height])
return demo
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--save_fps", type=int, default=16)
parser.add_argument("--result_dir", type=str, default="results")
parser.add_argument("--model_meta_path", type=str, default="demo/models.json")
parser.add_argument("--example_meta_path", type=str, default="demo/examples.json")
parser.add_argument("--camera_pose_meta_path", type=str, default="demo/camera_poses.json")
parser.add_argument("--depth_model_path", type=str, default="pretrained/Metric3D/metric_depth_vit_large_800k.pth")
parser.add_argument("--caption_model_path", type=str, default="pretrained/Qwen2.5-VL-7B-Instruct")
parser.add_argument("--server_name", type=str, default="0.0.0.0")
parser.add_argument("--server_port", type=int, default=8080)
return parser
if __name__ == "__main__":
parser = get_parser()
args, _ = parser.parse_known_args()
main(args).launch(server_name=args.server_name, server_port=args.server_port, allowed_paths=["demo"])