Spaces:
Running
Running
import os | |
import gc | |
import time | |
import random | |
import torch | |
import imageio | |
import gradio as gr | |
from diffusers.utils import load_image | |
from skyreels_v2_infer import DiffusionForcingPipeline | |
from skyreels_v2_infer.modules import download_model | |
from skyreels_v2_infer.pipelines import PromptEnhancer, resizecrop | |
def generate_diffusion_forced_video( | |
prompt, | |
model_id, | |
resolution, | |
num_frames, | |
image=None, | |
ar_step=0, | |
causal_attention=False, | |
causal_block_size=1, | |
base_num_frames=97, | |
overlap_history=None, | |
addnoise_condition=0, | |
guidance_scale=6.0, | |
shift=8.0, | |
inference_steps=30, | |
use_usp=False, | |
offload=False, | |
fps=24, | |
seed=None, | |
prompt_enhancer=False, | |
teacache=False, | |
teacache_thresh=0.2, | |
use_ret_steps=False | |
): | |
model_id = download_model(model_id) | |
if resolution == "540P": | |
height, width = 544, 960 | |
elif resolution == "720P": | |
height, width = 720, 1280 | |
else: | |
raise ValueError(f"Invalid resolution: {resolution}") | |
if seed is None: | |
random.seed(time.time()) | |
seed = int(random.randrange(4294967294)) | |
if num_frames > base_num_frames and overlap_history is None: | |
raise ValueError("Specify `overlap_history` for long video generation. Try 17 or 37.") | |
if addnoise_condition > 60: | |
print("Warning: Large `addnoise_condition` may reduce consistency. Recommended: 20.") | |
if image is not None: | |
image = load_image(image).convert("RGB") | |
image_width, image_height = image.size | |
if image_height > image_width: | |
height, width = width, height | |
image = resizecrop(image, height, width) | |
negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" | |
prompt_input = prompt | |
if prompt_enhancer and image is None: | |
enhancer = PromptEnhancer() | |
prompt_input = enhancer(prompt_input) | |
del enhancer | |
gc.collect() | |
torch.cuda.empty_cache() | |
pipe = DiffusionForcingPipeline( | |
model_id, | |
dit_path=model_id, | |
device=torch.device("cuda"), | |
weight_dtype=torch.bfloat16, | |
use_usp=use_usp, | |
offload=offload, | |
) | |
if causal_attention: | |
pipe.transformer.set_ar_attention(causal_block_size) | |
if teacache: | |
if ar_step > 0: | |
num_steps = ( | |
inference_steps + (((base_num_frames - 1) // 4 + 1) // causal_block_size - 1) * ar_step | |
) | |
else: | |
num_steps = inference_steps | |
pipe.transformer.initialize_teacache( | |
enable_teacache=True, | |
num_steps=num_steps, | |
teacache_thresh=teacache_thresh, | |
use_ret_steps=use_ret_steps, | |
ckpt_dir=model_id, | |
) | |
with torch.amp.autocast("cuda", dtype=pipe.transformer.dtype), torch.no_grad(): | |
video_frames = pipe( | |
prompt=prompt_input, | |
negative_prompt=negative_prompt, | |
image=image, | |
height=height, | |
width=width, | |
num_frames=num_frames, | |
num_inference_steps=inference_steps, | |
shift=shift, | |
guidance_scale=guidance_scale, | |
generator=torch.Generator(device="cuda").manual_seed(seed), | |
overlap_history=overlap_history, | |
addnoise_condition=addnoise_condition, | |
base_num_frames=base_num_frames, | |
ar_step=ar_step, | |
causal_block_size=causal_block_size, | |
fps=fps, | |
)[0] | |
os.makedirs("gradio_df_videos", exist_ok=True) | |
timestamp = time.strftime("%Y%m%d_%H%M%S") | |
output_path = f"gradio_df_videos/{prompt[:50].replace('/', '')}_{seed}_{timestamp}.mp4" | |
imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"]) | |
return output_path | |
# Gradio UI | |
resolution_options = ["540P", "720P"] | |
model_options = ["Skywork/SkyReels-V2-DF-1.3B-540P"] # Update if there are more | |
gr.Interface( | |
fn=generate_diffusion_forced_video, | |
inputs=[ | |
gr.Textbox(label="Prompt"), | |
gr.Dropdown(choices=model_options, value=model_options[0], label="Model ID"), | |
gr.Radio(choices=resolution_options, value="540P", label="Resolution"), | |
gr.Slider(minimum=16, maximum=200, value=97, step=1, label="Number of Frames"), | |
gr.Image(type="filepath", label="Input Image (optional)"), | |
gr.Number(label="AR Step", value=0), | |
gr.Checkbox(label="Causal Attention"), | |
gr.Number(label="Causal Block Size", value=1), | |
gr.Number(label="Base Num Frames", value=97), | |
gr.Number(label="Overlap History (set for long videos)", value=None), | |
gr.Number(label="AddNoise Condition", value=0), | |
gr.Slider(minimum=1.0, maximum=20.0, value=6.0, step=0.1, label="Guidance Scale"), | |
gr.Slider(minimum=0.0, maximum=20.0, value=8.0, step=0.1, label="Shift"), | |
gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Inference Steps"), | |
gr.Checkbox(label="Use USP"), | |
gr.Checkbox(label="Offload", value=True, interactive=False), | |
gr.Slider(minimum=1, maximum=60, value=24, step=1, label="FPS"), | |
gr.Number(label="Seed (optional)", precision=0), | |
gr.Checkbox(label="Prompt Enhancer"), | |
gr.Checkbox(label="Use TeaCache"), | |
gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.01, label="TeaCache Threshold"), | |
gr.Checkbox(label="Use Retention Steps"), | |
], | |
outputs=gr.Video(label="Generated Video"), | |
title="SkyReels V2 Diffusion Forcing" | |
).launch() | |