import argparse import gc import os import random import time import imageio import torch from diffusers.utils import load_image from skyreels_v2_infer import DiffusionForcingPipeline from skyreels_v2_infer.modules import download_model from skyreels_v2_infer.pipelines import PromptEnhancer from skyreels_v2_infer.pipelines import resizecrop if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--outdir", type=str, default="diffusion_forcing") parser.add_argument("--model_id", type=str, default="Skywork/SkyReels-V2-DF-1.3B-540P") parser.add_argument("--resolution", type=str, choices=["540P", "720P"]) parser.add_argument("--num_frames", type=int, default=97) parser.add_argument("--image", type=str, default=None) parser.add_argument("--ar_step", type=int, default=0) parser.add_argument("--causal_attention", action="store_true") parser.add_argument("--causal_block_size", type=int, default=1) parser.add_argument("--base_num_frames", type=int, default=97) parser.add_argument("--overlap_history", type=int, default=None) parser.add_argument("--addnoise_condition", type=int, default=0) parser.add_argument("--guidance_scale", type=float, default=6.0) parser.add_argument("--shift", type=float, default=8.0) parser.add_argument("--inference_steps", type=int, default=30) parser.add_argument("--use_usp", action="store_true") parser.add_argument("--offload", action="store_true") parser.add_argument("--fps", type=int, default=24) parser.add_argument("--seed", type=int, default=None) parser.add_argument( "--prompt", type=str, default="A woman in a leather jacket and sunglasses riding a vintage motorcycle through a desert highway at sunset, her hair blowing wildly in the wind as the motorcycle kicks up dust, with the golden sun casting long shadows across the barren landscape.", ) parser.add_argument("--prompt_enhancer", action="store_true") parser.add_argument("--teacache", action="store_true") parser.add_argument( "--teacache_thresh", type=float, default=0.2, help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup", ) parser.add_argument( "--use_ret_steps", action="store_true", help="Using Retention Steps will result in faster generation speed and better generation quality.", ) args = parser.parse_args() args.model_id = download_model(args.model_id) print("model_id:", args.model_id) assert (args.use_usp and args.seed is not None) or (not args.use_usp), "usp mode need seed" if args.seed is None: random.seed(time.time()) args.seed = int(random.randrange(4294967294)) if args.resolution == "540P": height = 544 width = 960 elif args.resolution == "720P": height = 720 width = 1280 else: raise ValueError(f"Invalid resolution: {args.resolution}") num_frames = args.num_frames fps = args.fps if num_frames > args.base_num_frames: assert ( args.overlap_history is not None ), 'You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.' if args.addnoise_condition > 60: print( f'You have set "addnoise_condition" as {args.addnoise_condition}. The value is too large which can cause inconsistency in long video generation. The value is recommanded to set 20.' ) guidance_scale = args.guidance_scale shift = args.shift if args.image: args.image = load_image(args.image) image_width, image_height = args.image.size if image_height > image_width: height, width = width, height args.image = resizecrop(args.image, height, width) image = args.image.convert("RGB") if args.image else None negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" save_dir = os.path.join("result", args.outdir) os.makedirs(save_dir, exist_ok=True) local_rank = 0 if args.use_usp: assert ( not args.prompt_enhancer ), "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter." from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment import torch.distributed as dist dist.init_process_group("nccl") local_rank = dist.get_rank() torch.cuda.set_device(dist.get_rank()) device = "cuda" init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) initialize_model_parallel( sequence_parallel_degree=dist.get_world_size(), ring_degree=1, ulysses_degree=dist.get_world_size(), ) prompt_input = args.prompt if args.prompt_enhancer and args.image is None: print(f"init prompt enhancer") prompt_enhancer = PromptEnhancer() prompt_input = prompt_enhancer(prompt_input) print(f"enhanced prompt: {prompt_input}") del prompt_enhancer gc.collect() torch.cuda.empty_cache() pipe = DiffusionForcingPipeline( args.model_id, dit_path=args.model_id, device=torch.device("cuda"), weight_dtype=torch.bfloat16, use_usp=args.use_usp, offload=args.offload, ) if args.causal_attention: pipe.transformer.set_ar_attention(args.causal_block_size) if args.teacache: if args.ar_step > 0: num_steps = ( args.inference_steps + (((args.base_num_frames - 1) // 4 + 1) // args.causal_block_size - 1) * args.ar_step ) print("num_steps:", num_steps) else: num_steps = args.inference_steps pipe.transformer.initialize_teacache( enable_teacache=True, num_steps=num_steps, teacache_thresh=args.teacache_thresh, use_ret_steps=args.use_ret_steps, ckpt_dir=args.model_id, ) print(f"prompt:{prompt_input}") print(f"guidance_scale:{guidance_scale}") with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad(): video_frames = pipe( prompt=prompt_input, negative_prompt=negative_prompt, image=image, height=height, width=width, num_frames=num_frames, num_inference_steps=args.inference_steps, shift=shift, guidance_scale=guidance_scale, generator=torch.Generator(device="cuda").manual_seed(args.seed), overlap_history=args.overlap_history, addnoise_condition=args.addnoise_condition, base_num_frames=args.base_num_frames, ar_step=args.ar_step, causal_block_size=args.causal_block_size, fps=fps, )[0] if local_rank == 0: current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) video_out_file = f"{args.prompt[:100].replace('/','')}_{args.seed}_{current_time}.mp4" output_path = os.path.join(save_dir, video_out_file) imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"])