Spaces:
Running
Running
Antoni Bigata
commited on
Commit
·
2dff4e4
1
Parent(s):
cf0da47
requirements
Browse files
app.py
CHANGED
@@ -186,54 +186,17 @@ DEFAULT_AUDIO_PATH = os.path.join(
|
|
186 |
# landmarks_extractor,
|
187 |
# ) = load_all_models()
|
188 |
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
print("Successfully compiled vae_model in FP16")
|
196 |
-
except Exception as e:
|
197 |
-
print(f"Warning: Failed to compile vae_model: {e}")
|
198 |
-
|
199 |
-
hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda()
|
200 |
-
hubert_model = hubert_model.half() # Convert to half precision
|
201 |
-
try:
|
202 |
-
hubert_model = torch.compile(hubert_model)
|
203 |
-
print("Successfully compiled hubert_model in FP16")
|
204 |
-
except Exception as e:
|
205 |
-
print(f"Warning: Failed to compile hubert_model: {e}")
|
206 |
-
|
207 |
-
wavlm_model = WavLM_wrapper(
|
208 |
-
model_size="Base+",
|
209 |
-
feed_as_frames=False,
|
210 |
-
merge_type="None",
|
211 |
-
model_path=os.path.join(repo_path, "checkpoints/WavLM-Base+.pt"),
|
212 |
-
).cuda()
|
213 |
-
|
214 |
-
wavlm_model = wavlm_model.half() # Convert to half precision
|
215 |
-
try:
|
216 |
-
wavlm_model = torch.compile(wavlm_model)
|
217 |
-
print("Successfully compiled wavlm_model in FP16")
|
218 |
-
except Exception as e:
|
219 |
-
print(f"Warning: Failed to compile wavlm_model: {e}")
|
220 |
-
|
221 |
-
landmarks_extractor = LandmarksExtractor()
|
222 |
-
keyframe_model = load_model(
|
223 |
-
config="keyframe.yaml",
|
224 |
-
ckpt=os.path.join(repo_path, "checkpoints/keyframe_dub.pt"),
|
225 |
-
)
|
226 |
-
interpolation_model = load_model(
|
227 |
-
config="interpolation.yaml",
|
228 |
-
ckpt=os.path.join(repo_path, "checkpoints/interpolation_dub.pt"),
|
229 |
-
)
|
230 |
-
keyframe_model.en_and_decode_n_samples_a_time = 2
|
231 |
-
interpolation_model.en_and_decode_n_samples_a_time = 2
|
232 |
|
233 |
|
234 |
@spaces.GPU(duration=60)
|
235 |
@torch.no_grad()
|
236 |
-
def compute_video_embedding(video_reader, min_len):
|
237 |
"""Compute embeddings from video"""
|
238 |
|
239 |
total_frames = min_len
|
@@ -283,7 +246,7 @@ def compute_video_embedding(video_reader, min_len):
|
|
283 |
|
284 |
@spaces.GPU(duration=120)
|
285 |
@torch.no_grad()
|
286 |
-
def compute_hubert_embedding(raw_audio):
|
287 |
"""Compute embeddings from audio"""
|
288 |
print(f"Computing audio embedding from {raw_audio.shape}")
|
289 |
|
@@ -330,7 +293,7 @@ def compute_hubert_embedding(raw_audio):
|
|
330 |
|
331 |
@spaces.GPU(duration=120)
|
332 |
@torch.no_grad()
|
333 |
-
def compute_wavlm_embedding(raw_audio):
|
334 |
"""Compute embeddings from audio"""
|
335 |
audio = rearrange(raw_audio, "(f s) -> f s", s=640)
|
336 |
|
@@ -369,7 +332,7 @@ def compute_wavlm_embedding(raw_audio):
|
|
369 |
|
370 |
|
371 |
@torch.no_grad()
|
372 |
-
def extract_video_landmarks(video_frames):
|
373 |
"""Extract landmarks from video frames"""
|
374 |
|
375 |
# Create a progress bar for Gradio
|
@@ -666,6 +629,57 @@ def process_video(video_input, audio_input, max_num_seconds):
|
|
666 |
duration=10,
|
667 |
)
|
668 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
669 |
# Use default media if none provided
|
670 |
if video_input is None:
|
671 |
video_input = DEFAULT_VIDEO_PATH
|
@@ -749,9 +763,9 @@ def process_video(video_input, audio_input, max_num_seconds):
|
|
749 |
|
750 |
# Compute video embeddings and landmarks - store full version in cache
|
751 |
video_embedding, video_frames = compute_video_embedding(
|
752 |
-
video_reader, len(video_reader)
|
753 |
)
|
754 |
-
video_landmarks = extract_video_landmarks(video_frames)
|
755 |
|
756 |
# Update video cache with full versions
|
757 |
cache["video"]["path"] = video_path_hash
|
@@ -807,8 +821,8 @@ def process_video(video_input, audio_input, max_num_seconds):
|
|
807 |
print("Computing audio embeddings")
|
808 |
|
809 |
# Compute audio embeddings with the truncated audio
|
810 |
-
hubert_embedding = compute_hubert_embedding(raw_audio_reshape)
|
811 |
-
wavlm_embedding = compute_wavlm_embedding(raw_audio_reshape)
|
812 |
|
813 |
# Update audio cache with full embeddings
|
814 |
# Note: raw_audio was already cached above
|
|
|
186 |
# landmarks_extractor,
|
187 |
# ) = load_all_models()
|
188 |
|
189 |
+
keyframe_model = None
|
190 |
+
interpolation_model = None
|
191 |
+
vae_model = None
|
192 |
+
hubert_model = None
|
193 |
+
wavlm_model = None
|
194 |
+
landmarks_extractor = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
|
197 |
@spaces.GPU(duration=60)
|
198 |
@torch.no_grad()
|
199 |
+
def compute_video_embedding(video_reader, min_len, vae_model):
|
200 |
"""Compute embeddings from video"""
|
201 |
|
202 |
total_frames = min_len
|
|
|
246 |
|
247 |
@spaces.GPU(duration=120)
|
248 |
@torch.no_grad()
|
249 |
+
def compute_hubert_embedding(raw_audio, hubert_model):
|
250 |
"""Compute embeddings from audio"""
|
251 |
print(f"Computing audio embedding from {raw_audio.shape}")
|
252 |
|
|
|
293 |
|
294 |
@spaces.GPU(duration=120)
|
295 |
@torch.no_grad()
|
296 |
+
def compute_wavlm_embedding(raw_audio, wavlm_model):
|
297 |
"""Compute embeddings from audio"""
|
298 |
audio = rearrange(raw_audio, "(f s) -> f s", s=640)
|
299 |
|
|
|
332 |
|
333 |
|
334 |
@torch.no_grad()
|
335 |
+
def extract_video_landmarks(video_frames, landmarks_extractor):
|
336 |
"""Extract landmarks from video frames"""
|
337 |
|
338 |
# Create a progress bar for Gradio
|
|
|
629 |
duration=10,
|
630 |
)
|
631 |
|
632 |
+
if vae_model is None:
|
633 |
+
vae_model = VaeWrapper("video")
|
634 |
+
vae_model = vae_model.half() # Convert to half precision
|
635 |
+
try:
|
636 |
+
vae_model = torch.compile(vae_model)
|
637 |
+
print("Successfully compiled vae_model in FP16")
|
638 |
+
except Exception as e:
|
639 |
+
print(f"Warning: Failed to compile vae_model: {e}")
|
640 |
+
|
641 |
+
if hubert_model is None:
|
642 |
+
hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda()
|
643 |
+
hubert_model = hubert_model.half() # Convert to half precision
|
644 |
+
try:
|
645 |
+
hubert_model = torch.compile(hubert_model)
|
646 |
+
print("Successfully compiled hubert_model in FP16")
|
647 |
+
except Exception as e:
|
648 |
+
print(f"Warning: Failed to compile hubert_model: {e}")
|
649 |
+
|
650 |
+
if wavlm_model is None:
|
651 |
+
wavlm_model = WavLM_wrapper(
|
652 |
+
model_size="Base+",
|
653 |
+
feed_as_frames=False,
|
654 |
+
merge_type="None",
|
655 |
+
model_path=os.path.join(repo_path, "checkpoints/WavLM-Base+.pt"),
|
656 |
+
).cuda()
|
657 |
+
|
658 |
+
wavlm_model = wavlm_model.half() # Convert to half precision
|
659 |
+
try:
|
660 |
+
wavlm_model = torch.compile(wavlm_model)
|
661 |
+
print("Successfully compiled wavlm_model in FP16")
|
662 |
+
except Exception as e:
|
663 |
+
print(f"Warning: Failed to compile wavlm_model: {e}")
|
664 |
+
|
665 |
+
if landmarks_extractor is None:
|
666 |
+
landmarks_extractor = LandmarksExtractor()
|
667 |
+
|
668 |
+
if keyframe_model is None:
|
669 |
+
keyframe_model = load_model(
|
670 |
+
config="keyframe.yaml",
|
671 |
+
ckpt=os.path.join(repo_path, "checkpoints/keyframe_dub.pt"),
|
672 |
+
)
|
673 |
+
|
674 |
+
if interpolation_model is None:
|
675 |
+
interpolation_model = load_model(
|
676 |
+
config="interpolation.yaml",
|
677 |
+
ckpt=os.path.join(repo_path, "checkpoints/interpolation_dub.pt"),
|
678 |
+
)
|
679 |
+
|
680 |
+
keyframe_model.en_and_decode_n_samples_a_time = 2
|
681 |
+
interpolation_model.en_and_decode_n_samples_a_time = 2
|
682 |
+
|
683 |
# Use default media if none provided
|
684 |
if video_input is None:
|
685 |
video_input = DEFAULT_VIDEO_PATH
|
|
|
763 |
|
764 |
# Compute video embeddings and landmarks - store full version in cache
|
765 |
video_embedding, video_frames = compute_video_embedding(
|
766 |
+
video_reader, len(video_reader), vae_model
|
767 |
)
|
768 |
+
video_landmarks = extract_video_landmarks(video_frames, landmarks_extractor)
|
769 |
|
770 |
# Update video cache with full versions
|
771 |
cache["video"]["path"] = video_path_hash
|
|
|
821 |
print("Computing audio embeddings")
|
822 |
|
823 |
# Compute audio embeddings with the truncated audio
|
824 |
+
hubert_embedding = compute_hubert_embedding(raw_audio_reshape, hubert_model)
|
825 |
+
wavlm_embedding = compute_wavlm_embedding(raw_audio_reshape, wavlm_model)
|
826 |
|
827 |
# Update audio cache with full embeddings
|
828 |
# Note: raw_audio was already cached above
|