artificialguybr commited on
Commit
d7950f8
Β·
verified Β·
1 Parent(s): 2c2f38a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -3
app.py CHANGED
@@ -259,6 +259,12 @@ def initialize_models():
259
  clip_image_encoder=clip_image_encoder,
260
  )
261
  pipeline.to(device=device)
 
 
 
 
 
 
262
  print("βœ… Pipeline created and moved to device")
263
 
264
  print("πŸ”„ Loading Wav2Vec models...")
@@ -343,6 +349,15 @@ def generate_video(
343
  audio_features = extract_audio_features(audio_path, wav2vec_processor, wav2vec_model)
344
  audio_embeds = audio_features.unsqueeze(0).to(device=device, dtype=config.weight_dtype)
345
 
 
 
 
 
 
 
 
 
 
346
  video_length = int(audio_clip.duration * fps)
347
  video_length = (
348
  int((video_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1
@@ -405,7 +420,6 @@ def generate_video(
405
  audio_start_frame = init_frames * 2
406
  audio_end_frame = (init_frames + current_partial_length) * 2
407
 
408
- # Ensure audio embeds are long enough
409
  if audio_embeds.shape[1] < audio_end_frame:
410
  repeat_times = (audio_end_frame // audio_embeds.shape[1]) + 1
411
  audio_embeds = audio_embeds.repeat(1, repeat_times, 1)
@@ -414,9 +428,9 @@ def generate_video(
414
 
415
  with torch.no_grad():
416
  sample = pipeline(
417
- prompt,
 
418
  num_frames=current_partial_length,
419
- negative_prompt=negative_prompt,
420
  audio_embeds=partial_audio_embeds,
421
  audio_scale=audio_scale,
422
  ip_mask=ip_mask,
 
259
  clip_image_encoder=clip_image_encoder,
260
  )
261
  pipeline.to(device=device)
262
+
263
+ if torch.__version__ >= "2.0":
264
+ print("πŸš€ Compiling the pipeline with torch.compile()...")
265
+ pipeline.transformer = torch.compile(pipeline.transformer, mode="reduce-overhead", fullgraph=True)
266
+ print("βœ… Pipeline transformer compiled!")
267
+
268
  print("βœ… Pipeline created and moved to device")
269
 
270
  print("πŸ”„ Loading Wav2Vec models...")
 
349
  audio_features = extract_audio_features(audio_path, wav2vec_processor, wav2vec_model)
350
  audio_embeds = audio_features.unsqueeze(0).to(device=device, dtype=config.weight_dtype)
351
 
352
+ progress(0.25, desc="Encoding prompts...")
353
+ prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
354
+ prompt,
355
+ device=device,
356
+ num_images_per_prompt=1,
357
+ do_classifier_free_guidance=(guidance_scale > 1.0),
358
+ negative_prompt=negative_prompt
359
+ )
360
+
361
  video_length = int(audio_clip.duration * fps)
362
  video_length = (
363
  int((video_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1
 
420
  audio_start_frame = init_frames * 2
421
  audio_end_frame = (init_frames + current_partial_length) * 2
422
 
 
423
  if audio_embeds.shape[1] < audio_end_frame:
424
  repeat_times = (audio_end_frame // audio_embeds.shape[1]) + 1
425
  audio_embeds = audio_embeds.repeat(1, repeat_times, 1)
 
428
 
429
  with torch.no_grad():
430
  sample = pipeline(
431
+ prompt_embeds=prompt_embeds,
432
+ negative_prompt_embeds=negative_prompt_embeds,
433
  num_frames=current_partial_length,
 
434
  audio_embeds=partial_audio_embeds,
435
  audio_scale=audio_scale,
436
  ip_mask=ip_mask,