File size: 5,907 Bytes
1d73f5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a13086a
1d73f5a
 
 
 
 
 
 
 
a13086a
1d73f5a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os
import gc
import time
import random
import torch
import imageio
import gradio as gr
from diffusers.utils import load_image

from skyreels_v2_infer import DiffusionForcingPipeline
from skyreels_v2_infer.modules import download_model
from skyreels_v2_infer.pipelines import PromptEnhancer, resizecrop

def generate_diffusion_forced_video(
    prompt,
    model_id,
    resolution,
    num_frames,
    image=None,
    ar_step=0,
    causal_attention=False,
    causal_block_size=1,
    base_num_frames=97,
    overlap_history=None,
    addnoise_condition=0,
    guidance_scale=6.0,
    shift=8.0,
    inference_steps=30,
    use_usp=False,
    offload=False,
    fps=24,
    seed=None,
    prompt_enhancer=False,
    teacache=False,
    teacache_thresh=0.2,
    use_ret_steps=False
):
    model_id = download_model(model_id)

    if resolution == "540P":
        height, width = 544, 960
    elif resolution == "720P":
        height, width = 720, 1280
    else:
        raise ValueError(f"Invalid resolution: {resolution}")

    if seed is None:
        random.seed(time.time())
        seed = int(random.randrange(4294967294))

    if num_frames > base_num_frames and overlap_history is None:
        raise ValueError("Specify `overlap_history` for long video generation. Try 17 or 37.")
    if addnoise_condition > 60:
        print("Warning: Large `addnoise_condition` may reduce consistency. Recommended: 20.")

    if image is not None:
        image = load_image(image).convert("RGB")
        image_width, image_height = image.size
        if image_height > image_width:
            height, width = width, height
        image = resizecrop(image, height, width)

    negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"

    prompt_input = prompt
    if prompt_enhancer and image is None:
        enhancer = PromptEnhancer()
        prompt_input = enhancer(prompt_input)
        del enhancer
        gc.collect()
        torch.cuda.empty_cache()

    pipe = DiffusionForcingPipeline(
        model_id,
        dit_path=model_id,
        device=torch.device("cuda"),
        weight_dtype=torch.bfloat16,
        use_usp=use_usp,
        offload=offload,
    )

    if causal_attention:
        pipe.transformer.set_ar_attention(causal_block_size)

    if teacache:
        if ar_step > 0:
            num_steps = (
                inference_steps + (((base_num_frames - 1) // 4 + 1) // causal_block_size - 1) * ar_step
            )
        else:
            num_steps = inference_steps
        pipe.transformer.initialize_teacache(
            enable_teacache=True,
            num_steps=num_steps,
            teacache_thresh=teacache_thresh,
            use_ret_steps=use_ret_steps,
            ckpt_dir=model_id,
        )

    with torch.amp.autocast("cuda", dtype=pipe.transformer.dtype), torch.no_grad():
        video_frames = pipe(
            prompt=prompt_input,
            negative_prompt=negative_prompt,
            image=image,
            height=height,
            width=width,
            num_frames=num_frames,
            num_inference_steps=inference_steps,
            shift=shift,
            guidance_scale=guidance_scale,
            generator=torch.Generator(device="cuda").manual_seed(seed),
            overlap_history=overlap_history,
            addnoise_condition=addnoise_condition,
            base_num_frames=base_num_frames,
            ar_step=ar_step,
            causal_block_size=causal_block_size,
            fps=fps,
        )[0]

    os.makedirs("gradio_df_videos", exist_ok=True)
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    output_path = f"gradio_df_videos/{prompt[:50].replace('/', '')}_{seed}_{timestamp}.mp4"
    imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"])
    return output_path


# Gradio UI
resolution_options = ["540P", "720P"]
model_options = ["Skywork/SkyReels-V2-DF-1.3B-540P"]  # Update if there are more

gr.Interface(
    fn=generate_diffusion_forced_video,
    inputs=[
        gr.Textbox(label="Prompt"),
        gr.Dropdown(choices=model_options, value=model_options[0], label="Model ID"),
        gr.Radio(choices=resolution_options, value="540P", label="Resolution"),
        gr.Slider(minimum=16, maximum=200, value=97, step=1, label="Number of Frames"),
        gr.Image(type="filepath", label="Input Image (optional)"),
        gr.Number(label="AR Step", value=0),
        gr.Checkbox(label="Causal Attention"),
        gr.Number(label="Causal Block Size", value=1),
        gr.Number(label="Base Num Frames", value=97),
        gr.Number(label="Overlap History (set for long videos)", value=None),
        gr.Number(label="AddNoise Condition", value=0),
        gr.Slider(minimum=1.0, maximum=20.0, value=6.0, step=0.1, label="Guidance Scale"),
        gr.Slider(minimum=0.0, maximum=20.0, value=8.0, step=0.1, label="Shift"),
        gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Inference Steps"),
        gr.Checkbox(label="Use USP"),
        gr.Checkbox(label="Offload", value=True, interactive=False),
        gr.Slider(minimum=1, maximum=60, value=24, step=1, label="FPS"),
        gr.Number(label="Seed (optional)", precision=0),
        gr.Checkbox(label="Prompt Enhancer"),
        gr.Checkbox(label="Use TeaCache"),
        gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.01, label="TeaCache Threshold"),
        gr.Checkbox(label="Use Retention Steps"),
    ],
    outputs=gr.Video(label="Generated Video"),
    title="SkyReels V2 Diffusion Forcing"
).launch()