Spaces:
Running
Running
import argparse | |
import os | |
import safetensors | |
import torch | |
from diffusers.utils import export_to_video, load_image | |
from cogvideox_interpolation.pipeline import CogVideoXInterpolationPipeline | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Video interpolation with different checkpoints') | |
parser.add_argument('--model_path', type=str, help='Path to the base model') | |
parser.add_argument('--base_ckpt_path', type=str, default=None, help='Base path for checkpoints') | |
parser.add_argument('--output_dir', type=str, help='Directory for output videos') | |
parser.add_argument('--first_image', type=str, help='Path to the first image') | |
parser.add_argument('--last_image', type=str, help='Path to the last image') | |
return parser.parse_args() | |
def main(): | |
args = parse_args() | |
pipe = CogVideoXInterpolationPipeline.from_pretrained( | |
args.model_path, | |
torch_dtype=torch.bfloat16 | |
) | |
pipe.enable_sequential_cpu_offload() | |
pipe.vae.enable_tiling() | |
pipe.vae.enable_slicing() | |
prompt = 'a 3D consistent video scene' | |
# checkpoints = ['ori', 800] | |
checkpoints = [800] | |
os.makedirs(args.output_dir, exist_ok=True) | |
for ckpt_num in checkpoints: | |
print(f"Processing checkpoint-{ckpt_num}") | |
pipe = CogVideoXInterpolationPipeline.from_pretrained( | |
args.model_path, | |
torch_dtype=torch.bfloat16 | |
) | |
if args.base_ckpt_path is not None: | |
# ckpt_path = os.path.join(args.base_ckpt_path, f"checkpoint-{ckpt_num}") | |
ckpt_path = args.base_ckpt_path | |
state_dict = safetensors.torch.load_file(os.path.join(ckpt_path, "model.safetensors")) | |
pipe.transformer.load_state_dict(state_dict) | |
pipe.enable_sequential_cpu_offload() | |
pipe.vae.enable_tiling() | |
pipe.vae.enable_slicing() | |
first_image = load_image(args.first_image) | |
last_image = load_image(args.last_image) | |
videos = pipe( | |
prompt=prompt, | |
first_image=first_image, | |
last_image=last_image, | |
num_videos_per_prompt=50, | |
num_inference_steps=50, | |
num_frames=49, | |
guidance_scale=6, | |
generator=torch.Generator(device="cuda").manual_seed(42), | |
) | |
video = videos[0] | |
prefix = "ori" if ckpt_num == 'ori' else "video" | |
output_path = os.path.join(args.output_dir, f"{prefix}_ckpt_{ckpt_num}.mp4") | |
export_to_video(video[0], output_path, fps=8) | |
print(f"{prefix}_ckpt_{ckpt_num}.mp4 saved") | |
del pipe | |
torch.cuda.empty_cache() | |
print("All checkpoints processing completed!") | |
if __name__ == "__main__": | |
main() |