Antoni Bigata commited on
Commit
2dff4e4
·
1 Parent(s): cf0da47

requirements

Browse files
Files changed (1) hide show
  1. app.py +65 -51
app.py CHANGED
@@ -186,54 +186,17 @@ DEFAULT_AUDIO_PATH = os.path.join(
186
  # landmarks_extractor,
187
  # ) = load_all_models()
188
 
189
- with spaces.GPU(duration=60) as gpu:
190
- vae_model = VaeWrapper("video")
191
-
192
- vae_model = vae_model.half() # Convert to half precision
193
- try:
194
- vae_model = torch.compile(vae_model)
195
- print("Successfully compiled vae_model in FP16")
196
- except Exception as e:
197
- print(f"Warning: Failed to compile vae_model: {e}")
198
-
199
- hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda()
200
- hubert_model = hubert_model.half() # Convert to half precision
201
- try:
202
- hubert_model = torch.compile(hubert_model)
203
- print("Successfully compiled hubert_model in FP16")
204
- except Exception as e:
205
- print(f"Warning: Failed to compile hubert_model: {e}")
206
-
207
- wavlm_model = WavLM_wrapper(
208
- model_size="Base+",
209
- feed_as_frames=False,
210
- merge_type="None",
211
- model_path=os.path.join(repo_path, "checkpoints/WavLM-Base+.pt"),
212
- ).cuda()
213
-
214
- wavlm_model = wavlm_model.half() # Convert to half precision
215
- try:
216
- wavlm_model = torch.compile(wavlm_model)
217
- print("Successfully compiled wavlm_model in FP16")
218
- except Exception as e:
219
- print(f"Warning: Failed to compile wavlm_model: {e}")
220
-
221
- landmarks_extractor = LandmarksExtractor()
222
- keyframe_model = load_model(
223
- config="keyframe.yaml",
224
- ckpt=os.path.join(repo_path, "checkpoints/keyframe_dub.pt"),
225
- )
226
- interpolation_model = load_model(
227
- config="interpolation.yaml",
228
- ckpt=os.path.join(repo_path, "checkpoints/interpolation_dub.pt"),
229
- )
230
- keyframe_model.en_and_decode_n_samples_a_time = 2
231
- interpolation_model.en_and_decode_n_samples_a_time = 2
232
 
233
 
234
  @spaces.GPU(duration=60)
235
  @torch.no_grad()
236
- def compute_video_embedding(video_reader, min_len):
237
  """Compute embeddings from video"""
238
 
239
  total_frames = min_len
@@ -283,7 +246,7 @@ def compute_video_embedding(video_reader, min_len):
283
 
284
  @spaces.GPU(duration=120)
285
  @torch.no_grad()
286
- def compute_hubert_embedding(raw_audio):
287
  """Compute embeddings from audio"""
288
  print(f"Computing audio embedding from {raw_audio.shape}")
289
 
@@ -330,7 +293,7 @@ def compute_hubert_embedding(raw_audio):
330
 
331
  @spaces.GPU(duration=120)
332
  @torch.no_grad()
333
- def compute_wavlm_embedding(raw_audio):
334
  """Compute embeddings from audio"""
335
  audio = rearrange(raw_audio, "(f s) -> f s", s=640)
336
 
@@ -369,7 +332,7 @@ def compute_wavlm_embedding(raw_audio):
369
 
370
 
371
  @torch.no_grad()
372
- def extract_video_landmarks(video_frames):
373
  """Extract landmarks from video frames"""
374
 
375
  # Create a progress bar for Gradio
@@ -666,6 +629,57 @@ def process_video(video_input, audio_input, max_num_seconds):
666
  duration=10,
667
  )
668
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
669
  # Use default media if none provided
670
  if video_input is None:
671
  video_input = DEFAULT_VIDEO_PATH
@@ -749,9 +763,9 @@ def process_video(video_input, audio_input, max_num_seconds):
749
 
750
  # Compute video embeddings and landmarks - store full version in cache
751
  video_embedding, video_frames = compute_video_embedding(
752
- video_reader, len(video_reader)
753
  )
754
- video_landmarks = extract_video_landmarks(video_frames)
755
 
756
  # Update video cache with full versions
757
  cache["video"]["path"] = video_path_hash
@@ -807,8 +821,8 @@ def process_video(video_input, audio_input, max_num_seconds):
807
  print("Computing audio embeddings")
808
 
809
  # Compute audio embeddings with the truncated audio
810
- hubert_embedding = compute_hubert_embedding(raw_audio_reshape)
811
- wavlm_embedding = compute_wavlm_embedding(raw_audio_reshape)
812
 
