Lemonator commited on
Commit
5751ffd
Β·
verified Β·
1 Parent(s): 5e15674

Update app_lora.py

Browse files
Files changed (1) hide show
  1. app_lora.py +238 -108
app_lora.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler
3
  from diffusers.utils import export_to_video
@@ -7,172 +8,301 @@ import tempfile
7
  import os
8
  import subprocess
9
 
10
- # The spaces library IS required for ZeroGPU.
11
- import spaces
12
  from huggingface_hub import hf_hub_download
13
  import numpy as np
14
  from PIL import Image
15
  import random
16
 
17
  import warnings
18
- warnings.filterwarnings("ignore")
 
19
 
20
  MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
 
21
  LORA_REPO_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
22
  LORA_FILENAME = "FusionX_LoRa/Wan2.1_I2V_14B_FusionX_LoRA.safetensors"
23
 
24
- # --- Global variable for the pipeline ---
25
- # We use a global variable to cache the model between calls.
26
- pipe = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # --- Constants and Helper Functions ---
29
  MOD_VALUE = 32
30
- DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE = 640, 1024
31
- NEW_FORMULA_MAX_AREA = 640.0 * 1024.0
 
 
32
  SLIDER_MIN_H, SLIDER_MAX_H = 128, 1024
33
  SLIDER_MIN_W, SLIDER_MAX_W = 128, 1024
34
  MAX_SEED = np.iinfo(np.int32).max
35
- FIXED_FPS, MIN_FRAMES_MODEL, MAX_FRAMES_MODEL = 24, 8, 240
 
 
 
 
36
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
37
  default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
38
 
39
 
40
- def get_duration(duration_seconds):
41
- """
42
- Dynamically set the timeout for the @spaces.GPU decorator based on video length.
43
- """
44
- if duration_seconds > 7: return 180
45
- if duration_seconds > 5: return 120
46
- if duration_seconds > 3: return 90
47
- return 60
48
-
49
- # --- The Main GPU Function ---
50
- # The @spaces.GPU decorator is ESSENTIAL for ZeroGPU.
51
- # It tells the platform that this function needs a GPU.
52
- @spaces.GPU(duration=60) # Default duration, can be updated dynamically
53
- def generate_video(input_image, prompt, height, width,
54
- negative_prompt, duration_seconds,
55
- guidance_scale, steps, seed, randomize_seed,
56
- progress=gr.Progress(track_tqdm=True)):
57
 
58
- global pipe
 
 
 
 
59
 
60
- # --- LAZY LOADING of the model ---
61
- # This block will only run on the very first generation request.
62
- if pipe is None:
63
- progress(0, desc="Cold start: Initializing model...")
64
- print("Cold start: Initializing model pipeline...")
65
-
66
- image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float16)
67
- vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float16)
68
-
69
- pipe = WanImageToVideoPipeline.from_pretrained(
70
- MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.float16
 
 
71
  )
72
- pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
73
- pipe.enable_model_cpu_offload()
74
-
75
- try:
76
- causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
77
- pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
78
- pipe.set_adapters(["causvid_lora"], adapter_weights=[0.75])
79
- pipe.fuse_lora()
80
- print("βœ… LoRA loaded successfully.")
81
- except Exception as e:
82
- raise gr.Error(f"Error loading LoRA: {e}")
83
-
84
- print("βœ… Pipeline initialized successfully.")
 
 
 
 
 
 
85
 
86
- # Update the GPU duration based on user input for longer videos.
87
- spaces.set_timeout(get_duration(duration_seconds))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  if input_image is None:
90
  raise gr.Error("Please upload an input image.")
91
 
92
  target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
93
  target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
 
 
94
  raw_frames = int(round(duration_seconds * FIXED_FPS))
 
