Fabrice-TIERCELIN commited on
Commit
ee44bdd
·
verified ·
1 Parent(s): e652791

Use function

Browse files
Files changed (1) hide show
  1. app.py +33 -30
app.py CHANGED
@@ -749,6 +749,38 @@ def worker_start_end(input_image, end_image, image_position, prompts, n_prompt,
749
  def callback(d):
750
  return
751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
752
  for latent_padding in latent_paddings:
753
  is_last_section = latent_padding == 0
754
  is_first_section = latent_padding == latent_paddings[0]
@@ -815,36 +847,7 @@ def worker_start_end(input_image, end_image, image_position, prompts, n_prompt,
815
  callback=callback,
816
  )
817
 
818
- if is_last_section:
819
- generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
820
-
821
- total_generated_latent_frames += int(generated_latents.shape[2])
822
- history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
823
-
824
- if not high_vram:
825
- offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=8)
826
- load_model_as_complete(vae, target_device=gpu)
827
-
828
- if history_pixels is None:
829
- history_pixels = vae_decode(history_latents[:, :, :total_generated_latent_frames, :, :], vae).cpu()
830
- else:
831
- section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
832
- overlapped_frames = latent_window_size * 4 - 3
833
-
834
- current_pixels = vae_decode(history_latents[:, :, :min(total_generated_latent_frames, section_latent_frames)], vae).cpu()
835
- history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
836
-
837
- if not high_vram:
838
- unload_complete_models(vae)
839
-
840
- if enable_preview or is_last_section:
841
- output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
842
-
843
- save_bcthw_as_mp4(history_pixels, output_filename, fps=fps_number, crf=mp4_crf)
844
-
845
- print(f'Decoded. Pixel shape {history_pixels.shape}')
846
-
847
- stream.output_queue.push(('file', output_filename))
848
 
849
  if is_last_section:
850
  break
 
749
  def callback(d):
750
  return
751
 
752
+ def post_process(job_id, start_latent, generated_latents, total_generated_latent_frames, history_latents, high_vram, transformer, gpu, vae, history_pixels, latent_window_size, enable_preview, outputs_folder, mp4_crf, stream, is_last_section):
753
+ if is_last_section:
754
+ generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
755
+
756
+ total_generated_latent_frames += int(generated_latents.shape[2])
757
+ history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
758
+
759
+ if not high_vram:
760
+ offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=8)
761
+ load_model_as_complete(vae, target_device=gpu)
762
+
763
+ if history_pixels is None:
764
+ history_pixels = vae_decode(history_latents[:, :, :total_generated_latent_frames, :, :], vae).cpu()
765
+ else:
766
+ section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
767
+ overlapped_frames = latent_window_size * 4 - 3
768
+
769
+ current_pixels = vae_decode(history_latents[:, :, :min(total_generated_latent_frames, section_latent_frames)], vae).cpu()
770
+ history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
771
+
772
+ if not high_vram:
773
+ unload_complete_models(vae)
774
+
775
+ if enable_preview or is_last_section:
776
+ output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
777
+
778
+ save_bcthw_as_mp4(history_pixels, output_filename, fps=fps_number, crf=mp4_crf)
779
+
780
+ print(f'Decoded. Pixel shape {history_pixels.shape}')
781
+
782
+ stream.output_queue.push(('file', output_filename))
783
+
784
  for latent_padding in latent_paddings:
785
  is_last_section = latent_padding == 0
786
  is_first_section = latent_padding == latent_paddings[0]
 
847
  callback=callback,
848
  )
849
 
850
+ [total_generated_latent_frames, history_latents, history_pixels] = post_process(job_id, start_latent, generated_latents, total_generated_latent_frames, history_latents, high_vram, transformer, gpu, vae, history_pixels, latent_window_size, enable_preview, outputs_folder, mp4_crf, stream, is_last_section)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
851
 
852
  if is_last_section:
853
  break