jbilcke-hf HF Staff commited on
Commit
d63bc95
·
verified ·
1 Parent(s): d8ad2ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -80
app.py CHANGED
@@ -1,23 +1,20 @@
1
  import torch
2
- from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler
3
  from diffusers.utils import export_to_video
4
- from transformers import CLIPVisionModel
5
  import gradio as gr
6
  import tempfile
7
  import spaces
8
  from huggingface_hub import hf_hub_download
9
  import numpy as np
10
- from PIL import Image
11
  import random
12
 
13
- MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
14
  LORA_REPO_ID = "Kijai/WanVideo_comfy"
15
- LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
16
 
17
- image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32)
18
  vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
19
- pipe = WanImageToVideoPipeline.from_pretrained(
20
- MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
21
  )
22
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
23
  pipe.to("cuda")
@@ -30,7 +27,6 @@ pipe.fuse_lora()
30
  MOD_VALUE = 32
31
  DEFAULT_H_SLIDER_VALUE = 512
32
  DEFAULT_W_SLIDER_VALUE = 896
33
- NEW_FORMULA_MAX_AREA = 480.0 * 832.0
34
 
35
  SLIDER_MIN_H, SLIDER_MAX_H = 128, 896
36
  SLIDER_MIN_W, SLIDER_MAX_W = 128, 896
@@ -40,46 +36,10 @@ FIXED_FPS = 24
40
  MIN_FRAMES_MODEL = 8
41
  MAX_FRAMES_MODEL = 81
42
 
43
- default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
44
  default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
45
 