95
  num_frames = ((raw_frames - 1) // 4) * 4 + 1
96
  num_frames = np.clip(num_frames, MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
97
-
98
- if num_frames > 120 and max(target_h, target_w) > 768:
99
- scale_factor = 768 / max(target_h, target_w)
100
- target_h = max(MOD_VALUE, int(target_h * scale_factor) // MOD_VALUE * MOD_VALUE)
101
- target_w = max(MOD_VALUE, int(target_w * scale_factor) // MOD_VALUE * MOD_VALUE)
102
- gr.Info(f"Reduced resolution to {target_w}x{target_h} for long video.")
103
-
 
 
 
 
 
 
104
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
 
105
  resized_image = input_image.resize((target_w, target_h), Image.Resampling.LANCZOS)
106
-
107
- try:
 
108
  torch.cuda.empty_cache()
109
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
110
- output_frames_list = pipe(
111
- image=resized_image, prompt=prompt, negative_prompt=negative_prompt,
112
- height=target_h, width=target_w, num_frames=num_frames,
113
- guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
114
- generator=torch.Generator(device="cuda").manual_seed(current_seed),
115
- callback_on_step_end=lambda p, s, t: progress(s/int(steps))
116
- ).frames[0]
 
 
 
 
 
 
 
 
 
117
  except torch.cuda.OutOfMemoryError:
118
- raise gr.Error("Out of GPU memory. Try reducing duration or resolution.")
119
- finally:
 
 
 
 
 
 
120
  torch.cuda.empty_cache()
121
 
122
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
123
  video_path = tmpfile.name
124
- # (Video export logic is unchanged)
125
- import imageio
126
- writer = imageio.get_writer(video_path, fps=FIXED_FPS, codec='libx264', pixelformat='yuv420p', quality=8)
127
- for frame in output_frames_list:
128
- writer.append_data(np.array(frame))
129
- writer.close()
130
 
131
- return video_path, current_seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
- # --- Gradio UI ---
134
- # (Helper functions for UI are unchanged)
135
- def handle_image_upload_for_dims_wan(uploaded_pil_image):
136
- if uploaded_pil_image is None: return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
137
- try:
138
- orig_w, orig_h = uploaded_pil_image.size
139
- aspect_ratio = orig_h / orig_w
140
- calc_h = round(np.sqrt(NEW_FORMULA_MAX_AREA * aspect_ratio))
141
- calc_w = round(np.sqrt(NEW_FORMULA_MAX_AREA / aspect_ratio))
142
- calc_h = max(MOD_VALUE, (calc_h // MOD_VALUE) * MOD_VALUE)
143
- calc_w = max(MOD_VALUE, (calc_w // MOD_VALUE) * MOD_VALUE)
144
- new_h = int(np.clip(calc_h, SLIDER_MIN_H, SLIDER_MAX_H))
145
- new_w = int(np.clip(calc_w, SLIDER_MIN_W, SLIDER_MAX_W))
146
- return gr.update(value=new_h), gr.update(value=new_w)
147
- except: return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
148
 
 
149
  with gr.Blocks() as demo:
150
- gr.Markdown("# Wan 2.1 I2V FusionX-LoRA (ZeroGPU Ready)")
151
- gr.Markdown("The first generation will be slow due to a 'cold start'. Subsequent generations will be much faster.")
152
 
153
  with gr.Row():
154
  with gr.Column():
155
- input_image_component = gr.Image(type="pil", label="Input Image")
156
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
157
- duration_seconds_input = gr.Slider(minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1), maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1), step=0.1, value=2, label="Duration (seconds)")
 
 
 
 
 
 
 
158
  with gr.Accordion("Advanced Settings", open=False):
159
  negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
160
- seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
161
- randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True)
162
  with gr.Row():
163
- height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label="Height")
164
- width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label="Width")
165
- steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=4, label="Inference Steps")
166
  guidance_scale_input = gr.Slider(minimum=0.0, maximum=20.0, step=0.5, value=1.0, label="Guidance Scale", visible=False)
 
167
  generate_button = gr.Button("Generate Video", variant="primary")
168
  with gr.Column():
169
  video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
170
- gr.Markdown("### Tips:\n- Longer videos need more memory.\n- 4-8 steps is optimal.")
 
 
 
171
 
172
- input_image_component.upload(fn=handle_image_upload_for_dims_wan, inputs=input_image_component, outputs=[height_input, width_input])
 
 
 
 
 
 
 
 
 
 
173
 
174
- ui_inputs = [input_image_component, prompt_input, height_input, width_input, negative_prompt_input, duration_seconds_input, guidance_scale_input, steps_slider, seed_input, randomize_seed_checkbox]
 
 
 
 
175
  generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input])
176
 
 
 
 
 
 
 
 
 
 
 
 
177
  if __name__ == "__main__":
178
- demo.queue().launch()
 
1
+ import spaces
2
  import torch
3
  from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler
4
  from diffusers.utils import export_to_video
 
8
  import os
9
  import subprocess
10
 
 
 
11
  from huggingface_hub import hf_hub_download
