Spaces:
Running
on
L40S
Running
on
L40S
import argparse | |
import gc | |
import os | |
import random | |
import time | |
import imageio | |
import torch | |
from diffusers.utils import load_image | |
from skyreels_v2_infer import DiffusionForcingPipeline | |
from skyreels_v2_infer.modules import download_model | |
from skyreels_v2_infer.pipelines import PromptEnhancer | |
from skyreels_v2_infer.pipelines import resizecrop | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--outdir", type=str, default="diffusion_forcing") | |
parser.add_argument("--model_id", type=str, default="Skywork/SkyReels-V2-DF-1.3B-540P") | |
parser.add_argument("--resolution", type=str, choices=["540P", "720P"]) | |
parser.add_argument("--num_frames", type=int, default=97) | |
parser.add_argument("--image", type=str, default=None) | |
parser.add_argument("--ar_step", type=int, default=0) | |
parser.add_argument("--causal_attention", action="store_true") | |
parser.add_argument("--causal_block_size", type=int, default=1) | |
parser.add_argument("--base_num_frames", type=int, default=97) | |
parser.add_argument("--overlap_history", type=int, default=None) | |
parser.add_argument("--addnoise_condition", type=int, default=0) | |
parser.add_argument("--guidance_scale", type=float, default=6.0) | |
parser.add_argument("--shift", type=float, default=8.0) | |
parser.add_argument("--inference_steps", type=int, default=30) | |
parser.add_argument("--use_usp", action="store_true") | |
parser.add_argument("--offload", action="store_true") | |
parser.add_argument("--fps", type=int, default=24) | |
parser.add_argument("--seed", type=int, default=None) | |
parser.add_argument( | |
"--prompt", | |
type=str, | |
default="A woman in a leather jacket and sunglasses riding a vintage motorcycle through a desert highway at sunset, her hair blowing wildly in the wind as the motorcycle kicks up dust, with the golden sun casting long shadows across the barren landscape.", | |
) | |
parser.add_argument("--prompt_enhancer", action="store_true") | |
parser.add_argument("--teacache", action="store_true") | |
parser.add_argument( | |
"--teacache_thresh", | |
type=float, | |
default=0.2, | |
help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup", | |
) | |
parser.add_argument( | |
"--use_ret_steps", | |
action="store_true", | |
help="Using Retention Steps will result in faster generation speed and better generation quality.", | |
) | |
args = parser.parse_args() | |
args.model_id = download_model(args.model_id) | |
print("model_id:", args.model_id) | |
assert (args.use_usp and args.seed is not None) or (not args.use_usp), "usp mode need seed" | |
if args.seed is None: | |
random.seed(time.time()) | |
args.seed = int(random.randrange(4294967294)) | |
if args.resolution == "540P": | |
height = 544 | |
width = 960 | |
elif args.resolution == "720P": | |
height = 720 | |
width = 1280 | |
else: | |
raise ValueError(f"Invalid resolution: {args.resolution}") | |
num_frames = args.num_frames | |
fps = args.fps | |
if num_frames > args.base_num_frames: | |
assert ( | |
args.overlap_history is not None | |
), 'You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.' | |
if args.addnoise_condition > 60: | |
print( | |
f'You have set "addnoise_condition" as {args.addnoise_condition}. The value is too large which can cause inconsistency in long video generation. The value is recommanded to set 20.' | |
) | |
guidance_scale = args.guidance_scale | |
shift = args.shift | |
if args.image: | |
args.image = load_image(args.image) | |
image_width, image_height = args.image.size | |
if image_height > image_width: | |
height, width = width, height | |
args.image = resizecrop(args.image, height, width) | |
image = args.image.convert("RGB") if args.image else None | |
negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" | |
save_dir = os.path.join("result", args.outdir) | |
os.makedirs(save_dir, exist_ok=True) | |
local_rank = 0 | |
if args.use_usp: | |
assert ( | |
not args.prompt_enhancer | |
), "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter." | |
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment | |
import torch.distributed as dist | |
dist.init_process_group("nccl") | |
local_rank = dist.get_rank() | |
torch.cuda.set_device(dist.get_rank()) | |
device = "cuda" | |
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) | |
initialize_model_parallel( | |
sequence_parallel_degree=dist.get_world_size(), | |
ring_degree=1, | |
ulysses_degree=dist.get_world_size(), | |
) | |
prompt_input = args.prompt | |
if args.prompt_enhancer and args.image is None: | |
print(f"init prompt enhancer") | |
prompt_enhancer = PromptEnhancer() | |
prompt_input = prompt_enhancer(prompt_input) | |
print(f"enhanced prompt: {prompt_input}") | |
del prompt_enhancer | |
gc.collect() | |
torch.cuda.empty_cache() | |
pipe = DiffusionForcingPipeline( | |
args.model_id, | |
dit_path=args.model_id, | |
device=torch.device("cuda"), | |
weight_dtype=torch.bfloat16, | |
use_usp=args.use_usp, | |
offload=args.offload, | |
) | |
if args.causal_attention: | |
pipe.transformer.set_ar_attention(args.causal_block_size) | |
if args.teacache: | |
if args.ar_step > 0: | |
num_steps = ( | |
args.inference_steps | |
+ (((args.base_num_frames - 1) // 4 + 1) // args.causal_block_size - 1) * args.ar_step | |
) | |
print("num_steps:", num_steps) | |
else: | |
num_steps = args.inference_steps | |
pipe.transformer.initialize_teacache( | |
enable_teacache=True, | |
num_steps=num_steps, | |
teacache_thresh=args.teacache_thresh, | |
use_ret_steps=args.use_ret_steps, | |
ckpt_dir=args.model_id, | |
) | |
print(f"prompt:{prompt_input}") | |
print(f"guidance_scale:{guidance_scale}") | |
with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad(): | |
video_frames = pipe( | |
prompt=prompt_input, | |
negative_prompt=negative_prompt, | |
image=image, | |
height=height, | |
width=width, | |
num_frames=num_frames, | |
num_inference_steps=args.inference_steps, | |
shift=shift, | |
guidance_scale=guidance_scale, | |
generator=torch.Generator(device="cuda").manual_seed(args.seed), | |
overlap_history=args.overlap_history, | |
addnoise_condition=args.addnoise_condition, | |
base_num_frames=args.base_num_frames, | |
ar_step=args.ar_step, | |
causal_block_size=args.causal_block_size, | |
fps=fps, | |
)[0] | |
if local_rank == 0: | |
current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) | |
video_out_file = f"{args.prompt[:100].replace('/','')}_{args.seed}_{current_time}.mp4" | |
output_path = os.path.join(save_dir, video_out_file) | |
imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"]) | |