eawolf2357-git / infernece_i2v_autoreg_glm.py
seawolf2357's picture
Upload folder using huggingface_hub
321d89c verified
import argparse
import logging
import math
import os
import random
import shutil
from datetime import timedelta
from pathlib import Path
from typing import List, Optional, Tuple, Union
from PIL import Image
from diffusers.utils import (
check_min_version,
convert_unet_state_dict_to_peft,
export_to_video,
is_wandb_available,
load_image,
)
from torchvision.transforms import ToPILImage
import torch
from pathlib import PosixPath
from utils.utils import load_model_from_config,load_segmented_safe_weights,control_weight_files
from models.cogvideox_transformer_3d_control import Control3DModel,Controled_CogVideoXTransformer3DModel
from models.pipeline_cogvideox_image2video import Controled_CogVideoXImageToVideoPipeline,Controled_Memory_CogVideoXImageToVideoPipeline
from models.global_local_memory_module import global_local_memory
import diffusers
from diffusers import (
AutoencoderKLCogVideoX,
CogVideoXDPMScheduler,
#CogVideoXImageToVideoPipeline,
CogVideoXTransformer3DModel,
)
from lineart_extractor.annotator.lineart import LineartDetector
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.optimization import get_scheduler
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
from diffusers.training_utils import cast_training_params, free_memory
from diffusers.utils import (
load_image,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module
from torchvision.transforms.functional import center_crop, resize
from torchvision.transforms import InterpolationMode
import torchvision.transforms as TT
import numpy as np
from videoxl.model.builder import load_pretrained_model
from videoxl.mm_utils import tokenizer_image_token, process_images,transform_input_id
from videoxl.constants import IMAGE_TOKEN_INDEX,TOKEN_PERFRAME
try:
import decord
except ImportError:
raise ImportError(
"The `decord` package is required for loading the video dataset. Install with `pip install decord`"
)
decord.bridge.set_bridge("torch")
from utils.autoreg_video_save_function import autoreg_video_save
from decord import VideoReader, cpu
from einops import rearrange
import gc
def _resize_for_rectangle_crop(arr,height,width,video_reshape_mode):
image_size = height,width
reshape_mode = video_reshape_mode
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
arr = resize(
arr,
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
interpolation=InterpolationMode.BICUBIC,
)
else:
arr = resize(
arr,
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
interpolation=InterpolationMode.BICUBIC,
)
h, w = arr.shape[2], arr.shape[3]
arr = arr.squeeze(0)
delta_h = h - image_size[0]
delta_w = w - image_size[1]
if reshape_mode == "random" or reshape_mode == "none":
top = np.random.randint(0, delta_h + 1)
left = np.random.randint(0, delta_w + 1)
elif reshape_mode == "center":
top, left = delta_h // 2, delta_w // 2
else:
raise NotImplementedError
image_size = height, width
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
return arr
def get_frame_length(frame_path):
video_reader = decord.VideoReader(uri = frame_path.as_posix())
video_num_frames = len(video_reader)
return video_num_frames
def proccess_frame(frame_path,frames_start,frames_end):
video_reader = decord.VideoReader(uri = frame_path.as_posix())
video_num_frames = len(video_reader)
start_frame = frames_start
end_frame = frames_end
indices = list(range(start_frame, end_frame))
frames = video_reader.get_batch(indices)
#frames = frames[start_frame: end_frame]
selected_num_frames = frames.shape[0]
print("selected_num_frames",selected_num_frames)
# Choose first (4k + 1) frames as this is how many is required by the VAE
remainder = (3 + (selected_num_frames % 4)) % 4
if remainder != 0:
frames = frames[:-remainder]
selected_num_frames = frames.shape[0]
assert (selected_num_frames - 1) % 4 == 0
# Training transforms
frames = frames.permute(0, 3, 1, 2) # [F, C, H, W]
frames = _resize_for_rectangle_crop(frames,height=args.height,width=args.width,video_reshape_mode="center")
final_frames = frames.contiguous()
return final_frames
def proccess_image(frames):
# Training transforms
frames = frames.unsqueeze(0).permute(0, 3, 1, 2) # [F, C, H, W]
frames = _resize_for_rectangle_crop(frames,height=args.height,width=args.width,video_reshape_mode="center")
final_frames = frames.contiguous()
return final_frames
def encode_sketch(video,pipe):
video = video.to(pipe.vae.device, dtype=pipe.vae.dtype).unsqueeze(0)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = pipe.vae.encode(video).latent_dist
return latent_dist
def process_sketch(sketch,linear_detector,pipe):
sketch = sketch.to("cuda", dtype = torch.bfloat16)
with torch.no_grad():
sketch = linear_detector(sketch,coarse=False)
sketch=(sketch>0.78).float()
sketch=1-sketch
sketch=sketch.repeat(1,3,1,1)
sketch = (sketch - 0.5) / 0.5
sketch=sketch.contiguous()
sketch = sketch.to(pipe.vae.device, dtype=pipe.vae.dtype).unsqueeze(0)
sketch = sketch.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
image = sketch[:, :, :1].clone()
with torch.no_grad():
sketch = pipe.vae.encode(sketch).latent_dist
sketches_first_frame=pipe.vae.encode(image).latent_dist
sketch = sketch.sample() * pipe.vae.config.scaling_factor
sketches_first_frame= sketches_first_frame.sample() * pipe.vae.config.scaling_factor
sketch = sketch.permute(0, 2, 1, 3, 4)
sketch = sketch.to(memory_format=torch.contiguous_format)
sketches_first_frame = sketches_first_frame.permute(0, 2, 1, 3, 4)
sketches_first_frame = sketches_first_frame.to(memory_format=torch.contiguous_format)
return sketch,sketches_first_frame
def process_sketch_image(sketch,linear_detector,pipe):
sketch=torch.tensor(np.array(sketch))
sketch=proccess_image(sketch)
sketch = sketch.to("cuda", dtype = torch.bfloat16)
with torch.no_grad():
sketch = linear_detector(sketch,coarse=False)
sketch=(sketch>0.78).float()
sketch=1-sketch
sketch=sketch.repeat(1,3,1,1)
sketch = (sketch - 0.5) / 0.5
sketch=sketch.contiguous()
sketch = sketch.to(pipe.vae.device, dtype=pipe.vae.dtype).unsqueeze(0)
sketch = sketch.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
with torch.no_grad():
sketch = pipe.vae.encode(sketch).latent_dist
sketch = sketch.sample() * pipe.vae.config.scaling_factor
sketch = sketch.permute(0, 2, 1, 3, 4)
sketch = sketch.to(memory_format=torch.contiguous_format)
return sketch
def log_validation(
pipe,
args,
pipeline_args,
device,
use_glm=False,
global_memory=None,
local_memory=None,
glm=None,
past_latents=None,
):
scheduler_args = {}
idx = pipeline_args.pop("segment", None)
video_key=pipeline_args.pop("video_key", None)
clip_memory=False if idx==0 else True
print("clip_memory",clip_memory)
if "variance_type" in pipe.scheduler.config:
variance_type = pipe.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
pipe = pipe.to(device)
generator = torch.Generator(device=device).manual_seed(args.seed) if args.seed else None
videos = []
os.makedirs(os.path.join(args.output_dir,video_key),exist_ok=True)
video_tensor_path=os.path.join(args.output_dir,video_key)
print(video_tensor_path,"video_tensor_path")
with torch.no_grad():
for _ in range(args.num_validation_videos):
frames_output, past_latents = pipe(**pipeline_args, generator=generator, output_type="pt",
num_inference_steps=50,use_glm=use_glm,
global_memory=global_memory,
local_memory=local_memory,
glm=glm,
video_tensor_path=video_tensor_path,
past_latents=past_latents[:,-4:-2] if (past_latents is not None) else None ,
clip_memory=clip_memory
)
pt_images=frames_output.frames[0]
#TODO here we can choose if we need the first frame or not
pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])])
image_np = VaeImageProcessor.pt_to_numpy(pt_images)
image_pil = VaeImageProcessor.numpy_to_pil(image_np)
videos.append(image_pil)
phase_name = f"inference_{idx}"
video_filenames = []
for i, video in enumerate(videos):
final_output_dir=os.path.join(args.output_dir,video_key)
os.makedirs(final_output_dir,exist_ok=True)
filename = os.path.join(final_output_dir, f"{phase_name}_video.mp4")
export_to_video(video, filename, fps=args.fps)
video_filenames.append(filename)
autoreg_video_save(base_path=final_output_dir,suffix="inference_{}_video.mp4",num_videos=idx+1)
return videos[0][65]
def save_segments(total_frames,segment_length,overlap):
start_frame = 0
segments = []
while start_frame + segment_length <= total_frames:
end_frame = start_frame + segment_length
segments.append((start_frame, end_frame))
start_frame = end_frame - overlap
return segments
def main(args):
os.makedirs(args.output_dir,exist_ok=True)
load_dtype=torch.bfloat16
transformer =Controled_CogVideoXTransformer3DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
torch_dtype=load_dtype,
)
control_config_path = "model_json/control_model_15_small.json"
transformer_control_config = load_model_from_config(control_config_path)
transformer_control = Control3DModel(**transformer_control_config)
control_weight_files=[args.control_weght]
transformer_control = load_segmented_safe_weights(transformer_control, control_weight_files)
transformer_control = transformer_control.to(load_dtype)
linear_detector=LineartDetector("cuda", dtype=torch.bfloat16)
gen_kwargs = {"do_sample": True, "temperature": 1, "top_p": None, "num_beams": 1, "use_cache": True, "max_new_tokens": 2}
# try:
# video_tokenizer, video_model, clip_image_processor, _ = load_pretrained_model(args.llm_model_path, None, "llava_qwen", device_map="cuda",attn_implementation="flash_attention_2")
# except:
video_tokenizer, video_model, clip_image_processor, _ = load_pretrained_model(args.llm_model_path, None, "llava_qwen", device_map="cuda",attn_implementation="sdpa")
video_model.config.beacon_ratio=[8] # you can delete this line to realize random compression of {2,4,8} ratio
vllm_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<image>\nCan you describe the scene and color in anime?<|im_end|>\n<|im_start|>assistant\n"
input_ids = tokenizer_image_token(vllm_prompt, video_tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(video_model.device)
video_model.to( dtype=torch.bfloat16)
glm=global_local_memory()
glm_weight_files=[args.glm_weight]
glm = load_segmented_safe_weights(glm,glm_weight_files)
glm=glm.to(load_dtype)
glm=glm.to("cuda")
print("successful load glm")
pipe = Controled_Memory_CogVideoXImageToVideoPipeline.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=torch.bfloat16,
transformer=transformer,
transformer_control=transformer_control
).to("cuda")
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
del transformer,transformer_control
gc.collect()
torch.cuda.empty_cache()
#pipe.enable_sequential_cpu_offload()
if args.enable_slicing:
pipe.vae.enable_slicing()
if args.enable_tiling:
pipe.vae.enable_tiling()
#pipe = pipe.to("cuda")
import json
with open('test_json/long_testset.json',"r") as json_file:
video_info=json.load(json_file)
for video_key,value in video_info.items():
print('------------')
print(video_key)
validation_prompt=value['prompt']
video_path=PosixPath(value['video_path'])
reference_image_path=str(value["reference_image"])
use_glm=False
i=0
global_image=None
frame_path = video_path
video_num_frames=get_frame_length(frame_path)
segments=save_segments(total_frames=video_num_frames,segment_length=args.max_num_frames,overlap=16)
print(segments)
''''''
past_latents=None
for seg_idx,segment in enumerate(segments):
print(seg_idx)
print(segment)
videos = proccess_frame(frame_path, frames_start=segment[0], frames_end=segment[1])
#print(segment)
sketches,sketches_first_frame = process_sketch(videos,linear_detector,pipe)
torch.cuda.empty_cache()
print("sketches!!!",sketches.shape)
validation_prompt = validation_prompt+" High quality, masterpiece, best quality, highres, ultra-detailed, fantastic."
to_pil=ToPILImage()
if global_image==None:
print("------------------")
print(reference_image_path)
print('------------------')
if reference_image_path != "0":
image=Image.open(reference_image_path).convert("RGB")
global_image=image
sketches_first_frame = process_sketch_image(global_image,linear_detector,pipe)
else:
image=to_pil(videos[0]).convert("RGB")
global_image=image
sketches_first_frame = process_sketch_image(global_image,linear_detector,pipe)
else:
image=global_image
pipeline_args = {
"image": image,
"prompt": validation_prompt,
"guidance_scale": args.guidance_scale,
"use_dynamic_cfg": args.use_dynamic_cfg,
"height": args.height,
"width": args.width,
"sketches": sketches,
"sketches_first_frame":sketches_first_frame,
"num_frames":args.max_num_frames,
"segment": seg_idx,
"video_key":video_key
}
#load the video and process the video
if use_glm:
auto_path=os.path.join(args.output_dir,video_key,"autoreg_video_1.mp4")
vr = VideoReader(auto_path, ctx=cpu(0))
total_frame_num = len(vr)
if total_frame_num>650:
max_frame=650
else:
max_frame=total_frame_num
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frame, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
frames = vr.get_batch(frame_idx).numpy()
print(frames.shape)
global_videos = clip_image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].to(video_model.device, dtype=torch.bfloat16)
local_videos=global_videos[-20:,]
beacon_skip_first = (input_ids == IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[1].item()
with torch.inference_mode():
num_tokens=TOKEN_PERFRAME *global_videos.shape[0]
beacon_skip_last = beacon_skip_first + num_tokens
video_model.generate(input_ids, images=[global_videos], modalities=["video"],beacon_skip_first=beacon_skip_first,beacon_skip_last=beacon_skip_last, **gen_kwargs)
indices=[-9,-5,-1]
global_memory=torch.cat([
torch.cat([rearrange(video_model.past_key_values[i][0], 'b c h w -> b h (c w)') for i in indices],dim=0).unsqueeze(0),
torch.cat([rearrange(video_model.past_key_values[i][1], 'b c h w -> b h (c w)') for i in indices],dim=0).unsqueeze(0)]
,dim=0).unsqueeze(0)
video_model.clear_past_key_values()
video_model.memory.reset()
print(global_memory.shape)
torch.cuda.empty_cache()
num_tokens=TOKEN_PERFRAME *local_videos.shape[0]
beacon_skip_last = beacon_skip_first + num_tokens
video_model.generate(input_ids, images=[local_videos], modalities=["video"],beacon_skip_first=beacon_skip_first,beacon_skip_last=beacon_skip_last, **gen_kwargs)
indices=[-9,-5,-1]
local_memory=torch.cat([
torch.cat([rearrange(video_model.past_key_values[i][0], 'b c h w -> b h (c w)') for i in indices],dim=0).unsqueeze(0),
torch.cat([rearrange(video_model.past_key_values[i][1], 'b c h w -> b h (c w)') for i in indices],dim=0).unsqueeze(0)]
,dim=0).unsqueeze(0)
video_model.clear_past_key_values()
video_model.memory.reset()
del global_videos,local_videos
torch.cuda.empty_cache()
print(local_memory.shape)
else:
global_memory=None
local_memory=None
last_image=log_validation(
pipe=pipe,
args=args,
pipeline_args=pipeline_args,
device="cuda",
use_glm=use_glm,
global_memory=global_memory,
local_memory=local_memory,
glm=glm,
past_latents=past_latents
)
torch.cuda.empty_cache()
use_glm=True
def get_args():
parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.")
parser.add_argument(
"--guidance_scale",
type=float,
default=6,
help="The guidance scale to use while sampling validation videos.",
)
# Model information
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--llm_model_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--control_weght",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--glm_weight",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--use_dynamic_cfg",
action="store_true",
default=False,
help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.",
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument(
"--num_validation_videos",
type=int,
default=1,
help="Number of videos that should be generated during validation per `validation_prompt`.",
)
# Training information
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="cogvideox-i2v-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--height",
type=int,
default=480,
help="All input videos are resized to this height.",
)
parser.add_argument(
"--width",
type=int,
default=720,
help="All input videos are resized to this width.",
)
parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
parser.add_argument(
"--max_num_frames", type=int, default=81, help="All input videos will be truncated to these many frames."
)
parser.add_argument(
"--enable_slicing",
action="store_true",
default=False,
help="Whether or not to use VAE slicing for saving memory.",
)
parser.add_argument(
"--enable_tiling",
action="store_true",
default=False,
help="Whether or not to use VAE tiling for saving memory.",
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
return parser.parse_args()
if __name__=="__main__":
args = get_args()
main(args)