Spaces:
Running
on
L40S
Running
on
L40S
Update app.py
Browse files
app.py
CHANGED
@@ -1,23 +1,20 @@
|
|
1 |
import torch
|
2 |
-
from diffusers import AutoencoderKLWan,
|
3 |
from diffusers.utils import export_to_video
|
4 |
-
from transformers import CLIPVisionModel
|
5 |
import gradio as gr
|
6 |
import tempfile
|
7 |
import spaces
|
8 |
from huggingface_hub import hf_hub_download
|
9 |
import numpy as np
|
10 |
-
from PIL import Image
|
11 |
import random
|
12 |
|
13 |
-
MODEL_ID = "Wan-AI/Wan2.1-
|
14 |
LORA_REPO_ID = "Kijai/WanVideo_comfy"
|
15 |
-
LORA_FILENAME = "
|
16 |
|
17 |
-
image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32)
|
18 |
vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
|
19 |
-
pipe =
|
20 |
-
MODEL_ID, vae=vae,
|
21 |
)
|
22 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
|
23 |
pipe.to("cuda")
|
@@ -30,7 +27,6 @@ pipe.fuse_lora()
|
|
30 |
MOD_VALUE = 32
|
31 |
DEFAULT_H_SLIDER_VALUE = 512
|
32 |
DEFAULT_W_SLIDER_VALUE = 896
|
33 |
-
NEW_FORMULA_MAX_AREA = 480.0 * 832.0
|
34 |
|
35 |
SLIDER_MIN_H, SLIDER_MAX_H = 128, 896
|
36 |
SLIDER_MIN_W, SLIDER_MAX_W = 128, 896
|
@@ -40,46 +36,10 @@ FIXED_FPS = 24
|
|
40 |
MIN_FRAMES_MODEL = 8
|
41 |
MAX_FRAMES_MODEL = 81
|
42 |
|
43 |
-
|
44 |
default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
|
45 |
|
46 |
-
|
47 |
-
def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area,
|
48 |
-
min_slider_h, max_slider_h,
|
49 |
-
min_slider_w, max_slider_w,
|
50 |
-
default_h, default_w):
|
51 |
-
orig_w, orig_h = pil_image.size
|
52 |
-
if orig_w <= 0 or orig_h <= 0:
|
53 |
-
return default_h, default_w
|
54 |
-
|
55 |
-
aspect_ratio = orig_h / orig_w
|
56 |
-
|
57 |
-
calc_h = round(np.sqrt(calculation_max_area * aspect_ratio))
|
58 |
-
calc_w = round(np.sqrt(calculation_max_area / aspect_ratio))
|
59 |
-
|
60 |
-
calc_h = max(mod_val, (calc_h // mod_val) * mod_val)
|
61 |
-
calc_w = max(mod_val, (calc_w // mod_val) * mod_val)
|
62 |
-
|
63 |
-
new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val))
|
64 |
-
new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val))
|
65 |
-
|
66 |
-
return new_h, new_w
|
67 |
-
|
68 |
-
def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_w_val):
|
69 |
-
if uploaded_pil_image is None:
|
70 |
-
return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
|
71 |
-
try:
|
72 |
-
new_h, new_w = _calculate_new_dimensions_wan(
|
73 |
-
uploaded_pil_image, MOD_VALUE, NEW_FORMULA_MAX_AREA,
|
74 |
-
SLIDER_MIN_H, SLIDER_MAX_H, SLIDER_MIN_W, SLIDER_MAX_W,
|
75 |
-
DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
|
76 |
-
)
|
77 |
-
return gr.update(value=new_h), gr.update(value=new_w)
|
78 |
-
except Exception as e:
|
79 |
-
gr.Warning("Error attempting to calculate new dimensions")
|
80 |
-
return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
|
81 |
-
|
82 |
-
def get_duration(input_image, prompt, height, width,
|
83 |
negative_prompt, duration_seconds,
|
84 |
guidance_scale, steps,
|
85 |
seed, randomize_seed,
|
@@ -92,21 +52,20 @@ def get_duration(input_image, prompt, height, width,
|
|
92 |
return 60
|
93 |
|
94 |
@spaces.GPU(duration=get_duration)
|
95 |
-
def generate_video(
|
96 |
negative_prompt=default_negative_prompt, duration_seconds = 2,
|
97 |
guidance_scale = 1, steps = 4,
|
98 |
seed = 42, randomize_seed = False,
|
99 |
progress=gr.Progress(track_tqdm=True)):
|
100 |
"""
|
101 |
-
Generate a video from
|
102 |
|
103 |
-
This function takes
|
104 |
-
prompt and parameters. It uses the Wan 2.1
|
105 |
for fast generation in 4-8 steps.
|
106 |
|
107 |
Args:
|
108 |
-
|
109 |
-
prompt (str): Text prompt describing the desired animation or motion.
|
110 |
height (int): Target height for the output video. Will be adjusted to multiple of MOD_VALUE (32).
|
111 |
width (int): Target width for the output video. Will be adjusted to multiple of MOD_VALUE (32).
|
112 |
negative_prompt (str, optional): Negative prompt to avoid unwanted elements.
|
@@ -129,17 +88,16 @@ def generate_video(input_image, prompt, height, width,
|
|
129 |
- current_seed (int): The seed used for generation (useful when randomize_seed=True)
|
130 |
|
131 |
Raises:
|
132 |
-
gr.Error: If
|
133 |
|
134 |
Note:
|
135 |
-
- The function automatically resizes the input image to the target dimensions
|
136 |
- Frame count is calculated as duration_seconds * FIXED_FPS (24)
|
137 |
- Output dimensions are adjusted to be multiples of MOD_VALUE (32)
|
138 |
- The function uses GPU acceleration via the @spaces.GPU decorator
|
139 |
- Generation time varies based on steps and duration (see get_duration function)
|
140 |
"""
|
141 |
-
if
|
142 |
-
raise gr.Error("Please
|
143 |
|
144 |
target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
|
145 |
target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
|
@@ -148,11 +106,9 @@ def generate_video(input_image, prompt, height, width,
|
|
148 |
|
149 |
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
|
150 |
|
151 |
-
resized_image = input_image.resize((target_w, target_h))
|
152 |
-
|
153 |
with torch.inference_mode():
|
154 |
output_frames_list = pipe(
|
155 |
-
|
156 |
height=target_h, width=target_w, num_frames=num_frames,
|
157 |
guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
|
158 |
generator=torch.Generator(device="cuda").manual_seed(current_seed)
|
@@ -164,12 +120,11 @@ def generate_video(input_image, prompt, height, width,
|
|
164 |
return video_path, current_seed
|
165 |
|
166 |
with gr.Blocks() as demo:
|
167 |
-
gr.Markdown("# Fast 4 steps Wan 2.1
|
168 |
-
gr.Markdown("[CausVid](https://github.com/tianweiy/CausVid) is a distilled version of Wan 2.1 to run faster in just 4-8 steps, [extracted as LoRA by Kijai](https://huggingface.co/Kijai/WanVideo_comfy/blob/main/
|
169 |
with gr.Row():
|
170 |
with gr.Column():
|
171 |
-
|
172 |
-
prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
|
173 |
duration_seconds_input = gr.Slider(minimum=round(MIN_FRAMES_MODEL/FIXED_FPS,1), maximum=round(MAX_FRAMES_MODEL/FIXED_FPS,1), step=0.1, value=2, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
|
174 |
|
175 |
with gr.Accordion("Advanced Settings", open=False):
|
@@ -185,21 +140,9 @@ with gr.Blocks() as demo:
|
|
185 |
generate_button = gr.Button("Generate Video", variant="primary")
|
186 |
with gr.Column():
|
187 |
video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
|
188 |
-
|
189 |
-
input_image_component.upload(
|
190 |
-
fn=handle_image_upload_for_dims_wan,
|
191 |
-
inputs=[input_image_component, height_input, width_input],
|
192 |
-
outputs=[height_input, width_input]
|
193 |
-
)
|
194 |
-
|
195 |
-
input_image_component.clear(
|
196 |
-
fn=handle_image_upload_for_dims_wan,
|
197 |
-
inputs=[input_image_component, height_input, width_input],
|
198 |
-
outputs=[height_input, width_input]
|
199 |
-
)
|
200 |
|
201 |
ui_inputs = [
|
202 |
-
|
203 |
negative_prompt_input, duration_seconds_input,
|
204 |
guidance_scale_input, steps_slider, seed_input, randomize_seed_checkbox
|
205 |
]
|
@@ -207,10 +150,11 @@ with gr.Blocks() as demo:
|
|
207 |
|
208 |
gr.Examples(
|
209 |
examples=[
|
210 |
-
["
|
211 |
-
["
|
|
|
212 |
],
|
213 |
-
inputs=[
|
214 |
)
|
215 |
|
216 |
if __name__ == "__main__":
|
|
|
1 |
import torch
|
2 |
+
from diffusers import AutoencoderKLWan, WanTextToVideoPipeline, UniPCMultistepScheduler
|
3 |
from diffusers.utils import export_to_video
|
|
|
4 |
import gradio as gr
|
5 |
import tempfile
|
6 |
import spaces
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
import numpy as np
|
|
|
9 |
import random
|
10 |
|
11 |
+
MODEL_ID = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
12 |
LORA_REPO_ID = "Kijai/WanVideo_comfy"
|
13 |
+
LORA_FILENAME = "Wan21_CausVid_bidirect2_T2V_1_3B_lora_rank32.safetensors"
|
14 |
|
|
|
15 |
vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
|
16 |
+
pipe = WanTextToVideoPipeline.from_pretrained(
|
17 |
+
MODEL_ID, vae=vae, torch_dtype=torch.bfloat16
|
18 |
)
|
19 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
|
20 |
pipe.to("cuda")
|
|
|
27 |
MOD_VALUE = 32
|
28 |
DEFAULT_H_SLIDER_VALUE = 512
|
29 |
DEFAULT_W_SLIDER_VALUE = 896
|
|
|
30 |
|
31 |
SLIDER_MIN_H, SLIDER_MAX_H = 128, 896
|
32 |
SLIDER_MIN_W, SLIDER_MAX_W = 128, 896
|
|
|
36 |
MIN_FRAMES_MODEL = 8
|
37 |
MAX_FRAMES_MODEL = 81
|
38 |
|
39 |
+
default_prompt_t2v = "a beautiful sunset over mountains, cinematic, smooth camera movement"
|
40 |
default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
|
41 |
|
42 |
+
def get_duration(prompt, height, width,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
negative_prompt, duration_seconds,
|
44 |
guidance_scale, steps,
|
45 |
seed, randomize_seed,
|
|
|
52 |
return 60
|
53 |
|
54 |
@spaces.GPU(duration=get_duration)
|
55 |
+
def generate_video(prompt, height, width,
|
56 |
negative_prompt=default_negative_prompt, duration_seconds = 2,
|
57 |
guidance_scale = 1, steps = 4,
|
58 |
seed = 42, randomize_seed = False,
|
59 |
progress=gr.Progress(track_tqdm=True)):
|
60 |
"""
|
61 |
+
Generate a video from a text prompt using the Wan 2.1 T2V model with CausVid LoRA.
|
62 |
|
63 |
+
This function takes a text prompt and generates a video based on the provided
|
64 |
+
prompt and parameters. It uses the Wan 2.1 1.3B Text-to-Video model with CausVid LoRA
|
65 |
for fast generation in 4-8 steps.
|
66 |
|
67 |
Args:
|
68 |
+
prompt (str): Text prompt describing the desired video content.
|
|
|
69 |
height (int): Target height for the output video. Will be adjusted to multiple of MOD_VALUE (32).
|
70 |
width (int): Target width for the output video. Will be adjusted to multiple of MOD_VALUE (32).
|
71 |
negative_prompt (str, optional): Negative prompt to avoid unwanted elements.
|
|
|
88 |
- current_seed (int): The seed used for generation (useful when randomize_seed=True)
|
89 |
|
90 |
Raises:
|
91 |
+
gr.Error: If prompt is empty or None.
|
92 |
|
93 |
Note:
|
|
|
94 |
- Frame count is calculated as duration_seconds * FIXED_FPS (24)
|
95 |
- Output dimensions are adjusted to be multiples of MOD_VALUE (32)
|
96 |
- The function uses GPU acceleration via the @spaces.GPU decorator
|
97 |
- Generation time varies based on steps and duration (see get_duration function)
|
98 |
"""
|
99 |
+
if not prompt or prompt.strip() == "":
|
100 |
+
raise gr.Error("Please enter a text prompt.")
|
101 |
|
102 |
target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
|
103 |
target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
|
|
|
106 |
|
107 |
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
|
108 |
|
|
|
|
|
109 |
with torch.inference_mode():
|
110 |
output_frames_list = pipe(
|
111 |
+
prompt=prompt, negative_prompt=negative_prompt,
|
112 |
height=target_h, width=target_w, num_frames=num_frames,
|
113 |
guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
|
114 |
generator=torch.Generator(device="cuda").manual_seed(current_seed)
|
|
|
120 |
return video_path, current_seed
|
121 |
|
122 |
with gr.Blocks() as demo:
|
123 |
+
gr.Markdown("# Fast 4 steps Wan 2.1 T2V (1.3B) with CausVid LoRA")
|
124 |
+
gr.Markdown("[CausVid](https://github.com/tianweiy/CausVid) is a distilled version of Wan 2.1 to run faster in just 4-8 steps, [extracted as LoRA by Kijai](https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_CausVid_bidirect2_T2V_1_3B_lora_rank32.safetensors) and is compatible with 🧨 diffusers")
|
125 |
with gr.Row():
|
126 |
with gr.Column():
|
127 |
+
prompt_input = gr.Textbox(label="Prompt", value=default_prompt_t2v, placeholder="Describe the video you want to generate...")
|
|
|
128 |
duration_seconds_input = gr.Slider(minimum=round(MIN_FRAMES_MODEL/FIXED_FPS,1), maximum=round(MAX_FRAMES_MODEL/FIXED_FPS,1), step=0.1, value=2, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
|
129 |
|
130 |
with gr.Accordion("Advanced Settings", open=False):
|
|
|
140 |
generate_button = gr.Button("Generate Video", variant="primary")
|
141 |
with gr.Column():
|
142 |
video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
ui_inputs = [
|
145 |
+
prompt_input, height_input, width_input,
|
146 |
negative_prompt_input, duration_seconds_input,
|
147 |
guidance_scale_input, steps_slider, seed_input, randomize_seed_checkbox
|
148 |
]
|
|
|
150 |
|
151 |
gr.Examples(
|
152 |
examples=[
|
153 |
+
["a majestic eagle soaring through mountain peaks, cinematic aerial view", 896, 512],
|
154 |
+
["a serene ocean wave crashing on a sandy beach at sunset", 448, 832],
|
155 |
+
["a field of flowers swaying in the wind, spring morning light", 512, 896],
|
156 |
],
|
157 |
+
inputs=[prompt_input, height_input, width_input], outputs=[video_output, seed_input], fn=generate_video, cache_examples="lazy"
|
158 |
)
|
159 |
|
160 |
if __name__ == "__main__":
|