Spaces:
Paused
Paused
""" | |
The original experimental code for this project can be found at: | |
https://gist.github.com/a-r-r-o-w/d070cce059ab4ceab3a9f289ff83c69c | |
By using this code, description prompts will be generated through a local large language model, and images will be | |
generated using the black-forest-labs/FLUX.1-dev model, followed by video generation via CogVideoX. | |
The entire process utilizes open-source solutions, without the need for any API keys. | |
You can use the generate.sh file in the same folder to automate running this code | |
for batch generation of videos and images. | |
bash generate.sh | |
""" | |
import argparse | |
import gc | |
import json | |
import os | |
import pathlib | |
import random | |
from typing import Any, Dict | |
from transformers import AutoTokenizer | |
os.environ["TORCH_LOGS"] = "+dynamo,recompiles,graph_breaks" | |
os.environ["TORCHDYNAMO_VERBOSE"] = "1" | |
import numpy as np | |
import torch | |
import transformers | |
from diffusers import CogVideoXImageToVideoPipeline, CogVideoXDPMScheduler, DiffusionPipeline | |
from diffusers.utils.logging import get_logger | |
from diffusers.utils import export_to_video | |
torch.set_float32_matmul_precision("high") | |
logger = get_logger(__name__) | |
SYSTEM_PROMPT = """ | |
You are part of a team of people that create videos using generative models. You use a video-generation model that can generate a video about anything you describe. | |
For example, if you respond with "A beautiful morning in the woods with the sun peaking through the trees", the video generation model will create a video of exactly as described. You task is to summarize the descriptions of videos provided to by users, and create details prompts to feed into the generative model. | |
There are a few rules to follow: | |
- You will only ever output a single video description per request. | |
- If the user mentions to summarize the prompt in [X] words, make sure to not exceed the limit. | |
You responses should just be the video generation prompt. Here are examples: | |
- “A lone figure stands on a city rooftop at night, gazing up at the full moon. The moon glows brightly, casting a gentle light over the quiet cityscape. Below, the windows of countless homes shine with warm lights, creating a contrast between the bustling life below and the peaceful solitude above. The scene captures the essence of the Mid-Autumn Festival, where despite the distance, the figure feels connected to loved ones through the shared beauty of the moonlit sky.” | |
- "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting." | |
- "A street artist, clad in a worn-out denim jacket and a colorful banana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall" | |
""".strip() | |
USER_PROMPT = """ | |
Could you generate a prompt for a video generation model? | |
Please limit the prompt to [{0}] words. | |
""".strip() | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--num_videos", | |
type=int, | |
default=5, | |
help="Number of unique videos you would like to generate." | |
) | |
parser.add_argument( | |
"--model_path", | |
type=str, | |
default="THUDM/CogVideoX-5B", | |
help="The path of Image2Video CogVideoX-5B", | |
) | |
parser.add_argument( | |
"--caption_generator_model_id", | |
type=str, | |
default="THUDM/glm-4-9b-chat", | |
help="Caption generation model. default GLM-4-9B", | |
) | |
parser.add_argument( | |
"--caption_generator_cache_dir", | |
type=str, | |
default=None, | |
help="Cache directory for caption generation model." | |
) | |
parser.add_argument( | |
"--image_generator_model_id", | |
type=str, | |
default="black-forest-labs/FLUX.1-dev", | |
help="Image generation model." | |
) | |
parser.add_argument( | |
"--image_generator_cache_dir", | |
type=str, | |
default=None, | |
help="Cache directory for image generation model." | |
) | |
parser.add_argument( | |
"--image_generator_num_inference_steps", | |
type=int, | |
default=50, | |
help="Caption generation model." | |
) | |
parser.add_argument( | |
"--guidance_scale", | |
type=float, | |
default=7, | |
help="Guidance scale to be use for generation." | |
) | |
parser.add_argument( | |
"--use_dynamic_cfg", | |
action="store_true", | |
help="Whether or not to use cosine dynamic guidance for generation [Recommended].", | |
) | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
default="outputs/", | |
help="Location where generated images and videos should be stored.", | |
) | |
parser.add_argument( | |
"--compile", | |
action="store_true", | |
help="Whether or not to compile the transformer of image and video generators." | |
) | |
parser.add_argument( | |
"--enable_vae_tiling", | |
action="store_true", | |
help="Whether or not to use VAE tiling when encoding/decoding." | |
) | |
parser.add_argument( | |
"--seed", | |
type=int, | |
default=42, | |
help="Seed for reproducibility." | |
) | |
return parser.parse_args() | |
def reset_memory(): | |
gc.collect() | |
torch.cuda.empty_cache() | |
torch.cuda.reset_peak_memory_stats() | |
torch.cuda.reset_accumulated_memory_stats() | |
def main(args: Dict[str, Any]) -> None: | |
output_dir = pathlib.Path(args.output_dir) | |
os.makedirs(output_dir.as_posix(), exist_ok=True) | |
random.seed(args.seed) | |
np.random.seed(args.seed) | |
torch.manual_seed(args.seed) | |
torch.cuda.manual_seed_all(args.seed) | |
reset_memory() | |
tokenizer = AutoTokenizer.from_pretrained(args.caption_generator_model_id, trust_remote_code=True) | |
caption_generator = transformers.pipeline( | |
"text-generation", | |
model=args.caption_generator_model_id, | |
device_map="auto", | |
model_kwargs={ | |
"local_files_only": True, | |
"cache_dir": args.caption_generator_cache_dir, | |
"torch_dtype": torch.bfloat16, | |
}, | |
trust_remote_code=True, | |
tokenizer=tokenizer | |
) | |
captions = [] | |
for i in range(args.num_videos): | |
num_words = random.choice([50, 75, 100]) | |
user_prompt = USER_PROMPT.format(num_words) | |
messages = [ | |
{"role": "system", "content": SYSTEM_PROMPT}, | |
{"role": "user", "content": user_prompt}, | |
] | |
outputs = caption_generator(messages, max_new_tokens=226) | |
caption = outputs[0]["generated_text"][-1]["content"] | |
if caption.startswith("\"") and caption.endswith("\""): | |
caption = caption[1:-1] | |
captions.append(caption) | |
logger.info(f"Generated caption: {caption}") | |
with open(output_dir / "captions.json", "w") as file: | |
json.dump(captions, file) | |
del caption_generator | |
reset_memory() | |
image_generator = DiffusionPipeline.from_pretrained( | |
args.image_generator_model_id, | |
cache_dir=args.image_generator_cache_dir, | |
torch_dtype=torch.bfloat16 | |
) | |
image_generator.to("cuda") | |
if args.compile: | |
image_generator.transformer = torch.compile(image_generator.transformer, mode="max-autotune", fullgraph=True) | |
if args.enable_vae_tiling: | |
image_generator.vae.enable_tiling() | |
images = [] | |
for index, caption in enumerate(captions): | |
image = image_generator( | |
prompt=caption, | |
height=480, | |
width=720, | |
num_inference_steps=args.image_generator_num_inference_steps, | |
guidance_scale=3.5, | |
).images[0] | |
filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_") | |
image.save(output_dir / f"{index}_{filename}.png") | |
images.append(image) | |
del image_generator | |
reset_memory() | |
video_generator = CogVideoXImageToVideoPipeline.from_pretrained( | |
args.model_path, torch_dtype=torch.bfloat16).to("cuda") | |
video_generator.scheduler = CogVideoXDPMScheduler.from_config( | |
video_generator.scheduler.config, | |
timestep_spacing="trailing") | |
if args.compile: | |
video_generator.transformer = torch.compile(video_generator.transformer, mode="max-autotune", fullgraph=True) | |
if args.enable_vae_tiling: | |
video_generator.vae.enable_tiling() | |
generator = torch.Generator().manual_seed(args.seed) | |
for index, (caption, image) in enumerate(zip(captions, images)): | |
video = video_generator( | |
image=image, | |
prompt=caption, | |
height=480, | |
width=720, | |
num_frames=49, | |
num_inference_steps=50, | |
guidance_scale=args.guidance_scale, | |
use_dynamic_cfg=args.use_dynamic_cfg, | |
generator=generator, | |
).frames[0] | |
filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_") | |
export_to_video(video, output_dir / f"{index}_{filename}.mp4", fps=8) | |
if __name__ == "__main__": | |
args = get_args() | |
main(args) | |