12
  import numpy as np
13
  from PIL import Image
14
  import random
15
 
16
  import warnings
17
+ warnings.filterwarnings("ignore", message=".*Attempting to use legacy OpenCV backend.*")
18
+ warnings.filterwarnings("ignore", message=".*num_frames - 1.*")
19
 
20
  MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
21
+
22
  LORA_REPO_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
23
  LORA_FILENAME = "FusionX_LoRa/Wan2.1_I2V_14B_FusionX_LoRA.safetensors"
24
 
25
+ # Initialize models with proper dtype handling
26
+ image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float16)
27
+ vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float16)
28
+ pipe = WanImageToVideoPipeline.from_pretrained(
29
+ MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.float16
30
+ )
31
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
32
+
33
+ # Enable memory efficient attention and CPU offloading for large videos
34
+ pipe.enable_model_cpu_offload()
35
+ pipe.enable_vae_slicing()
36
+ pipe.enable_vae_tiling()
37
+
38
+ try:
39
+ causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
40
+ print("βœ… LoRA downloaded to:", causvid_path)
41
+
42
+ pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
43
+ pipe.set_adapters(["causvid_lora"], adapter_weights=[0.75])
44
+ pipe.fuse_lora()
45
+
46
+ except Exception as e:
47
+ import traceback
48
+ print("❌ Error during LoRA loading:")
49
+ traceback.print_exc()
50
 
 
51
  MOD_VALUE = 32
52
+ DEFAULT_H_SLIDER_VALUE = 640
53
+ DEFAULT_W_SLIDER_VALUE = 1024
54
+ NEW_FORMULA_MAX_AREA = 640.0 * 1024.0
55
+
56
  SLIDER_MIN_H, SLIDER_MAX_H = 128, 1024
57
  SLIDER_MIN_W, SLIDER_MAX_W = 128, 1024
58
  MAX_SEED = np.iinfo(np.int32).max
59
+
60
+ FIXED_FPS = 24
61
+ MIN_FRAMES_MODEL = 8 # Minimum 8 frames (~0.33s)
62
+ MAX_FRAMES_MODEL = 240 # Maximum 240 frames (10 seconds at 24fps)
63
+
64
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
65
  default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
66
 
67
 
68
+ def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area,
69
+ min_slider_h, max_slider_h,
70
+ min_slider_w, max_slider_w,
71
+ default_h, default_w):
72
+ orig_w, orig_h = pil_image.size
73
+ if orig_w <= 0 or orig_h <= 0:
74
+ return default_h, default_w
75
+
76
+ aspect_ratio = orig_h / orig_w
 
 
 
 
 
 
 
 
77
 
78
+ calc_h = round(np.sqrt(calculation_max_area * aspect_ratio))
79
+ calc_w = round(np.sqrt(calculation_max_area / aspect_ratio))
80
+
81
+ calc_h = max(mod_val, (calc_h // mod_val) * mod_val)
82
+ calc_w = max(mod_val, (calc_w // mod_val) * mod_val)
83
 
84
+ new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val))
85
+ new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val))
86
+
87
+ return new_h, new_w
88
+
89
+ def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_w_val):
90
+ if uploaded_pil_image is None:
91
+ return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
92
+ try:
93
+ new_h, new_w = _calculate_new_dimensions_wan(
94
+ uploaded_pil_image, MOD_VALUE, NEW_FORMULA_MAX_AREA,
95
+ SLIDER_MIN_H, SLIDER_MAX_H, SLIDER_MIN_W, SLIDER_MAX_W,
96
+ DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
97
  )
98
+ return gr.update(value=new_h), gr.update(value=new_w)
99
+ except Exception as e:
100
+ gr.Warning("Error attempting to calculate new dimensions")
101
+ return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
102
+
103
+ def get_duration(input_image, prompt, height, width,
104
+ negative_prompt, duration_seconds,
105
+ guidance_scale, steps,
106
+ seed, randomize_seed,
107
+ progress):
108
+ # Adjust timeout based on video length and complexity
109
+ if duration_seconds > 7:
110
+ return 180 # 3 minutes for very long videos
111
+ elif duration_seconds > 5:
112
+ return 120 # 2 minutes for long videos
113
+ elif duration_seconds > 3:
114
+ return 90 # 1.5 minutes for medium videos
115
+ else:
116
+ return 60 # 1 minute for short videos
117
 