46
-
47
- def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area,
48
- min_slider_h, max_slider_h,
49
- min_slider_w, max_slider_w,
50
- default_h, default_w):
51
- orig_w, orig_h = pil_image.size
52
- if orig_w <= 0 or orig_h <= 0:
53
- return default_h, default_w
54
-
55
- aspect_ratio = orig_h / orig_w
56
-
57
- calc_h = round(np.sqrt(calculation_max_area * aspect_ratio))
58
- calc_w = round(np.sqrt(calculation_max_area / aspect_ratio))
59
-
60
- calc_h = max(mod_val, (calc_h // mod_val) * mod_val)
61
- calc_w = max(mod_val, (calc_w // mod_val) * mod_val)
62
-
63
- new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val))
64
- new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val))
65
-
66
- return new_h, new_w
67
-
68
- def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_w_val):
69
- if uploaded_pil_image is None:
70
- return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
71
- try:
72
- new_h, new_w = _calculate_new_dimensions_wan(
73
- uploaded_pil_image, MOD_VALUE, NEW_FORMULA_MAX_AREA,
74
- SLIDER_MIN_H, SLIDER_MAX_H, SLIDER_MIN_W, SLIDER_MAX_W,
75
- DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
76
- )
77
- return gr.update(value=new_h), gr.update(value=new_w)
78
- except Exception as e:
79
- gr.Warning("Error attempting to calculate new dimensions")
80
- return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
81
-
82
- def get_duration(input_image, prompt, height, width,
83
  negative_prompt, duration_seconds,
84
  guidance_scale, steps,
85
  seed, randomize_seed,
@@ -92,21 +52,20 @@ def get_duration(input_image, prompt, height, width,
92
  return 60
93
 
94
  @spaces.GPU(duration=get_duration)
95
- def generate_video(input_image, prompt, height, width,
96
  negative_prompt=default_negative_prompt, duration_seconds = 2,
97
  guidance_scale = 1, steps = 4,
98
  seed = 42, randomize_seed = False,
99
  progress=gr.Progress(track_tqdm=True)):
100
  """
101
- Generate a video from an input image using the Wan 2.1 I2V model with CausVid LoRA.
102
 
103
- This function takes an input image and generates a video animation based on the provided
104
- prompt and parameters. It uses the Wan 2.1 14B Image-to-Video model with CausVid LoRA
105
  for fast generation in 4-8 steps.
106
 
107
  Args:
108
- input_image (PIL.Image): The input image to animate. Will be resized to target dimensions.
109
- prompt (str): Text prompt describing the desired animation or motion.
110
  height (int): Target height for the output video. Will be adjusted to multiple of MOD_VALUE (32).
111
  width (int): Target width for the output video. Will be adjusted to multiple of MOD_VALUE (32).
112
  negative_prompt (str, optional): Negative prompt to avoid unwanted elements.
@@ -129,17 +88,16 @@ def generate_video(input_image, prompt, height, width,
129
  - current_seed (int): The seed used for generation (useful when randomize_seed=True)
130
 
131
  Raises:
132
- gr.Error: If input_image is None (no image uploaded).
133
 
134
  Note:
135
- - The function automatically resizes the input image to the target dimensions
136
  - Frame count is calculated as duration_seconds * FIXED_FPS (24)
137
  - Output dimensions are adjusted to be multiples of MOD_VALUE (32)
138
  - The function uses GPU acceleration via the @spaces.GPU decorator
139
  - Generation time varies based on steps and duration (see get_duration function)
140
  """
141
- if input_image is None:
142
- raise gr.Error("Please upload an input image.")
143
 
144
  target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
145
  target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
@@ -148,11 +106,9 @@ def generate_video(input_image, prompt, height, width,
148
 
149
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
150
 
151
- resized_image = input_image.resize((target_w, target_h))
152
-
153
  with torch.inference_mode():
154
  output_frames_list = pipe(
155
- image=resized_image, prompt=prompt, negative_prompt=negative_prompt,
156
  height=target_h, width=target_w, num_frames=num_frames,
157
  guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
158
  generator=torch.Generator(device="cuda").manual_seed(current_seed)
@@ -164,12 +120,11 @@ def generate_video(input_image, prompt, height, width,
164
  return video_path, current_seed
165
 
166
  with gr.Blocks() as demo:
167
- gr.Markdown("# Fast 4 steps Wan 2.1 I2V (14B) with CausVid LoRA")
168
- gr.Markdown("[CausVid](https://github.com/tianweiy/CausVid) is a distilled version of Wan 2.1 to run faster in just 4-8 steps, [extracted as LoRA by Kijai](https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_CausVid_14B_T2V_lora_rank32.safetensors) and is compatible with 🧨 diffusers")
169
  with gr.Row():
170
  with gr.Column():
171
- input_image_component = gr.Image(type="pil", label="Input Image (auto-resized to target H/W)")
172
- prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
173
  duration_seconds_input = gr.Slider(minimum=round(MIN_FRAMES_MODEL/FIXED_FPS,1), maximum=round(MAX_FRAMES_MODEL/FIXED_FPS,1), step=0.1, value=2, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
174
 
175
  with gr.Accordion("Advanced Settings", open=False):
@@ -185,21 +140,9 @@ with gr.Blocks() as demo:
185
  generate_button = gr.Button("Generate Video", variant="primary")
186
  with gr.Column():
187
  video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
188
-
189
- input_image_component.upload(
190
- fn=handle_image_upload_for_dims_wan,
191
- inputs=[input_image_component, height_input, width_input],
192
- outputs=[height_input, width_input]
193
- )
194
-
195
- input_image_component.clear(
196
- fn=handle_image_upload_for_dims_wan,
197
- inputs=[input_image_component, height_input, width_input],
198
- outputs=[height_input, width_input]
199
- )
200
 
201
  ui_inputs = [
202
- input_image_component, prompt_input, height_input, width_input,
203
  negative_prompt_input, duration_seconds_input,
204
  guidance_scale_input, steps_slider, seed_input, randomize_seed_checkbox
205
  ]
@@ -207,10 +150,11 @@ with gr.Blocks() as demo:
207
 
208
  gr.Examples(
209
  examples=[
210
- ["peng.png", "a penguin playfully dancing in the snow, Antarctica", 896, 512],
211
- ["forg.jpg", "the frog jumps around", 448, 832],
 
212
  ],
213
- inputs=[input_image_component, prompt_input, height_input, width_input], outputs=[video_output, seed_input], fn=generate_video, cache_examples="lazy"
214
  )
215
 
216
  if __name__ == "__main__":
 
1
  import torch
2
+ from diffusers import AutoencoderKLWan, WanTextToVideoPipeline, UniPCMultistepScheduler
3
  from diffusers.utils import export_to_video
 
4
  import gradio as gr
5
  import tempfile
6
  import spaces
7
  from huggingface_hub import hf_hub_download
8
  import numpy as np
 
9
  import random
10
 
11
+ MODEL_ID = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
12
  LORA_REPO_ID = "Kijai/WanVideo_comfy"
13
+ LORA_FILENAME = "Wan21_CausVid_bidirect2_T2V_1_3B_lora_rank32.safetensors"
14
 
 
15
  vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
16
+ pipe = WanTextToVideoPipeline.from_pretrained(
17
+ MODEL_ID, vae=vae, torch_dtype=torch.bfloat16
18
  )
19
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
20
  pipe.to("cuda")
 
27
  MOD_VALUE = 32
28
  DEFAULT_H_SLIDER_VALUE = 512
29
  DEFAULT_W_SLIDER_VALUE = 896
 
30
 
31
  SLIDER_MIN_H, SLIDER_MAX_H = 128, 896
32
  SLIDER_MIN_W, SLIDER_MAX_W = 128, 896
 
36
  MIN_FRAMES_MODEL = 8
37
  MAX_FRAMES_MODEL = 81
38
 
39
+ default_prompt_t2v = "a beautiful sunset over mountains, cinematic, smooth camera movement"
40
  default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
41
 
42
+ def get_duration(prompt, height, width,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  negative_prompt, duration_seconds,
44
  guidance_scale, steps,
45
  seed, randomize_seed,
 
52
  return 60
53
 
54
  @spaces.GPU(duration=get_duration)
55
+ def generate_video(prompt, height, width,
56
  negative_prompt=default_negative_prompt, duration_seconds = 2,
57
  guidance_scale = 1, steps = 4,
58
  seed = 42, randomize_seed = False,
59
  progress=gr.Progress(track_tqdm=True)):
60
  """
61
+ Generate a video from a text prompt using the Wan 2.1 T2V model with CausVid LoRA.
62
 
63
+ This function takes a text prompt and generates a video based on the provided
64
+ prompt and parameters. It uses the Wan 2.1 1.3B Text-to-Video model with CausVid LoRA
65
  for fast generation in 4-8 steps.
66
 
67
  Args:
68
+ prompt (str): Text prompt describing the desired video content.
 
69
  height (int): Target height for the output video. Will be adjusted to multiple of MOD_VALUE (32).
70
  width (int): Target width for the output video. Will be adjusted to multiple of MOD_VALUE (32).
71
  negative_prompt (str, optional): Negative prompt to avoid unwanted elements.
 
88
  - current_seed (int): The seed used for generation (useful when randomize_seed=True)
89
 
90
  Raises:
91
+ gr.Error: If prompt is empty or None.
92
 
93
  Note:
 
94
  - Frame count is calculated as duration_seconds * FIXED_FPS (24)
95
  - Output dimensions are adjusted to be multiples of MOD_VALUE (32)
96
  - The function uses GPU acceleration via the @spaces.GPU decorator
97
  - Generation time varies based on steps and duration (see get_duration function)
98
  """
99
+ if not prompt or prompt.strip() == "":
100
+ raise gr.Error("Please enter a text prompt.")
101
 
102
  target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
103
  target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
 
106
 
107
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
108
 
 
 
109
  with torch.inference_mode():
110
  output_frames_list = pipe(
111
+ prompt=prompt, negative_prompt=negative_prompt,
112
  height=target_h, width=target_w, num_frames=num_frames,
113
  guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
114
  generator=torch.Generator(device="cuda").manual_seed(current_seed)
 
120
  return video_path, current_seed
121
 
122
  with gr.Blocks() as demo:
123
+ gr.Markdown("# Fast 4 steps Wan 2.1 T2V (1.3B) with CausVid LoRA")
124
+ gr.Markdown("[CausVid](https://github.com/tianweiy/CausVid) is a distilled version of Wan 2.1 to run faster in just 4-8 steps, [extracted as LoRA by Kijai](https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_CausVid_bidirect2_T2V_1_3B_lora_rank32.safetensors) and is compatible with 🧨 diffusers")
125
  with gr.Row():
126
  with gr.Column():
127
+ prompt_input = gr.Textbox(label="Prompt", value=default_prompt_t2v, placeholder="Describe the video you want to generate...")
 
128
  duration_seconds_input = gr.Slider(minimum=round(MIN_FRAMES_MODEL/FIXED_FPS,1), maximum=round(MAX_FRAMES_MODEL/FIXED_FPS,1), step=0.1, value=2, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
129
 
130
  with gr.Accordion("Advanced Settings", open=False):
 
140
  generate_button = gr.Button("Generate Video", variant="primary")
141
  with gr.Column():
142
  video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  ui_inputs = [
145
+ prompt_input, height_input, width_input,
146
  negative_prompt_input, duration_seconds_input,
147
  guidance_scale_input, steps_slider, seed_input, randomize_seed_checkbox
148
  ]
 
150
 
151
  gr.Examples(
152
  examples=[
153
+ ["a majestic eagle soaring through mountain peaks, cinematic aerial view", 896, 512],
154
+ ["a serene ocean wave crashing on a sandy beach at sunset", 448, 832],
155
+ ["a field of flowers swaying in the wind, spring morning light", 512, 896],
156
  ],
157
+ inputs=[prompt_input, height_input, width_input], outputs=[video_output, seed_input], fn=generate_video, cache_examples="lazy"
158
  )
159
 
160
  if __name__ == "__main__":