import os import gc import time import random import torch import imageio import gradio as gr from diffusers.utils import load_image from skyreels_v2_infer import DiffusionForcingPipeline from skyreels_v2_infer.modules import download_model from skyreels_v2_infer.pipelines import PromptEnhancer, resizecrop def generate_diffusion_forced_video( prompt, model_id, resolution, num_frames, image=None, ar_step=0, causal_attention=False, causal_block_size=1, base_num_frames=97, overlap_history=None, addnoise_condition=0, guidance_scale=6.0, shift=8.0, inference_steps=30, use_usp=False, offload=False, fps=24, seed=None, prompt_enhancer=False, teacache=False, teacache_thresh=0.2, use_ret_steps=False ): model_id = download_model(model_id) if resolution == "540P": height, width = 544, 960 elif resolution == "720P": height, width = 720, 1280 else: raise ValueError(f"Invalid resolution: {resolution}") if seed is None: random.seed(time.time()) seed = int(random.randrange(4294967294)) if num_frames > base_num_frames and overlap_history is None: raise ValueError("Specify `overlap_history` for long video generation. Try 17 or 37.") if addnoise_condition > 60: print("Warning: Large `addnoise_condition` may reduce consistency. Recommended: 20.") if image is not None: image = load_image(image).convert("RGB") image_width, image_height = image.size if image_height > image_width: height, width = width, height image = resizecrop(image, height, width) negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" prompt_input = prompt if prompt_enhancer and image is None: enhancer = PromptEnhancer() prompt_input = enhancer(prompt_input) del enhancer gc.collect() torch.cuda.empty_cache() pipe = DiffusionForcingPipeline( model_id, dit_path=model_id, device=torch.device("cuda"), weight_dtype=torch.bfloat16, use_usp=use_usp, offload=offload, ) if causal_attention: pipe.transformer.set_ar_attention(causal_block_size) if teacache: if ar_step > 0: num_steps = ( inference_steps + (((base_num_frames - 1) // 4 + 1) // causal_block_size - 1) * ar_step ) else: num_steps = inference_steps pipe.transformer.initialize_teacache( enable_teacache=True, num_steps=num_steps, teacache_thresh=teacache_thresh, use_ret_steps=use_ret_steps, ckpt_dir=model_id, ) with torch.amp.autocast("cuda", dtype=pipe.transformer.dtype), torch.no_grad(): video_frames = pipe( prompt=prompt_input, negative_prompt=negative_prompt, image=image, height=height, width=width, num_frames=num_frames, num_inference_steps=inference_steps, shift=shift, guidance_scale=guidance_scale, generator=torch.Generator(device="cuda").manual_seed(seed), overlap_history=overlap_history, addnoise_condition=addnoise_condition, base_num_frames=base_num_frames, ar_step=ar_step, causal_block_size=causal_block_size, fps=fps, )[0] os.makedirs("gradio_df_videos", exist_ok=True) timestamp = time.strftime("%Y%m%d_%H%M%S") output_path = f"gradio_df_videos/{prompt[:50].replace('/', '')}_{seed}_{timestamp}.mp4" imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"]) return output_path # Gradio UI resolution_options = ["540P", "720P"] model_options = ["Skywork/SkyReels-V2-DF-1.3B-540P"] # Update if there are more gr.Interface( fn=generate_diffusion_forced_video, inputs=[ gr.Textbox(label="Prompt"), gr.Dropdown(choices=model_options, value=model_options[0], label="Model ID"), gr.Radio(choices=resolution_options, value="540P", label="Resolution"), gr.Slider(minimum=16, maximum=200, value=97, step=1, label="Number of Frames"), gr.Image(type="filepath", label="Input Image (optional)"), gr.Number(label="AR Step", value=0), gr.Checkbox(label="Causal Attention"), gr.Number(label="Causal Block Size", value=1), gr.Number(label="Base Num Frames", value=97), gr.Number(label="Overlap History (set for long videos)", value=None), gr.Number(label="AddNoise Condition", value=0), gr.Slider(minimum=1.0, maximum=20.0, value=6.0, step=0.1, label="Guidance Scale"), gr.Slider(minimum=0.0, maximum=20.0, value=8.0, step=0.1, label="Shift"), gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Inference Steps"), gr.Checkbox(label="Use USP"), gr.Checkbox(label="Offload", value=True, interactive=False), gr.Slider(minimum=1, maximum=60, value=24, step=1, label="FPS"), gr.Number(label="Seed (optional)", precision=0), gr.Checkbox(label="Prompt Enhancer"), gr.Checkbox(label="Use TeaCache"), gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.01, label="TeaCache Threshold"), gr.Checkbox(label="Use Retention Steps"), ], outputs=gr.Video(label="Generated Video"), title="SkyReels V2 Diffusion Forcing" ).launch()