Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -259,6 +259,12 @@ def initialize_models():
|
|
259 |
clip_image_encoder=clip_image_encoder,
|
260 |
)
|
261 |
pipeline.to(device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
print("β
Pipeline created and moved to device")
|
263 |
|
264 |
print("π Loading Wav2Vec models...")
|
@@ -343,6 +349,15 @@ def generate_video(
|
|
343 |
audio_features = extract_audio_features(audio_path, wav2vec_processor, wav2vec_model)
|
344 |
audio_embeds = audio_features.unsqueeze(0).to(device=device, dtype=config.weight_dtype)
|
345 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
video_length = int(audio_clip.duration * fps)
|
347 |
video_length = (
|
348 |
int((video_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1
|
@@ -405,7 +420,6 @@ def generate_video(
|
|
405 |
audio_start_frame = init_frames * 2
|
406 |
audio_end_frame = (init_frames + current_partial_length) * 2
|
407 |
|
408 |
-
# Ensure audio embeds are long enough
|
409 |
if audio_embeds.shape[1] < audio_end_frame:
|
410 |
repeat_times = (audio_end_frame // audio_embeds.shape[1]) + 1
|
411 |
audio_embeds = audio_embeds.repeat(1, repeat_times, 1)
|
@@ -414,9 +428,9 @@ def generate_video(
|
|
414 |
|
415 |
with torch.no_grad():
|
416 |
sample = pipeline(
|
417 |
-
|
|
|
418 |
num_frames=current_partial_length,
|
419 |
-
negative_prompt=negative_prompt,
|
420 |
audio_embeds=partial_audio_embeds,
|
421 |
audio_scale=audio_scale,
|
422 |
ip_mask=ip_mask,
|
|
|
259 |
clip_image_encoder=clip_image_encoder,
|
260 |
)
|
261 |
pipeline.to(device=device)
|
262 |
+
|
263 |
+
if torch.__version__ >= "2.0":
|
264 |
+
print("π Compiling the pipeline with torch.compile()...")
|
265 |
+
pipeline.transformer = torch.compile(pipeline.transformer, mode="reduce-overhead", fullgraph=True)
|
266 |
+
print("β
Pipeline transformer compiled!")
|
267 |
+
|
268 |
print("β
Pipeline created and moved to device")
|
269 |
|
270 |
print("π Loading Wav2Vec models...")
|
|
|
349 |
audio_features = extract_audio_features(audio_path, wav2vec_processor, wav2vec_model)
|
350 |
audio_embeds = audio_features.unsqueeze(0).to(device=device, dtype=config.weight_dtype)
|
351 |
|
352 |
+
progress(0.25, desc="Encoding prompts...")
|
353 |
+
prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
|
354 |
+
prompt,
|
355 |
+
device=device,
|
356 |
+
num_images_per_prompt=1,
|
357 |
+
do_classifier_free_guidance=(guidance_scale > 1.0),
|
358 |
+
negative_prompt=negative_prompt
|
359 |
+
)
|
360 |
+
|
361 |
video_length = int(audio_clip.duration * fps)
|
362 |
video_length = (
|
363 |
int((video_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1
|
|
|
420 |
audio_start_frame = init_frames * 2
|
421 |
audio_end_frame = (init_frames + current_partial_length) * 2
|
422 |
|
|
|
423 |
if audio_embeds.shape[1] < audio_end_frame:
|
424 |
repeat_times = (audio_end_frame // audio_embeds.shape[1]) + 1
|
425 |
audio_embeds = audio_embeds.repeat(1, repeat_times, 1)
|
|
|
428 |
|
429 |
with torch.no_grad():
|
430 |
sample = pipeline(
|
431 |
+
prompt_embeds=prompt_embeds,
|
432 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
433 |
num_frames=current_partial_length,
|
|
|
434 |
audio_embeds=partial_audio_embeds,
|
435 |
audio_scale=audio_scale,
|
436 |
ip_mask=ip_mask,
|