813
  # Update audio cache with full embeddings
814
  # Note: raw_audio was already cached above
 
186
  # landmarks_extractor,
187
  # ) = load_all_models()
188
 
189
+ keyframe_model = None
190
+ interpolation_model = None
191
+ vae_model = None
192
+ hubert_model = None
193
+ wavlm_model = None
194
+ landmarks_extractor = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
 
197
  @spaces.GPU(duration=60)
198
  @torch.no_grad()
199
+ def compute_video_embedding(video_reader, min_len, vae_model):
200
  """Compute embeddings from video"""
201
 
202
  total_frames = min_len
 
246
 
247
  @spaces.GPU(duration=120)
248
  @torch.no_grad()
249
+ def compute_hubert_embedding(raw_audio, hubert_model):
250
  """Compute embeddings from audio"""
251
  print(f"Computing audio embedding from {raw_audio.shape}")
252
 
 
293
 
294
  @spaces.GPU(duration=120)
295
  @torch.no_grad()
296
+ def compute_wavlm_embedding(raw_audio, wavlm_model):
297
  """Compute embeddings from audio"""
298
  audio = rearrange(raw_audio, "(f s) -> f s", s=640)
299
 
 
332
 
333
 
334
  @torch.no_grad()
335
+ def extract_video_landmarks(video_frames, landmarks_extractor):
336
  """Extract landmarks from video frames"""
337
 
338
  # Create a progress bar for Gradio
 
629
  duration=10,
630
  )
631
 
632
+ if vae_model is None:
633
+ vae_model = VaeWrapper("video")
634
+ vae_model = vae_model.half() # Convert to half precision
635
+ try:
636
+ vae_model = torch.compile(vae_model)
637
+ print("Successfully compiled vae_model in FP16")
638
+ except Exception as e:
639
+ print(f"Warning: Failed to compile vae_model: {e}")
640
+
641
+ if hubert_model is None:
642
+ hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda()
643
+ hubert_model = hubert_model.half() # Convert to half precision
644
+ try:
645
+ hubert_model = torch.compile(hubert_model)
646
+ print("Successfully compiled hubert_model in FP16")
647
+ except Exception as e:
648
+ print(f"Warning: Failed to compile hubert_model: {e}")
649
+
650
+ if wavlm_model is None:
651
+ wavlm_model = WavLM_wrapper(
652
+ model_size="Base+",
653
+ feed_as_frames=False,
654
+ merge_type="None",
655
+ model_path=os.path.join(repo_path, "checkpoints/WavLM-Base+.pt"),
656
+ ).cuda()
657
+
658
+ wavlm_model = wavlm_model.half() # Convert to half precision
659
+ try:
660
+ wavlm_model = torch.compile(wavlm_model)
661
+ print("Successfully compiled wavlm_model in FP16")
662
+ except Exception as e:
663
+ print(f"Warning: Failed to compile wavlm_model: {e}")
664
+
665
+ if landmarks_extractor is None:
666
+ landmarks_extractor = LandmarksExtractor()
667
+
668
+ if keyframe_model is None:
669
+ keyframe_model = load_model(
670
+ config="keyframe.yaml",
671
+ ckpt=os.path.join(repo_path, "checkpoints/keyframe_dub.pt"),
672
+ )
673
+
674
+ if interpolation_model is None:
675
+ interpolation_model = load_model(
676
+ config="interpolation.yaml",
677
+ ckpt=os.path.join(repo_path, "checkpoints/interpolation_dub.pt"),
678
+ )
679
+
680
+ keyframe_model.en_and_decode_n_samples_a_time = 2
681
+ interpolation_model.en_and_decode_n_samples_a_time = 2
682
+
683
  # Use default media if none provided
684
  if video_input is None:
685
  video_input = DEFAULT_VIDEO_PATH
 
763
 
764
  # Compute video embeddings and landmarks - store full version in cache
765
  video_embedding, video_frames = compute_video_embedding(
766
+ video_reader, len(video_reader), vae_model
767
  )
768
+ video_landmarks = extract_video_landmarks(video_frames, landmarks_extractor)
769
 
770
  # Update video cache with full versions
771
  cache["video"]["path"] = video_path_hash
 
821
  print("Computing audio embeddings")
822
 
823
  # Compute audio embeddings with the truncated audio
824
+ hubert_embedding = compute_hubert_embedding(raw_audio_reshape, hubert_model)
825
+ wavlm_embedding = compute_wavlm_embedding(raw_audio_reshape, wavlm_model)
826
 
827
  # Update audio cache with full embeddings
828
  # Note: raw_audio was already cached above