118
+ def export_video_with_ffmpeg(frames, output_path, fps=24):
119
+ """Export video using imageio if available, otherwise fall back to OpenCV"""
120
+ try:
121
+ import imageio
122
+ # Use imageio for better quality
123
+ writer = imageio.get_writer(output_path, fps=fps, codec='libx264',
124
+ pixelformat='yuv420p', quality=8)
125
+ for frame in frames:
126
+ writer.append_data(np.array(frame))
127
+ writer.close()
128
+ return True
129
+ except ImportError:
130
+ # Fall back to OpenCV
131
+ export_to_video(frames, output_path, fps=fps)
132
+ return False
133
+
134
+ @spaces.GPU(duration=get_duration)
135
+ def generate_video(input_image, prompt, height, width,
136
+ negative_prompt=default_negative_prompt, duration_seconds=2,
137
+ guidance_scale=1, steps=4,
138
+ seed=42, randomize_seed=False,
139
+ progress=gr.Progress(track_tqdm=True)):
140
 
141
  if input_image is None:
142
  raise gr.Error("Please upload an input image.")
143
 
144
  target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
145
  target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
146
+
147
+ # Calculate frames with proper alignment
148
  raw_frames = int(round(duration_seconds * FIXED_FPS))
149
+ # Ensure num_frames-1 is divisible by 4 as required by the model
150
  num_frames = ((raw_frames - 1) // 4) * 4 + 1
151
  num_frames = np.clip(num_frames, MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
152
+
153
+ # Additional check for very long videos
154
+ if num_frames > 120:
155
+ # For videos longer than 5 seconds, reduce resolution to manage memory
156
+ max_dim = max(target_h, target_w)
157
+ if max_dim > 768:
158
+ scale_factor = 768 / max_dim
159
+ target_h = max(MOD_VALUE, (int(target_h * scale_factor) // MOD_VALUE) * MOD_VALUE)
160
+ target_w = max(MOD_VALUE, (int(target_w * scale_factor) // MOD_VALUE) * MOD_VALUE)
161
+ gr.Info(f"Reduced resolution to {target_w}x{target_h} for long video generation")
162
+
163
+ print(f"Generating {num_frames} frames (requested {raw_frames}) at {target_w}x{target_h}")
164
+
165
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
166
+
167
  resized_image = input_image.resize((target_w, target_h), Image.Resampling.LANCZOS)
168
+
169
+ # Clear GPU cache before generation
170
+ if torch.cuda.is_available():
171
  torch.cuda.empty_cache()
172
+
173
+ try:
174
+ with torch.inference_mode():
175
+ # Generate video with autocast for memory efficiency
176
+ with torch.autocast("cuda", dtype=torch.float16):
177
+ output_frames_list = pipe(
178
+ image=resized_image,
179
+ prompt=prompt,
180
+ negative_prompt=negative_prompt,
181
+ height=target_h,
182
+ width=target_w,
183
+ num_frames=num_frames,
184
+ guidance_scale=float(guidance_scale),
185
+ num_inference_steps=int(steps),
186
+ generator=torch.Generator(device="cuda").manual_seed(current_seed),
187
+ return_dict=True
188
+ ).frames[0]
189
  except torch.cuda.OutOfMemoryError:
190
+ torch.cuda.empty_cache()
191
+ raise gr.Error("Out of GPU memory. Try reducing the duration or resolution.")
192
+ except Exception as e:
193
+ torch.cuda.empty_cache()
194
+ raise gr.Error(f"Generation failed: {str(e)}")
195
+
196
+ # Clear cache after generation
197
+ if torch.cuda.is_available():
198
  torch.cuda.empty_cache()
199
 
200
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
201
  video_path = tmpfile.name
 
 
 
 
 
 
202
 
203
+ # Export using imageio if available, otherwise OpenCV
204
+ used_imageio = export_video_with_ffmpeg(output_frames_list, video_path, fps=FIXED_FPS)
205
+
206
+ # Only try FFmpeg optimization if we have a valid video file
207
+ if os.path.exists(video_path) and os.path.getsize(video_path) > 0:
208
+ try:
209
+ # Check if ffmpeg is available
210
+ subprocess.run(['ffmpeg', '-version'], capture_output=True, check=True)
211
+
212
+ optimized_path = video_path + "_opt.mp4"
213
+ cmd = [
214
+ 'ffmpeg',
215
+ '-y', # Overwrite without asking
216
+ '-i', video_path, # Input file
217
+ '-c:v', 'libx264', # Codec
218
+ '-pix_fmt', 'yuv420p', # Pixel format
219
+ '-profile:v', 'main', # Compatibility profile
220
+ '-level', '4.0', # Support for higher resolutions
221
+ '-movflags', '+faststart', # Streaming optimized
222
+ '-crf', '23', # Quality level
223
+ '-preset', 'medium', # Balance between speed and compression
224
+ '-maxrate', '10M', # Max bitrate for large videos
225
+ '-bufsize', '20M', # Buffer size
226
+ optimized_path
227
+ ]
228
+
229
+ result = subprocess.run(cmd, capture_output=True, text=True)
230
+
231
+ if result.returncode == 0 and os.path.exists(optimized_path) and os.path.getsize(optimized_path) > 0:
232
+ os.unlink(video_path) # Remove original
233
+ video_path = optimized_path
234
+ else:
235
+ print(f"FFmpeg optimization failed: {result.stderr}")
236
+
237
+ except (subprocess.CalledProcessError, FileNotFoundError):
238
+ print("FFmpeg not available or optimization failed, using original export")
239
 
240
+ return video_path, current_seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
+ # Gradio Interface
243
  with gr.Blocks() as demo:
244
+ gr.Markdown("# Fast 4 steps Wan 2.1 I2V (14B) FusionX-LoRA")
245
+ gr.Markdown("Generate videos up to 10 seconds long! Longer videos may use reduced resolution for stability.")
246
 
247
  with gr.Row():
248
  with gr.Column():
249
+ input_image_component = gr.Image(type="pil", label="Input Image (auto-resized to target H/W)")
250
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
251
+ duration_seconds_input = gr.Slider(
252
+ minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1), # 0.3s (8 frames)
253
+ maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1), # 10.0s (240 frames)
254
+ step=0.1,
255
+ value=2, # Default 2 seconds
256
+ label="Duration (seconds)",
257
+ info=f"Video length: {MIN_FRAMES_MODEL/FIXED_FPS:.1f}-{MAX_FRAMES_MODEL/FIXED_FPS:.1f}s. Longer videos may take more time and use more memory."
258
+ )
259
  with gr.Accordion("Advanced Settings", open=False):
260
  negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
261
+ seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
262
+ randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
263
  with gr.Row():
264
+ height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height (multiple of {MOD_VALUE})")
265
+ width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width (multiple of {MOD_VALUE})")
266
+ steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=4, label="Inference Steps")
267
  guidance_scale_input = gr.Slider(minimum=0.0, maximum=20.0, step=0.5, value=1.0, label="Guidance Scale", visible=False)
268
+
269
  generate_button = gr.Button("Generate Video", variant="primary")
270
  with gr.Column():
271
  video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
272
+ gr.Markdown("### Tips for best results:")
273
+ gr.Markdown("- For videos longer than 5 seconds, consider using lower resolutions (512-768px)")
274
+ gr.Markdown("- Clear, simple prompts often work better than complex descriptions")
275
+ gr.Markdown("- The model works best with 4-8 inference steps")
276
 
277
+ input_image_component.upload(
278
+ fn=handle_image_upload_for_dims_wan,
279
+ inputs=[input_image_component, height_input, width_input],
280
+ outputs=[height_input, width_input]
281
+ )
282
+
283
+ input_image_component.clear(
284
+ fn=handle_image_upload_for_dims_wan,
285
+ inputs=[input_image_component, height_input, width_input],
286
+ outputs=[height_input, width_input]
287
+ )
288
 
289
+ ui_inputs = [
290
+ input_image_component, prompt_input, height_input, width_input,
291
+ negative_prompt_input, duration_seconds_input,
292
+ guidance_scale_input, steps_slider, seed_input, randomize_seed_checkbox
293
+ ]
294
  generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input])
295
 
296
+ gr.Examples(
297
+ examples=[
298
+ ["peng.png", "a penguin playfully dancing in the snow, Antarctica", 896, 512],
299
+ ["forg.jpg", "the frog jumps around", 448, 832],
300
+ ],
301
+ inputs=[input_image_component, prompt_input, height_input, width_input],
302
+ outputs=[video_output, seed_input],
303
+ fn=generate_video,
304
+ cache_examples="lazy"
305
+ )
306
+
307
  if __name__ == "__main__":
308
+ demo.queue(max_size=3).launch()