EPiC-fps / inference /cli_demo_camera_i2v_pcd.py
roll-ai's picture
Upload 161 files
b14067d verified
import sys
import os
sys.path.insert(0, os.getcwd())
sys.path.append('.')
sys.path.append('..')
import argparse
import os
import torch
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import (
CogVideoXDDIMScheduler,
CogVideoXDPMScheduler,
AutoencoderKLCogVideoX
)
from diffusers.utils import export_to_video, load_video
from controlnet_pipeline import ControlnetCogVideoXImageToVideoPCDPipeline
from cogvideo_transformer import CustomCogVideoXTransformer3DModel
from cogvideo_controlnet_pcd import CogVideoXControlnetPCD
from training.controlnet_datasets_camera_pcd_mask import RealEstate10KPCDRenderDataset
from torchvision.transforms.functional import to_pil_image
from inference.utils import stack_images_horizontally
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
import cv2
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import cv2
import numpy as np
import torch
def get_black_region_mask_tensor(video_tensor, threshold=2, kernel_size=15):
"""
Generate cleaned binary masks for black regions in a video tensor.
Args:
video_tensor (torch.Tensor): shape (T, H, W, 3), RGB, uint8
threshold (int): pixel intensity threshold to consider a pixel as black (default: 20)
kernel_size (int): morphological kernel size to smooth masks (default: 7)
Returns:
torch.Tensor: binary mask tensor of shape (T, H, W), where 1 indicates black region
"""
video_uint8 = ((video_tensor + 1.0) * 127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1) # shape (T, H, W, C)
video_np = video_uint8.numpy()
T, H, W, _ = video_np.shape
masks = np.empty((T, H, W), dtype=np.uint8)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
for t in range(T):
img = video_np[t] # (H, W, 3), uint8
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
_, mask = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY_INV)
mask_cleaned = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
masks[t] = (mask_cleaned > 0).astype(np.uint8)
return torch.from_numpy(masks)
def maxpool_mask_tensor(mask_tensor):
"""
Apply spatial and temporal max pooling to a binary mask tensor.
Args:
mask_tensor (torch.Tensor): shape (T, H, W), binary mask (0 or 1)
Returns:
torch.Tensor: shape (12, 30, 45), pooled binary mask
"""
T, H, W = mask_tensor.shape
assert T % 12 == 0, "T must be divisible by 12 (e.g., 48)"
assert H % 30 == 0 and W % 45 == 0, "H and W must be divisible by 30 and 45"
# Reshape to (B=T, C=1, H, W) for 2D spatial pooling
x = mask_tensor.unsqueeze(1).float() # (T, 1, H, W)
x_pooled = F.max_pool2d(x, kernel_size=(H // 30, W // 45)) # → (T, 1, 30, 45)
# Temporal pooling: reshape to (12, T//12, 30, 45) and max along dim=1
t_groups = T // 12
x_pooled = x_pooled.view(12, t_groups, 30, 45)
pooled_mask = torch.amax(x_pooled, dim=1) # → (12, 30, 45)
# Add a zero frame at the beginning: shape (1, 30, 45)
zero_frame = torch.zeros_like(pooled_mask[0:1]) # (1, 30, 45)
pooled_mask = torch.cat([zero_frame, pooled_mask], dim=0) # → (13, 30, 45)
return 1 - pooled_mask.int()
def avgpool_mask_tensor(mask_tensor):
"""
Apply spatial and temporal average pooling to a binary mask tensor,
and threshold at 0.5 to retain only majority-active regions.
Args:
mask_tensor (torch.Tensor): shape (T, H, W), binary mask (0 or 1)
Returns:
torch.Tensor: shape (13, 30, 45), pooled binary mask with first frame zeroed
"""
T, H, W = mask_tensor.shape
assert T % 12 == 0, "T must be divisible by 12 (e.g., 48)"
assert H % 30 == 0 and W % 45 == 0, "H and W must be divisible by 30 and 45"
# Spatial average pooling
x = mask_tensor.unsqueeze(1).float() # (T, 1, H, W)
x_pooled = F.avg_pool2d(x, kernel_size=(H // 30, W // 45)) # → (T, 1, 30, 45)
# Temporal pooling
t_groups = T // 12
x_pooled = x_pooled.view(12, t_groups, 30, 45)
pooled_avg = torch.mean(x_pooled, dim=1) # → (12, 30, 45)
# Threshold: keep only when > 0.5
pooled_mask = (pooled_avg > 0.5).int()
# Add zero frame
zero_frame = torch.zeros_like(pooled_mask[0:1])
pooled_mask = torch.cat([zero_frame, pooled_mask], dim=0) # → (13, 30, 45)
return 1 - pooled_mask # inverting as before
@torch.no_grad()
def generate_video(
prompt,
image,
video_root_dir: str,
base_model_path: str,
use_zero_conv: bool,
controlnet_model_path: str,
controlnet_weights: float = 1.0,
controlnet_guidance_start: float = 0.0,
controlnet_guidance_end: float = 1.0,
use_dynamic_cfg: bool = True,
lora_path: str = None,
lora_rank: int = 128,
output_path: str = "./output/",
num_inference_steps: int = 50,
guidance_scale: float = 6.0,
num_videos_per_prompt: int = 1,
dtype: torch.dtype = torch.bfloat16,
seed: int = 42,
num_frames: int = 49,
height: int = 480,
width: int = 720,
start_camera_idx: int = 0,
end_camera_idx: int = 1,
controlnet_transformer_num_attn_heads: int = None,
controlnet_transformer_attention_head_dim: int = None,
controlnet_transformer_out_proj_dim_factor: int = None,
controlnet_transformer_out_proj_dim_zero_init: bool = False,
controlnet_transformer_num_layers: int = 8,
downscale_coef: int = 8,
controlnet_input_channels: int = 6,
infer_with_mask: bool = False,
pool_style: str = 'avg',
pipe_cpu_offload: bool = False,
):
"""
Generates a video based on the given prompt and saves it to the specified path.
Parameters:
- prompt (str): The description of the video to be generated.
- video_root_dir (str): The path to the camera dataset
- annotation_json (str): Name of subset (train.json or test.json)
- base_model_path (str): The path of the pre-trained model to be used.
- controlnet_model_path (str): The path of the pre-trained conrolnet model to be used.
- controlnet_weights (float): Strenght of controlnet
- controlnet_guidance_start (float): The stage when the controlnet starts to be applied
- controlnet_guidance_end (float): The stage when the controlnet end to be applied
- lora_path (str): The path of the LoRA weights to be used.
- lora_rank (int): The rank of the LoRA weights.
- output_path (str): The path where the generated video will be saved.
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
- num_videos_per_prompt (int): Number of videos to generate per prompt.
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
- seed (int): The seed for reproducibility.
"""
os.makedirs(output_path, exist_ok=True)
# 1. Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16).
tokenizer = T5Tokenizer.from_pretrained(
base_model_path, subfolder="tokenizer"
)
text_encoder = T5EncoderModel.from_pretrained(
base_model_path, subfolder="text_encoder"
)
transformer = CustomCogVideoXTransformer3DModel.from_pretrained(
base_model_path, subfolder="transformer"
)
vae = AutoencoderKLCogVideoX.from_pretrained(
base_model_path, subfolder="vae"
)
scheduler = CogVideoXDDIMScheduler.from_pretrained(
base_model_path, subfolder="scheduler"
)
# ControlNet
num_attention_heads_orig = 48 if "5b" in base_model_path.lower() else 30
controlnet_kwargs = {}
if controlnet_transformer_num_attn_heads is not None:
controlnet_kwargs["num_attention_heads"] = args.controlnet_transformer_num_attn_heads
else:
controlnet_kwargs["num_attention_heads"] = num_attention_heads_orig
if controlnet_transformer_attention_head_dim is not None:
controlnet_kwargs["attention_head_dim"] = controlnet_transformer_attention_head_dim
if controlnet_transformer_out_proj_dim_factor is not None:
controlnet_kwargs["out_proj_dim"] = num_attention_heads_orig * controlnet_transformer_out_proj_dim_factor
controlnet_kwargs["out_proj_dim_zero_init"] = controlnet_transformer_out_proj_dim_zero_init
controlnet = CogVideoXControlnetPCD(
num_layers=controlnet_transformer_num_layers,
downscale_coef=downscale_coef,
in_channels=controlnet_input_channels,
use_zero_conv=use_zero_conv,
**controlnet_kwargs,
)
if controlnet_model_path:
ckpt = torch.load(controlnet_model_path, map_location='cpu', weights_only=False)
controlnet_state_dict = {}
for name, params in ckpt['state_dict'].items():
controlnet_state_dict[name] = params
m, u = controlnet.load_state_dict(controlnet_state_dict, strict=False)
print(f'[ Weights from pretrained controlnet was loaded into controlnet ] [M: {len(m)} | U: {len(u)}]')
# Full pipeline
pipe = ControlnetCogVideoXImageToVideoPCDPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
vae=vae,
controlnet=controlnet,
scheduler=scheduler,
).to('cuda')
# If you're using with lora, add this code
if lora_path:
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1")
pipe.fuse_lora(lora_scale=1 / lora_rank)
# 2. Set Scheduler.
# Can be changed to `CogVideoXDPMScheduler` or `CogVideoXDDIMScheduler`.
# We recommend using `CogVideoXDDIMScheduler` for CogVideoX-2B.
# using `CogVideoXDPMScheduler` for CogVideoX-5B / CogVideoX-5B-I2V.
# pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
# 3. Enable CPU offload for the model.
# turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
# and enable to("cuda")
# pipe.to("cuda")
pipe = pipe.to(dtype=dtype)
# pipe.enable_sequential_cpu_offload()
if pipe_cpu_offload:
pipe.enable_model_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
# 4. Load dataset
eval_dataset = RealEstate10KPCDRenderDataset(
video_root_dir=video_root_dir,
image_size=(height, width),
sample_n_frames=num_frames,
)
None_prompt = True
if prompt:
None_prompt = False
print(eval_dataset.dataset)
for camera_idx in range(start_camera_idx, end_camera_idx):
# Get data
data_dict = eval_dataset[camera_idx]
reference_video = data_dict['video']
anchor_video = data_dict['anchor_video']
print(eval_dataset.dataset[camera_idx],seed)
if None_prompt:
# Set output directory
output_path_file = os.path.join(output_path, f"{camera_idx:05d}_{seed}_out.mp4")
prompt = data_dict['caption']
else:
# Set output directory
output_path_file = os.path.join(output_path, f"{prompt[:10]}_{camera_idx:05d}_{seed}_out.mp4")
if image is None:
input_images = reference_video[0].unsqueeze(0)
else:
input_images = torch.tensor(np.array(Image.open(image))).permute(2,0,1).unsqueeze(0)/255
pixel_transforms = [transforms.Resize((480, 720)),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
for transform in pixel_transforms:
input_images = transform(input_images)
# if image is None:
# input_images = reference_video[:24]
# else:
# input_images = torch.tensor(np.array(Image.open(image))).permute(2,0,1)/255
# pixel_transforms = [transforms.Resize((480, 720)),
# transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
# for transform in pixel_transforms:
# input_images = transform(input_images)
reference_frames = [to_pil_image(frame) for frame in ((reference_video)/2+0.5)]
output_path_file_reference = output_path_file.replace("_out.mp4", "_reference.mp4")
output_path_file_out_reference = output_path_file.replace(".mp4", "_reference.mp4")
if infer_with_mask:
try:
video_mask = 1 - torch.from_numpy(np.load(os.path.join(eval_dataset.root_path,'masks',eval_dataset.dataset[camera_idx]+'.npz'))['mask']*1)
except:
print('using derived mask')
video_mask = get_black_region_mask_tensor(anchor_video)
if pool_style == 'max':
controlnet_output_mask = maxpool_mask_tensor(video_mask[1:]).flatten().unsqueeze(0).unsqueeze(-1).to('cuda')
elif pool_style == 'avg':
controlnet_output_mask = avgpool_mask_tensor(video_mask[1:]).flatten().unsqueeze(0).unsqueeze(-1).to('cuda')
else:
controlnet_output_mask = None
# if os.path.isfile(output_path_file):
# continue
# 5. Generate the video frames based on the prompt.
# `num_frames` is the Number of frames to generate.
# This is the default value for 6 seconds video and 8 fps and will plus 1 frame for the first frame and 49 frames.
video_generate_all = pipe(
image=input_images,
anchor_video=anchor_video,
controlnet_output_mask=controlnet_output_mask,
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
num_inference_steps=num_inference_steps, # Number of inference steps
num_frames=num_frames, # Number of frames to generate,changed to 49 for diffusers version `0.30.3` and after.
use_dynamic_cfg=use_dynamic_cfg, # This id used for DPM Sechduler, for DDIM scheduler, it should be False
guidance_scale=guidance_scale,
generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility
controlnet_weights=controlnet_weights,
controlnet_guidance_start=controlnet_guidance_start,
controlnet_guidance_end=controlnet_guidance_end,
).frames
video_generate = video_generate_all[0]
# 6. Export the generated frames to a video file. fps must be 8 for original video.
export_to_video(video_generate, output_path_file, fps=8)
export_to_video(reference_frames, output_path_file_reference, fps=8)
out_reference_frames = [
stack_images_horizontally(frame_reference, frame_out)
for frame_out, frame_reference in zip(video_generate, reference_frames)
]
anchor_video = [to_pil_image(frame) for frame in ((anchor_video)/2+0.5)]
out_reference_frames = [
stack_images_horizontally(frame_out, frame_reference)
for frame_out, frame_reference in zip(out_reference_frames, anchor_video)
]
export_to_video(out_reference_frames, output_path_file_out_reference, fps=8)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
parser.add_argument("--prompt", type=str, default=None, help="The description of the video to be generated")
parser.add_argument("--image", type=str, default=None, help="The reference image of the video to be generated")
parser.add_argument(
"--video_root_dir",
type=str,
required=True,
help="The path of the video for controlnet processing.",
)
parser.add_argument(
"--base_model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used"
)
parser.add_argument(
"--controlnet_model_path", type=str, default="TheDenk/cogvideox-5b-controlnet-hed-v1", help="The path of the controlnet pre-trained model to be used"
)
parser.add_argument("--controlnet_weights", type=float, default=0.5, help="Strenght of controlnet")
parser.add_argument("--use_zero_conv", action="store_true", default=False, help="Use zero conv")
parser.add_argument("--infer_with_mask", action="store_true", default=False, help="add mask to controlnet")
parser.add_argument("--pool_style", default='max', help="max pool or avg pool")
parser.add_argument("--controlnet_guidance_start", type=float, default=0.0, help="The stage when the controlnet starts to be applied")
parser.add_argument("--controlnet_guidance_end", type=float, default=0.5, help="The stage when the controlnet end to be applied")
parser.add_argument("--use_dynamic_cfg", type=bool, default=True, help="Use dynamic cfg")
parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
parser.add_argument("--lora_rank", type=int, default=128, help="The rank of the LoRA weights")
parser.add_argument(
"--output_path", type=str, default="./output", help="The path where the generated video will be saved"
)
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
parser.add_argument(
"--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
)
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
parser.add_argument(
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
)
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
parser.add_argument("--height", type=int, default=480)
parser.add_argument("--width", type=int, default=720)
parser.add_argument("--num_frames", type=int, default=49)
parser.add_argument("--start_camera_idx", type=int, default=0)
parser.add_argument("--end_camera_idx", type=int, default=1)
parser.add_argument("--controlnet_transformer_num_attn_heads", type=int, default=None)
parser.add_argument("--controlnet_transformer_attention_head_dim", type=int, default=None)
parser.add_argument("--controlnet_transformer_out_proj_dim_factor", type=int, default=None)
parser.add_argument("--controlnet_transformer_out_proj_dim_zero_init", action="store_true", default=False, help=("Init project zero."),
)
parser.add_argument("--downscale_coef", type=int, default=8)
parser.add_argument("--vae_channels", type=int, default=16)
parser.add_argument("--controlnet_input_channels", type=int, default=6)
parser.add_argument("--controlnet_transformer_num_layers", type=int, default=8)
parser.add_argument("--enable_model_cpu_offload", action="store_true", default=False, help="Enable model CPU offload")
args = parser.parse_args()
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
generate_video(
prompt=args.prompt,
image=args.image,
video_root_dir=args.video_root_dir,
base_model_path=args.base_model_path,
use_zero_conv=args.use_zero_conv,
controlnet_model_path=args.controlnet_model_path,
controlnet_weights=args.controlnet_weights,
controlnet_guidance_start=args.controlnet_guidance_start,
controlnet_guidance_end=args.controlnet_guidance_end,
use_dynamic_cfg=args.use_dynamic_cfg,
lora_path=args.lora_path,
lora_rank=args.lora_rank,
output_path=args.output_path,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
num_videos_per_prompt=args.num_videos_per_prompt,
dtype=dtype,
seed=args.seed,
height=args.height,
width=args.width,
num_frames=args.num_frames,
start_camera_idx=args.start_camera_idx,
end_camera_idx=args.end_camera_idx,
controlnet_transformer_num_attn_heads=args.controlnet_transformer_num_attn_heads,
controlnet_transformer_attention_head_dim=args.controlnet_transformer_attention_head_dim,
controlnet_transformer_out_proj_dim_factor=args.controlnet_transformer_out_proj_dim_factor,
controlnet_transformer_num_layers=args.controlnet_transformer_num_layers,
downscale_coef=args.downscale_coef,
controlnet_input_channels=args.controlnet_input_channels,
infer_with_mask=args.infer_with_mask,
pool_style=args.pool_style,
pipe_cpu_offload=args.enable_model_cpu_offload,
)