SkyReels-V2 / generate_video_df.py
fffiloni's picture
Migrated from GitHub
fc0a183 verified
import argparse
import gc
import os
import random
import time
import imageio
import torch
from diffusers.utils import load_image
from skyreels_v2_infer import DiffusionForcingPipeline
from skyreels_v2_infer.modules import download_model
from skyreels_v2_infer.pipelines import PromptEnhancer
from skyreels_v2_infer.pipelines import resizecrop
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--outdir", type=str, default="diffusion_forcing")
parser.add_argument("--model_id", type=str, default="Skywork/SkyReels-V2-DF-1.3B-540P")
parser.add_argument("--resolution", type=str, choices=["540P", "720P"])
parser.add_argument("--num_frames", type=int, default=97)
parser.add_argument("--image", type=str, default=None)
parser.add_argument("--ar_step", type=int, default=0)
parser.add_argument("--causal_attention", action="store_true")
parser.add_argument("--causal_block_size", type=int, default=1)
parser.add_argument("--base_num_frames", type=int, default=97)
parser.add_argument("--overlap_history", type=int, default=None)
parser.add_argument("--addnoise_condition", type=int, default=0)
parser.add_argument("--guidance_scale", type=float, default=6.0)
parser.add_argument("--shift", type=float, default=8.0)
parser.add_argument("--inference_steps", type=int, default=30)
parser.add_argument("--use_usp", action="store_true")
parser.add_argument("--offload", action="store_true")
parser.add_argument("--fps", type=int, default=24)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument(
"--prompt",
type=str,
default="A woman in a leather jacket and sunglasses riding a vintage motorcycle through a desert highway at sunset, her hair blowing wildly in the wind as the motorcycle kicks up dust, with the golden sun casting long shadows across the barren landscape.",
)
parser.add_argument("--prompt_enhancer", action="store_true")
parser.add_argument("--teacache", action="store_true")
parser.add_argument(
"--teacache_thresh",
type=float,
default=0.2,
help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup",
)
parser.add_argument(
"--use_ret_steps",
action="store_true",
help="Using Retention Steps will result in faster generation speed and better generation quality.",
)
args = parser.parse_args()
args.model_id = download_model(args.model_id)
print("model_id:", args.model_id)
assert (args.use_usp and args.seed is not None) or (not args.use_usp), "usp mode need seed"
if args.seed is None:
random.seed(time.time())
args.seed = int(random.randrange(4294967294))
if args.resolution == "540P":
height = 544
width = 960
elif args.resolution == "720P":
height = 720
width = 1280
else:
raise ValueError(f"Invalid resolution: {args.resolution}")
num_frames = args.num_frames
fps = args.fps
if num_frames > args.base_num_frames:
assert (
args.overlap_history is not None
), 'You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.'
if args.addnoise_condition > 60:
print(
f'You have set "addnoise_condition" as {args.addnoise_condition}. The value is too large which can cause inconsistency in long video generation. The value is recommanded to set 20.'
)
guidance_scale = args.guidance_scale
shift = args.shift
if args.image:
args.image = load_image(args.image)
image_width, image_height = args.image.size
if image_height > image_width:
height, width = width, height
args.image = resizecrop(args.image, height, width)
image = args.image.convert("RGB") if args.image else None
negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
save_dir = os.path.join("result", args.outdir)
os.makedirs(save_dir, exist_ok=True)
local_rank = 0
if args.use_usp:
assert (
not args.prompt_enhancer
), "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter."
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
import torch.distributed as dist
dist.init_process_group("nccl")
local_rank = dist.get_rank()
torch.cuda.set_device(dist.get_rank())
device = "cuda"
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=1,
ulysses_degree=dist.get_world_size(),
)
prompt_input = args.prompt
if args.prompt_enhancer and args.image is None:
print(f"init prompt enhancer")
prompt_enhancer = PromptEnhancer()
prompt_input = prompt_enhancer(prompt_input)
print(f"enhanced prompt: {prompt_input}")
del prompt_enhancer
gc.collect()
torch.cuda.empty_cache()
pipe = DiffusionForcingPipeline(
args.model_id,
dit_path=args.model_id,
device=torch.device("cuda"),
weight_dtype=torch.bfloat16,
use_usp=args.use_usp,
offload=args.offload,
)
if args.causal_attention:
pipe.transformer.set_ar_attention(args.causal_block_size)
if args.teacache:
if args.ar_step > 0:
num_steps = (
args.inference_steps
+ (((args.base_num_frames - 1) // 4 + 1) // args.causal_block_size - 1) * args.ar_step
)
print("num_steps:", num_steps)
else:
num_steps = args.inference_steps
pipe.transformer.initialize_teacache(
enable_teacache=True,
num_steps=num_steps,
teacache_thresh=args.teacache_thresh,
use_ret_steps=args.use_ret_steps,
ckpt_dir=args.model_id,
)
print(f"prompt:{prompt_input}")
print(f"guidance_scale:{guidance_scale}")
with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad():
video_frames = pipe(
prompt=prompt_input,
negative_prompt=negative_prompt,
image=image,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=args.inference_steps,
shift=shift,
guidance_scale=guidance_scale,
generator=torch.Generator(device="cuda").manual_seed(args.seed),
overlap_history=args.overlap_history,
addnoise_condition=args.addnoise_condition,
base_num_frames=args.base_num_frames,
ar_step=args.ar_step,
causal_block_size=args.causal_block_size,
fps=fps,
)[0]
if local_rank == 0:
current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
video_out_file = f"{args.prompt[:100].replace('/','')}_{args.seed}_{current_time}.mp4"
output_path = os.path.join(save_dir, video_out_file)
imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"])