multimodalart HF Staff commited on
Commit
9e195fc
·
verified ·
1 Parent(s): 7b7a87f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -114
app.py CHANGED
@@ -37,7 +37,6 @@ hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/
37
  print("Downloads complete.")
38
 
39
  # --- Boilerplate code from the original script ---
40
-
41
  def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
42
  """Returns the value at the given index of a sequence or mapping.
43
 
@@ -88,7 +87,6 @@ def add_comfyui_directory_to_sys_path() -> None:
88
  """
89
  Add 'ComfyUI' to the sys.path
90
  """
91
- # Use a more robust name to find the ComfyUI directory
92
  comfyui_path = find_path("ComfyUI")
93
  if comfyui_path is not None and os.path.isdir(comfyui_path):
94
  sys.path.append(comfyui_path)
@@ -132,8 +130,6 @@ def import_custom_nodes() -> None:
132
 
133
 
134
  # --- Model Loading and Caching ---
135
-
136
- # Dictionary to hold all loaded models and node instances
137
  MODELS_AND_NODES = {}
138
 
139
  print("Setting up ComfyUI paths...")
@@ -215,12 +211,21 @@ print("All models loaded successfully!")
215
 
216
  # --- Main Video Generation Logic ---
217
  @spaces.GPU(duration=120)
218
- def generate_video(start_image_pil: Image.Image, end_image_pil: Image.Image, prompt: str, negative_prompt: str, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
219
  """
220
  The main function to generate a video based on user inputs.
221
  This function is called every time the user clicks the 'Generate' button.
222
  """
223
- # Use pre-loaded models and nodes from the global dictionary
 
 
224
  clip = MODELS_AND_NODES["clip"]
225
  vae = MODELS_AND_NODES["vae"]
226
  model_low_noise = MODELS_AND_NODES["model_low_noise"]
@@ -246,122 +251,127 @@ def generate_video(start_image_pil: Image.Image, end_image_pil: Image.Image, pro
246
  start_image_path = start_file.name
247
  end_image_path = end_file.name
248
 
249
- try:
250
- with torch.inference_mode():
251
- progress(0.1, desc="Encoding text and images...")
252
- # --- Workflow execution ---
253
- positive_conditioning = cliptextencode.encode(text=prompt, clip=get_value_at_index(clip, 0))
254
- negative_conditioning = cliptextencode.encode(text=negative_prompt, clip=get_value_at_index(clip, 0))
255
-
256
- start_image_loaded = loadimage.load_image(image=start_image_path)
257
- end_image_loaded = loadimage.load_image(image=end_image_path)
258
-
259
- clip_vision_encoded_start = clipvisionencode.encode(
260
- crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(start_image_loaded, 0)
261
- )
262
- clip_vision_encoded_end = clipvisionencode.encode(
263
- crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(end_image_loaded, 0)
264
- )
265
-
266
- progress(0.2, desc="Preparing initial latents...")
267
- initial_latents = wanfirstlastframetovideo.EXECUTE_NORMALIZED(
268
- width=480, height=480, length=33, batch_size=1,
269
- positive=get_value_at_index(positive_conditioning, 0),
270
- negative=get_value_at_index(negative_conditioning, 0),
271
- vae=get_value_at_index(vae, 0),
272
- clip_vision_start_image=get_value_at_index(clip_vision_encoded_start, 0),
273
- clip_vision_end_image=get_value_at_index(clip_vision_encoded_end, 0),
274
- start_image=get_value_at_index(start_image_loaded, 0),
275
- end_image=get_value_at_index(end_image_loaded, 0),
276
- )
277
-
278
- progress(0.3, desc="Patching models...")
279
- model_low_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_low_noise, 0))
280
- model_low_final = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_low_patched, 0))
281
-
282
- model_high_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_high_noise, 0))
283
- model_high_final = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_high_patched, 0))
284
-
285
- progress(0.5, desc="Running KSampler (Step 1/2)...")
286
- latent_step1 = ksampleradvanced.sample(
287
- add_noise="enable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1,
288
- sampler_name="euler", scheduler="simple", start_at_step=0, end_at_step=4,
289
- return_with_leftover_noise="enable", model=get_value_at_index(model_high_final, 0),
290
- positive=get_value_at_index(initial_latents, 0),
291
- negative=get_value_at_index(initial_latents, 1),
292
- latent_image=get_value_at_index(initial_latents, 2),
293
- )
294
-
295
- progress(0.7, desc="Running KSampler (Step 2/2)...")
296
- latent_step2 = ksampleradvanced.sample(
297
- add_noise="disable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1,
298
- sampler_name="euler", scheduler="simple", start_at_step=4, end_at_step=10000,
299
- return_with_leftover_noise="disable", model=get_value_at_index(model_low_final, 0),
300
- positive=get_value_at_index(initial_latents, 0),
301
- negative=get_value_at_index(initial_latents, 1),
302
- latent_image=get_value_at_index(latent_step1, 0),
303
- )
304
-
305
- progress(0.8, desc="Decoding VAE...")
306
- decoded_images = vaedecode.decode(samples=get_value_at_index(latent_step2, 0), vae=get_value_at_index(vae, 0))
307
-
308
- progress(0.9, desc="Creating and saving video...")
309
- video_data = createvideo.create_video(fps=16, images=get_value_at_index(decoded_images, 0))
310
-
311
- # Save the video to ComfyUI's output directory
312
- save_result = savevideo.save_video(
313
- filename_prefix="GradioVideo", format="mp4", codec="h264",
314
- video=get_value_at_index(video_data, 0),
315
- )
316
-
317
- progress(1.0, desc="Done!")
318
- return f"output/{save_result['ui']['images'][0]['filename']}"
319
-
320
- finally:
321
- # Clean up the temporary image files
322
- os.unlink(start_image_path)
323
- os.unlink(end_image_path)
324
-
325
- # --- Gradio UI ---
326
-
327
- def create_gradio_app():
328
- with gr.Blocks(theme=gr.themes.Soft()) as app:
329
- gr.Markdown("# Image-to-Video Generation App")
330
- gr.Markdown("Upload a start and end frame, provide a prompt, and let the AI generate a video transitioning between them.")
331
-
332
- with gr.Row():
333
- start_image = gr.Image(type="pil", label="Start Frame")
334
- end_image = gr.Image(type="pil", label="End Frame")
335
 
336
- prompt = gr.Textbox(label="Prompt", value="the guy turns")
337
- negative_prompt = gr.Textbox(
338
- label="Negative Prompt",
339
- value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,"
 
340
  )
341
-
342
- generate_button = gr.Button("Generate Video", variant="primary")
343
-
344
- output_video = gr.Video(label="Generated Video")
345
 
346
- generate_button.click(
347
- fn=generate_video,
348
- inputs=[start_image, end_image, prompt, negative_prompt],
349
- outputs=output_video
 
 
 
 
 
 
350
  )
351
 
352
- gr.Examples(
353
- examples=[
354
- ["examples/start.png", "examples/end.png", "a beautiful woman smiling"],
355
- ["examples/start.png", "examples/end.png", "a robot walking through a futuristic city"],
356
- ],
357
- inputs=[start_image, end_image, prompt],
358
- outputs=output_video,
359
- fn=generate_video,
360
- cache_examples=False, # Set to True if you want to pre-compute examples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  )
362
 
363
- return app
 
364
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
 
366
  if __name__ == "__main__":
367
  app = create_gradio_app()
 
37
  print("Downloads complete.")
38
 
39
  # --- Boilerplate code from the original script ---
 
40
  def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
41
  """Returns the value at the given index of a sequence or mapping.
