RealCam-I2V / finetune /image_to_video.py
roll-ai's picture
Upload 333 files
e8bdafd verified
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
import importlib
import json
import os
from functools import partial
from pprint import pprint
from uuid import uuid4
import numpy as np
import torch
import torch.nn.functional as F
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.utils.export_utils import export_to_video
from einops import rearrange, repeat
from pytorch_lightning import seed_everything
from torch import Tensor
from torchvision import transforms
from transformers import AutoTokenizer, T5EncoderModel
from torchcodec.decoders import VideoDecoder
def relative_pose(rt: Tensor, mode, ref_index) -> Tensor:
'''
:param rt: F,4,4
:param mode: left or right
:return:
'''
if mode == "left":
rt = rt[ref_index].inverse() @ rt
elif mode == "right":
rt = rt @ rt[ref_index].inverse()
return rt
def camera_pose_lerp(c2w: Tensor, target_frames: int):
weights = torch.linspace(0, c2w.size(0) - 1, target_frames, dtype=c2w.dtype)
left_indices = weights.floor().long()
right_indices = weights.ceil().long()
return torch.lerp(c2w[left_indices], c2w[right_indices], weights.unsqueeze(-1).unsqueeze(-1).frac())
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def _resize_for_rectangle_crop(frames, H, W):
'''
:param frames: C,F,H,W
:param image_size: H,W
:return: frames: C,F,crop_H,crop_W; camera_intrinsics: F,3,3
'''
ori_H, ori_W = frames.shape[-2:]
# if ori_W / ori_H < 1.0:
# tmp_H, tmp_W = int(H), int(W)
# H, W = tmp_W, tmp_H
if ori_W / ori_H > W / H:
frames = transforms.functional.resize(frames, size=[H, int(ori_W * H / ori_H)])
else:
frames = transforms.functional.resize(frames, size=[int(ori_H * W / ori_W), W])
resized_H, resized_W = frames.shape[2], frames.shape[3]
frames = frames.squeeze(0)
delta_H = resized_H - H
delta_W = resized_W - W
top, left = delta_H // 2, delta_W // 2
frames = transforms.functional.crop(frames, top=top, left=left, height=H, width=W)
return frames, resized_H, resized_W
def _resize(frames, H, W):
'''
:param frames: C,F,H,W
:param image_size: H,W
:return: frames: C,F,crop_H,crop_W; camera_intrinsics: F,3,3
'''
frames = transforms.functional.resize(frames, size=[H, W])
resized_H, resized_W = frames.shape[2], frames.shape[3]
frames = frames.squeeze(0)
return frames, resized_H, resized_W
class Image2Video:
def __init__(
self,
result_dir: str = "results",
model_meta_path: str = "models.json",
camera_pose_meta_path: str = "camera_poses.json",
save_fps: int = 16,
device: str = "cuda",
):
self.result_dir = result_dir
self.model_meta_file = model_meta_path
self.camera_pose_meta_path = camera_pose_meta_path
self.save_fps = save_fps
self.device = torch.device(device)
self.pipe = None
def init_model(self, model_name):
from models.camera_controller.cogvideox_with_controlnetxs import CogVideoXTransformer3DModel
from models.camera_controller.controlnetxs import ControlnetXs
with open(self.model_meta_file, "r", encoding="utf-8") as f:
model_metadata = json.load(f)[model_name]
pretrained_model_path = model_metadata["pretrained_model_path"]
controlnetxs_model_path = model_metadata["controlnetxs_model_path"]
self.transformer = CogVideoXTransformer3DModel.from_pretrained(pretrained_model_path, subfolder="transformer", torch_dtype=torch.bfloat16)
self.controlnetxs = ControlnetXs("models/camera_controller/CogVideoX1.5-5B-I2V", self.transformer.config)
self.controlnetxs.load_state_dict(torch.load(controlnetxs_model_path)['module'], strict=True)
self.controlnetxs.to(torch.bfloat16)
# self.controlnetxs.to(torch.float32)
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
self.text_encoder = T5EncoderModel.from_pretrained(pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.bfloat16)
self.vae = AutoencoderKLCogVideoX.from_pretrained(pretrained_model_path, subfolder="vae", torch_dtype=torch.bfloat16)
self.vae_scale_factor_spatial = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.scheduler = CogVideoXDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
self.controlnetxs.eval()
self.text_encoder.eval()
self.vae.eval()
self.transformer.eval()
self.prepare_models()
def prepare_models(self) -> None:
if self.vae is not None:
self.vae.enable_slicing()
self.vae.enable_tiling()
if self.controlnetxs.vae_encoder is not None:
self.controlnetxs.vae_encoder.enable_slicing()
self.controlnetxs.vae_encoder.enable_tiling()
def init_pipe(self):
from models.camera_controller.pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
self.pipe = CogVideoXImageToVideoPipeline(
tokenizer=self.tokenizer,
text_encoder=None,
vae=self.vae,
transformer=self.transformer,
scheduler=self.scheduler
)
self.pipe.scaling_flag = True
self.pipe.to(self.device)
def offload_cpu(self):
if hasattr(self, "transformer"):
self.transformer.cpu()
if hasattr(self, "controlnetxs"):
self.controlnetxs.cpu()
if hasattr(self, "text_encoder"):
self.text_encoder.cpu()
if hasattr(self, "vae"):
self.vae.cpu()
torch.cuda.empty_cache()
def prepare_rotary_positional_embeddings(
self,
height: int,
width: int,
num_frames: int,
transformer_config: dict,
vae_scale_factor_spatial: int,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
if transformer_config.patch_size_t is None:
base_num_frames = num_frames
else:
base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=transformer_config.attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(grid_height, grid_width),
device=device,
)
return freqs_cos, freqs_sin
def validation_step(self, input_kwargs: dict[str]) -> torch.Tensor:
"""
Return the input_kwargs that needs to be saved. For videos, the input_kwargs format is list[PIL],
and for images, the input_kwargs format is PIL
image: shape=(1,c,h,w), value in [0, 1]
"""
plucker_embedding = input_kwargs["plucker_embedding"]
image = input_kwargs["image"]
H, W = image.shape[-2:]
# camera
plucker_embedding = plucker_embedding.to(self.controlnetxs.vae_encoder.device, dtype=self.controlnetxs.vae_encoder.dtype) # [C=6, F, H, W]
latent_plucker_embedding_dist = self.controlnetxs.vae_encoder.encode(plucker_embedding).latent_dist # B,C=6,F,H,W --> B,128,(F-1)//4+1,H//4,W//4
latent_plucker_embedding = latent_plucker_embedding_dist.sample()
latent_plucker_embedding = latent_plucker_embedding.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] to [B, F, C, H, W]
latent_plucker_embedding = latent_plucker_embedding.repeat(2, 1, 1, 1, 1) # cfg
patch_size_t = self.transformer.config.patch_size_t
if patch_size_t is not None:
ncopy = patch_size_t - latent_plucker_embedding.shape[1] % patch_size_t
if ncopy > 0:
# Copy the first frame ncopy times to match patch_size_t
first_frame = latent_plucker_embedding[:, :1, :, :, :] # Get first frame [B, F, C, H, W]
latent_plucker_embedding = torch.cat([first_frame.repeat(1, ncopy, 1, 1, 1), latent_plucker_embedding], dim=1)
if 'latent_scene_frames' in input_kwargs:
input_kwargs['latent_scene_frames'] = torch.cat([
input_kwargs['latent_scene_frames'][:, :1, :, :, :].repeat(1, ncopy, 1, 1, 1),
input_kwargs['latent_scene_frames']
],
dim=1
)
input_kwargs['latent_scene_mask'] = torch.cat([
input_kwargs['latent_scene_mask'][:, :1, :, :, :].repeat(1, ncopy, 1, 1, 1),
input_kwargs['latent_scene_mask']
],
dim=1
)
assert latent_plucker_embedding.shape[1] % patch_size_t == 0
num_latent_frames = latent_plucker_embedding.shape[1]
vae_scale_factor_spatial = 2 ** (len(self.vae.config.block_out_channels) - 1)
rotary_emb_for_controlnetxs = (
self.prepare_rotary_positional_embeddings(
height=H,
width=W,
num_frames=num_latent_frames,
transformer_config=self.controlnetxs.transformer.config,
vae_scale_factor_spatial=vae_scale_factor_spatial,
device=self.device,
)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
self.init_pipe()
original_forward = self.pipe.transformer.forward
self.pipe.transformer.forward = partial(
self.pipe.transformer.forward,
controlnetxs=self.controlnetxs,
latent_plucker_embedding=latent_plucker_embedding,
image_rotary_emb_for_controlnetxs=rotary_emb_for_controlnetxs,
)
forward_kwargs = dict(
num_frames=input_kwargs["video_length"],
height=H,
width=W,
prompt_embeds=input_kwargs["prompt_embedding"],
negative_prompt_embeds=input_kwargs["negative_prompt_embedding"],
image=image.to(self.device).to_dense(),
num_inference_steps=input_kwargs['num_inference_steps'],
guidance_scale=input_kwargs['text_cfg'],
noise_shaping=input_kwargs['noise_shaping'],
noise_shaping_minimum_timesteps = input_kwargs['noise_shaping_minimum_timesteps'],
latent_scene_frames = input_kwargs.get('latent_scene_frames', None), # B,F,C,H,W
latent_scene_mask = input_kwargs.get('latent_scene_mask', None),
generator = input_kwargs['generator'],
)
video_generate = self.pipe(**forward_kwargs).frames[0]
self.pipe.transformer.forward = original_forward
return video_generate
def encode_text(self, prompt: str) -> torch.Tensor:
prompt_token_ids = self.tokenizer(
prompt,
padding="max_length",
max_length=self.transformer.config.max_text_seq_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
prompt_token_ids = prompt_token_ids.input_ids
prompt_embedding = self.text_encoder(prompt_token_ids.to(self.device))[0]
return prompt_embedding.to(torch.bfloat16).to(self.device)
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
# shape of input video: [B, C, F, H, W]
video = video.to(self.vae.device, dtype=self.vae.dtype)
latent_dist = self.vae.encode(video).latent_dist
latent = latent_dist.sample() * self.vae.config.scaling_factor
return latent.to(torch.bfloat16).to(self.device)
@torch.inference_mode()
def get_image(
self,
model_name: str,
ref_img: np.ndarray,
prompt: str,
negative_prompt: str,
camera_pose_type: str,
preview_video: str = None,
steps: int = 25,
trace_extract_ratio: float = 1.0,
trace_scale_factor: float = 1.0,
camera_cfg: float = 1.0,
text_cfg: float = 6.0,
seed: int = 123,
noise_shaping: bool = False,
noise_shaping_minimum_timesteps: int = 800,
video_shape: tuple[int, int, int] = (81, 768, 1360),
resize_for_rectangle_crop: bool = True,
):
if self.pipe is None:
self.init_model(model_name)
video_length, self.sample_height, self.sample_width = video_shape
print(video_length, self.sample_height, self.sample_width)
seed_everything(seed)
input_kwargs = {
'video_length': video_length,
'camera_cfg': camera_cfg,
'num_inference_steps': steps,
'text_cfg': text_cfg,
'noise_shaping': noise_shaping,
'noise_shaping_minimum_timesteps': noise_shaping_minimum_timesteps,
'generator': torch.Generator(device=self.device).manual_seed(seed)
}
ref_img = rearrange(torch.from_numpy(ref_img), 'h w c -> c 1 h w')
if resize_for_rectangle_crop:
ref_img, resized_H, resized_W = _resize_for_rectangle_crop(
ref_img, self.sample_height, self.sample_width,
)
else:
ref_img, resized_H, resized_W = _resize(
ref_img, self.sample_height, self.sample_width,
)
ref_img = rearrange(ref_img / 255, "c 1 h w -> 1 c h w")
H, W = ref_img.shape[-2:]
input_kwargs["image"] = ref_img.to(self.device).to(torch.bfloat16)
with open(self.camera_pose_meta_path, "r", encoding="utf-8") as f:
camera_pose_file_path = json.load(f)[camera_pose_type]
camera_data = torch.from_numpy(np.loadtxt(camera_pose_file_path, comments="https")) # t, -1
fx = 0.5 * max(resized_H, resized_W)
fy = fx
cx = 0.5 * W
cy = 0.5 * H
intrinsics_matrix = torch.tensor([
[fx, 0, cx],
[0, fy, cy],
[0, 0, 1.0]
])
w2cs_3x4 = camera_data[:, 7:].reshape(-1, 3, 4) # [t, 3, 4]
dummy = torch.tensor([[[0, 0, 0, 1]]] * w2cs_3x4.shape[0]) # [t, 1, 4]
w2cs_4x4 = torch.cat([w2cs_3x4, dummy], dim=1) # [t, 4, 4]
c2ws_4x4 = w2cs_4x4.inverse() # [t, 4, 4]
c2ws_lerp_4x4 = camera_pose_lerp(c2ws_4x4, round(video_length / trace_extract_ratio))[: video_length]
from utils.camera_utils import get_camera_condition
plucker_embedding, relative_c2w_RT_4x4 = get_camera_condition(
H, W, intrinsics_matrix[None, None], c2ws_lerp_4x4[None], mode="c2w",
cond_frame_index=0, align_factor=trace_scale_factor
) # [B=1, C=6, F, H, W]
input_kwargs["plucker_embedding"] = plucker_embedding.to(self.device).to(torch.bfloat16)
uid = uuid4().fields[0]
if noise_shaping:
scene_frames = VideoDecoder(preview_video, device=str(self.device))[:]
scene_frames = rearrange(scene_frames / 255 * 2 - 1, "t c h w -> c t h w") # c,f,h,w, value in [-1, 1]
latent_scene_frames = self.encode_video(scene_frames.unsqueeze(0)) # b=1,c,f,h,w
input_kwargs['latent_scene_frames'] = latent_scene_frames.permute(0, 2, 1, 3, 4) # b=1,c,f,h,w --> b=1,f,c,h,w
from models.camera_controller.utils import apply_thresholded_conv
scene_mask = (scene_frames < 1).float().amax(0, keepdim=True) # c,f,h,w --> 1,f,h,w
scene_mask = apply_thresholded_conv(scene_mask, kernel_size=5, threshold=1.0) # 1,f,h,w
latent_scene_mask = torch.cat([
F.interpolate(scene_mask[:, :1].unsqueeze(1), (1, H // 8, W // 8), mode="trilinear", align_corners=True),
F.interpolate(scene_mask[:, 1:].unsqueeze(1), ((video_length - 1) // 4, H // 8, W // 8), mode="trilinear", align_corners=True)
], dim=2).bool()
input_kwargs['latent_scene_mask'] = latent_scene_mask.permute(0, 2, 1, 3, 4)
self.vae.cpu()
self.transformer.cpu()
self.controlnetxs.cpu()
torch.cuda.empty_cache()
self.text_encoder.to(self.device)
input_kwargs |= {
"prompt_embedding": self.encode_text(prompt),
"negative_prompt_embedding": self.encode_text(negative_prompt),
}
self.text_encoder.cpu()
torch.cuda.empty_cache()
self.vae.to(self.device)
self.transformer.to(self.device)
self.controlnetxs.to(self.device)
generated_video = self.validation_step(input_kwargs)
video_path = f"{self.result_dir}/{model_name}_{uid:08x}.mp4"
os.makedirs(self.result_dir, exist_ok=True)
export_to_video(generated_video, video_path, fps=self.save_fps)
torch.cuda.empty_cache()
return video_path