fffiloni commited on
Commit
1d73f5a
·
verified ·
1 Parent(s): a546ef4

Create app_df.py

Browse files
Files changed (1) hide show
  1. app_df.py +159 -0
app_df.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import time
4
+ import random
5
+ import torch
6
+ import imageio
7
+ import gradio as gr
8
+ from diffusers.utils import load_image
9
+
10
+ from skyreels_v2_infer import DiffusionForcingPipeline
11
+ from skyreels_v2_infer.modules import download_model
12
+ from skyreels_v2_infer.pipelines import PromptEnhancer, resizecrop
13
+
14
+ def generate_diffusion_forced_video(
15
+ prompt,
16
+ model_id,
17
+ resolution,
18
+ num_frames,
19
+ image=None,
20
+ ar_step=0,
21
+ causal_attention=False,
22
+ causal_block_size=1,
23
+ base_num_frames=97,
24
+ overlap_history=None,
25
+ addnoise_condition=0,
26
+ guidance_scale=6.0,
27
+ shift=8.0,
28
+ inference_steps=30,
29
+ use_usp=False,
30
+ offload=False,
31
+ fps=24,
32
+ seed=None,
33
+ prompt_enhancer=False,
34
+ teacache=False,
35
+ teacache_thresh=0.2,
36
+ use_ret_steps=False
37
+ ):
38
+ model_id = download_model(model_id)
39
+
40
+ if resolution == "540P":
41
+ height, width = 544, 960
42
+ elif resolution == "720P":
43
+ height, width = 720, 1280
44
+ else:
45
+ raise ValueError(f"Invalid resolution: {resolution}")
46
+
47
+ if seed is None:
48
+ random.seed(time.time())
49
+ seed = int(random.randrange(4294967294))
50
+
51
+ if num_frames > base_num_frames and overlap_history is None:
52
+ raise ValueError("Specify `overlap_history` for long video generation. Try 17 or 37.")
53
+ if addnoise_condition > 60:
54
+ print("Warning: Large `addnoise_condition` may reduce consistency. Recommended: 20.")
55
+
56
+ if image is not None:
57
+ image = load_image(image).convert("RGB")
58
+ image_width, image_height = image.size
59
+ if image_height > image_width:
60
+ height, width = width, height
61
+ image = resizecrop(image, height, width)
62
+
63
+ negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
64
+
65
+ prompt_input = prompt
66
+ if prompt_enhancer and image is None:
67
+ enhancer = PromptEnhancer()
68
+ prompt_input = enhancer(prompt_input)
69
+ del enhancer
70
+ gc.collect()
71
+ torch.cuda.empty_cache()
72
+
73
+ pipe = DiffusionForcingPipeline(
74
+ model_id,
75
+ dit_path=model_id,
76
+ device=torch.device("cuda"),
77
+ weight_dtype=torch.bfloat16,
78
+ use_usp=use_usp,
79
+ offload=offload,
80
+ )
81
+
82
+ if causal_attention:
83
+ pipe.transformer.set_ar_attention(causal_block_size)
84
+
85
+ if teacache:
86
+ if ar_step > 0:
87
+ num_steps = (
88
+ inference_steps + (((base_num_frames - 1) // 4 + 1) // causal_block_size - 1) * ar_step
89
+ )
90
+ else:
91
+ num_steps = inference_steps
92
+ pipe.transformer.initialize_teacache(
93
+ enable_teacache=True,
94
+ num_steps=num_steps,
95
+ teacache_thresh=teacache_thresh,
96
+ use_ret_steps=use_ret_steps,
97
+ ckpt_dir=model_id,
98
+ )
99
+
100
+ with torch.amp.autocast("cuda", dtype=pipe.transformer.dtype), torch.no_grad():
101
+ video_frames = pipe(
102
+ prompt=prompt_input,
103
+ negative_prompt=negative_prompt,
104
+ image=image,
105
+ height=height,
106
+ width=width,
107
+ num_frames=num_frames,
108
+ num_inference_steps=inference_steps,
109
+ shift=shift,
110
+ guidance_scale=guidance_scale,
111
+ generator=torch.Generator(device="cuda").manual_seed(seed),
112
+ overlap_history=overlap_history,
113
+ addnoise_condition=addnoise_condition,
114
+ base_num_frames=base_num_frames,
115
+ ar_step=ar_step,
116
+ causal_block_size=causal_block_size,
117
+ fps=fps,
118
+ )[0]
119
+
120
+ os.makedirs("gradio_df_videos", exist_ok=True)
121
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
122
+ output_path = f"gradio_df_videos/{prompt[:50].replace('/', '')}_{seed}_{timestamp}.mp4"
123
+ imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"])
124
+ return output_path
125
+
126
+
127
+ # Gradio UI
128
+ resolution_options = ["540P", "720P"]
129
+ model_options = ["Skywork/SkyReels-V2-DF-1.3B-540P"] # Update if there are more
130
+
131
+ gr.Interface(
132
+ fn=generate_diffusion_forced_video,
133
+ inputs=[
134
+ gr.Textbox(label="Prompt"),
135
+ gr.Dropdown(choices=model_options, value=model_options[0], label="Model ID"),
136
+ gr.Radio(choices=resolution_options, value="540P", label="Resolution"),
137
+ gr.Slider(minimum=16, maximum=200, value=97, step=1, label="Number of Frames"),
138
+ gr.Image(type="filepath", label="Input Image (optional)"),
139
+ gr.Number(label="AR Step", value=0),
140
+ gr.Checkbox(label="Causal Attention"),
141
+ gr.Number(label="Causal Block Size", value=1),
142
+ gr.Number(label="Base Num Frames", value=97),
143
+ gr.Number(label="Overlap History (set for long videos)", value=None),
144
+ gr.Number(label="AddNoise Condition", value=0),
145
+ gr.Slider(minimum=1.0, maximum=20.0, value=6.0, step=0.1, label="Guidance Scale"),
146
+ gr.Slider(minimum=0.0, maximum=20.0, value=8.0, step=0.1, label="Shift"),
147
+ gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Inference Steps"),
148
+ gr.Checkbox(label="Use USP"),
149
+ gr.Checkbox(label="Offload"),
150
+ gr.Slider(minimum=1, maximum=60, value=24, step=1, label="FPS"),
151
+ gr.Number(label="Seed (optional)", precision=0),
152
+ gr.Checkbox(label="Prompt Enhancer"),
153
+ gr.Checkbox(label="Use TeaCache"),
154
+ gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.01, label="TeaCache Threshold"),
155
+ gr.Checkbox(label="Use Retention Steps"),
156
+ ],
157
+ outputs=gr.Video(label="Generated Video"),
158
+ title="SkyReels V2 Diffusion Forcing Generator"
159
+ ).launch()