42
 
 
87
  """
88
  Add 'ComfyUI' to the sys.path
89
  """
 
90
  comfyui_path = find_path("ComfyUI")
91
  if comfyui_path is not None and os.path.isdir(comfyui_path):
92
  sys.path.append(comfyui_path)
 
130
 
131
 
132
  # --- Model Loading and Caching ---
 
 
133
  MODELS_AND_NODES = {}
134
 
135
  print("Setting up ComfyUI paths...")
 
211
 
212
  # --- Main Video Generation Logic ---
213
  @spaces.GPU(duration=120)
214
+ def generate_video(
215
+ start_image_pil,
216
+ end_image_pil,
217
+ prompt,
218
+ negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,",
219
+ duration=2,
220
+ progress=gr.Progress(track_tqdm=True)
221
+ ):
222
  """
223
  The main function to generate a video based on user inputs.
224
  This function is called every time the user clicks the 'Generate' button.
225
  """
226
+ FPS = 16
227
+ num_frames = max(2, int(duration * FPS))
228
+
229
  clip = MODELS_AND_NODES["clip"]
230
  vae = MODELS_AND_NODES["vae"]
231
  model_low_noise = MODELS_AND_NODES["model_low_noise"]
 
251
  start_image_path = start_file.name
252
  end_image_path = end_file.name
253
 
254
+ with torch.inference_mode():
255
+ progress(0.1, desc="Encoding text and images...")
256
+ # --- Workflow execution ---
257
+ positive_conditioning = cliptextencode.encode(text=prompt, clip=get_value_at_index(clip, 0))
258
+ negative_conditioning = cliptextencode.encode(text=negative_prompt, clip=get_value_at_index(clip, 0))
259
+
260
+ start_image_loaded = loadimage.load_image(image=start_image_path)
261
+ end_image_loaded = loadimage.load_image(image=end_image_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
+ clip_vision_encoded_start = clipvisionencode.encode(
264
+ crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(start_image_loaded, 0)
265
+ )
266
+ clip_vision_encoded_end = clipvisionencode.encode(
267
+ crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(end_image_loaded, 0)
268
  )
 
 
 
 
269
 
270
+ progress(0.2, desc="Preparing initial latents...")
271
+ initial_latents = wanfirstlastframetovideo.EXECUTE_NORMALIZED(
272
+ width=480, height=480, length=num_frames, batch_size=1,
273
+ positive=get_value_at_index(positive_conditioning, 0),
274
+ negative=get_value_at_index(negative_conditioning, 0),
275
+ vae=get_value_at_index(vae, 0),
276
+ clip_vision_start_image=get_value_at_index(clip_vision_encoded_start, 0),
277
+ clip_vision_end_image=get_value_at_index(clip_vision_encoded_end, 0),
278
+ start_image=get_value_at_index(start_image_loaded, 0),
279
+ end_image=get_value_at_index(end_image_loaded, 0),
280
  )
281
 
282
+ progress(0.3, desc="Patching models...")
283
+ model_low_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_low_noise, 0))
284
+ model_low_final = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_low_patched, 0))
285
+
286
+ model_high_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_high_noise, 0))
287
+ model_high_final = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_high_patched, 0))
288
+
289
+ progress(0.5, desc="Running KSampler (Step 1/2)...")
290
+ latent_step1 = ksampleradvanced.sample(
291
+ add_noise="enable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1,
292
+ sampler_name="euler", scheduler="simple", start_at_step=0, end_at_step=4,
293
+ return_with_leftover_noise="enable", model=get_value_at_index(model_high_final, 0),
294
+ positive=get_value_at_index(initial_latents, 0),
295
+ negative=get_value_at_index(initial_latents, 1),
296
+ latent_image=get_value_at_index(initial_latents, 2),
297
+ )
298
+
299
+ progress(0.7, desc="Running KSampler (Step 2/2)...")
300
+ latent_step2 = ksampleradvanced.sample(
301
+ add_noise="disable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1,
302
+ sampler_name="euler", scheduler="simple", start_at_step=4, end_at_step=10000,
303
+ return_with_leftover_noise="disable", model=get_value_at_index(model_low_final, 0),
304
+ positive=get_value_at_index(initial_latents, 0),
305
+ negative=get_value_at_index(initial_latents, 1),
306
+ latent_image=get_value_at_index(latent_step1, 0),
307
  )
308
 
309
+ progress(0.8, desc="Decoding VAE...")
310
+ decoded_images = vaedecode.decode(samples=get_value_at_index(latent_step2, 0), vae=get_value_at_index(vae, 0))
311
 
312
+ progress(0.9, desc="Creating and saving video...")
313
+ video_data = createvideo.create_video(fps=FPS, images=get_value_at_index(decoded_images, 0))
314
+
315
+ # Save the video to ComfyUI's output directory
316
+ save_result = savevideo.save_video(
317
+ filename_prefix="GradioVideo", format="mp4", codec="h264",
318
+ video=get_value_at_index(video_data, 0),
319
+ )
320
+
321
+ progress(1.0, desc="Done!")
322
+ return f"output/{save_result['ui']['images'][0]['filename']}"
323
+
324
+ css = '''
325
+ .fillable{max-width: 980px !important}
326
+ .dark .progress-text {color: white}
327
+ '''
328
+ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app:
329
+ gr.Markdown("# Wan 2.2 First/Last Frame Video Fast")
330
+ gr.Markdown("Running the [Wan 2.2 First/Last Frame ComfyUI workflow](https://www.reddit.com/r/StableDiffusion/comments/1me4306/psa_wan_22_does_first_frame_last_frame_out_of_the/) on ZeroGPU")
331
+ with gr.Row():
332
+ with gr.Column():
333
+ with gr.Row():
334
+ start_image = gr.Image(type="pil", label="Start Frame")
335
+ end_image = gr.Image(type="pil", label="End Frame")
336
+
337
+ prompt = gr.Textbox(label="Prompt", info="Describe the transition between the two images", value="transition")
338
+
339
+ with gr.Accordion("Advanced Settings", open=False):
340
+ duration = gr.Slider(
341
+ minimum=1.0,
342
+ maximum=5.0,
343
+ value=2.0,
344
+ step=0.1,
345
+ label="Video Duration (seconds)",
346
+ info="Longer videos take longer to generate"
347
+ )
348
+ negative_prompt = gr.Textbox(
349
+ label="Negative Prompt",
350
+ value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,",
351
+ visible=False
352
+ )
353
+
354
+ generate_button = gr.Button("Generate Video", variant="primary")
355
+
356
+ with gr.Column():
357
+ output_video = gr.Video(label="Generated Video")
358
+
359
+ generate_button.click(
360
+ fn=generate_video,
361
+ inputs=[start_image, end_image, prompt, negative_prompt, duration],
362
+ outputs=output_video
363
+ )
364
+
365
+ gr.Examples(
366
+ examples=[
367
+ ["poli_tower.png", "tower_takes_off.png", "the man turns"],
368
+ ["capybara_zoomed.png", "capybara.webp", "a dramatic dolly zoom"],
369
+ ],
370
+ inputs=[start_image, end_image, prompt],
371
+ outputs=output_video,
372
+ fn=generate_video,
373
+ cache_examples="lazy",
374
+ )
375
 
376
  if __name__ == "__main__":
377
  app = create_gradio_app()