Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- .gradio/certificate.pem +31 -0
- CogVideoX/__pycache__/pipeline_rgba.cpython-310.pyc +0 -0
- CogVideoX/__pycache__/rgba_utils.cpython-310.pyc +0 -0
- CogVideoX/cli.py +90 -0
- CogVideoX/pipeline_rgba.py +744 -0
- CogVideoX/rgba_utils.py +313 -0
- LICENSE.md +15 -0
- Mochi/README.md +88 -0
- Mochi/args.py +268 -0
- Mochi/cli.py +79 -0
- Mochi/dataset_simple.py +50 -0
- Mochi/embed.py +111 -0
- Mochi/pipeline_mochi_rgba.py +782 -0
- Mochi/prepare_dataset.sh +15 -0
- Mochi/rgba_utils.py +354 -0
- Mochi/train.py +593 -0
- Mochi/train.sh +31 -0
- Mochi/trim_and_crop_videos.py +126 -0
- Mochi/utils.py +260 -0
- README.md +161 -12
- app.py +306 -0
- model_cogvideox_rgba_lora/cogvideox_rgba_lora.safetensors +3 -0
- requirements.txt +14 -0
- wechat_group.jpg +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
wechat_group.jpg filter=lfs diff=lfs merge=lfs -text
|
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
CogVideoX/__pycache__/pipeline_rgba.cpython-310.pyc
ADDED
Binary file (25.8 kB). View file
|
|
CogVideoX/__pycache__/rgba_utils.cpython-310.pyc
ADDED
Binary file (9.28 kB). View file
|
|
CogVideoX/cli.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from diffusers import CogVideoXDPMScheduler
|
4 |
+
from pipeline_rgba import CogVideoXPipeline
|
5 |
+
from diffusers.utils import export_to_video
|
6 |
+
import argparse
|
7 |
+
import numpy as np
|
8 |
+
from rgba_utils import *
|
9 |
+
|
10 |
+
def main(args):
|
11 |
+
# 1. load pipeline
|
12 |
+
pipe = CogVideoXPipeline.from_pretrained(args.model_path, torch_dtype=torch.bfloat16)
|
13 |
+
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
14 |
+
pipe.enable_sequential_cpu_offload()
|
15 |
+
pipe.vae.enable_slicing()
|
16 |
+
pipe.vae.enable_tiling()
|
17 |
+
|
18 |
+
|
19 |
+
# 2. define prompt and arguments
|
20 |
+
pipeline_args = {
|
21 |
+
"prompt": args.prompt,
|
22 |
+
"guidance_scale": args.guidance_scale,
|
23 |
+
"num_inference_steps": args.num_inference_steps,
|
24 |
+
"height": args.height,
|
25 |
+
"width": args.width,
|
26 |
+
"num_frames": args.num_frames,
|
27 |
+
"output_type": "latent",
|
28 |
+
"use_dynamic_cfg":True,
|
29 |
+
}
|
30 |
+
|
31 |
+
# 3. prepare rgbx utils
|
32 |
+
# breakpoint()
|
33 |
+
seq_length = 2 * (
|
34 |
+
(args.height // pipe.vae_scale_factor_spatial // 2)
|
35 |
+
* (args.width // pipe.vae_scale_factor_spatial // 2)
|
36 |
+
* ((args.num_frames - 1) // pipe.vae_scale_factor_temporal + 1)
|
37 |
+
)
|
38 |
+
# seq_length = 35100
|
39 |
+
|
40 |
+
prepare_for_rgba_inference(
|
41 |
+
pipe.transformer,
|
42 |
+
rgba_weights_path=args.lora_path,
|
43 |
+
device="cuda",
|
44 |
+
dtype=torch.bfloat16,
|
45 |
+
text_length=226,
|
46 |
+
seq_length=seq_length, # this is for the creation of attention mask.
|
47 |
+
)
|
48 |
+
|
49 |
+
# 4. run inference
|
50 |
+
generator = torch.manual_seed(args.seed) if args.seed else None
|
51 |
+
frames_latents = pipe(**pipeline_args, generator=generator).frames
|
52 |
+
|
53 |
+
frames_latents_rgb, frames_latents_alpha = frames_latents.chunk(2, dim=1)
|
54 |
+
|
55 |
+
frames_rgb = decode_latents(pipe, frames_latents_rgb)
|
56 |
+
frames_alpha = decode_latents(pipe, frames_latents_alpha)
|
57 |
+
|
58 |
+
|
59 |
+
pooled_alpha = np.max(frames_alpha, axis=-1, keepdims=True)
|
60 |
+
frames_alpha_pooled = np.repeat(pooled_alpha, 3, axis=-1)
|
61 |
+
premultiplied_rgb = frames_rgb * frames_alpha_pooled
|
62 |
+
|
63 |
+
if os.path.exists(args.output_path) == False:
|
64 |
+
os.makedirs(args.output_path)
|
65 |
+
|
66 |
+
export_to_video(premultiplied_rgb[0], os.path.join(args.output_path, "rgb.mp4"), fps=args.fps)
|
67 |
+
export_to_video(frames_alpha_pooled[0], os.path.join(args.output_path, "alpha.mp4"), fps=args.fps)
|
68 |
+
|
69 |
+
|
70 |
+
if __name__ == "__main__":
|
71 |
+
parser = argparse.ArgumentParser(description="Generate a video from a text prompt")
|
72 |
+
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
|
73 |
+
parser.add_argument("--lora_path", type=str, default="/hpc2hdd/home/lwang592/projects/CogVideo/sat/outputs/training/ckpts-5b-attn_rebias-partial_lora-8gpu-wo_t2a/lora-rgba-12-21-19-11/5000/rgba_lora.safetensors", help="The path of the LoRA weights to be used")
|
74 |
+
|
75 |
+
parser.add_argument(
|
76 |
+
"--model_path", type=str, default="THUDM/CogVideoX-5B", help="Path of the pre-trained model use"
|
77 |
+
)
|
78 |
+
|
79 |
+
|
80 |
+
parser.add_argument("--output_path", type=str, default="./output", help="The path save generated video")
|
81 |
+
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
|
82 |
+
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
|
83 |
+
parser.add_argument("--num_frames", type=int, default=49, help="Number of steps for the inference process")
|
84 |
+
parser.add_argument("--width", type=int, default=720, help="Number of steps for the inference process")
|
85 |
+
parser.add_argument("--height", type=int, default=480, help="Number of steps for the inference process")
|
86 |
+
parser.add_argument("--fps", type=int, default=8, help="Number of steps for the inference process")
|
87 |
+
parser.add_argument("--seed", type=int, default=None, help="The seed for reproducibility")
|
88 |
+
args = parser.parse_args()
|
89 |
+
|
90 |
+
main(args)
|
CogVideoX/pipeline_rgba.py
ADDED
@@ -0,0 +1,744 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import inspect
|
17 |
+
import math
|
18 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
22 |
+
|
23 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
24 |
+
from diffusers.loaders import CogVideoXLoraLoaderMixin
|
25 |
+
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
26 |
+
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
27 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
28 |
+
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
29 |
+
from diffusers.utils import logging, replace_example_docstring
|
30 |
+
from diffusers.utils.torch_utils import randn_tensor
|
31 |
+
from diffusers.video_processor import VideoProcessor
|
32 |
+
from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput
|
33 |
+
|
34 |
+
|
35 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
36 |
+
|
37 |
+
|
38 |
+
EXAMPLE_DOC_STRING = """
|
39 |
+
Examples:
|
40 |
+
```python
|
41 |
+
>>> import torch
|
42 |
+
>>> from diffusers import CogVideoXPipeline
|
43 |
+
>>> from diffusers.utils import export_to_video
|
44 |
+
|
45 |
+
>>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
|
46 |
+
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
|
47 |
+
>>> prompt = (
|
48 |
+
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
|
49 |
+
... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
|
50 |
+
... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
|
51 |
+
... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
|
52 |
+
... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
|
53 |
+
... "atmosphere of this unique musical performance."
|
54 |
+
... )
|
55 |
+
>>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
|
56 |
+
>>> export_to_video(video, "output.mp4", fps=8)
|
57 |
+
```
|
58 |
+
"""
|
59 |
+
|
60 |
+
|
61 |
+
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
62 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
63 |
+
tw = tgt_width
|
64 |
+
th = tgt_height
|
65 |
+
h, w = src
|
66 |
+
r = h / w
|
67 |
+
if r > (th / tw):
|
68 |
+
resize_height = th
|
69 |
+
resize_width = int(round(th / h * w))
|
70 |
+
else:
|
71 |
+
resize_width = tw
|
72 |
+
resize_height = int(round(tw / w * h))
|
73 |
+
|
74 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
75 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
76 |
+
|
77 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
78 |
+
|
79 |
+
|
80 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
81 |
+
def retrieve_timesteps(
|
82 |
+
scheduler,
|
83 |
+
num_inference_steps: Optional[int] = None,
|
84 |
+
device: Optional[Union[str, torch.device]] = None,
|
85 |
+
timesteps: Optional[List[int]] = None,
|
86 |
+
sigmas: Optional[List[float]] = None,
|
87 |
+
**kwargs,
|
88 |
+
):
|
89 |
+
r"""
|
90 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
91 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
scheduler (`SchedulerMixin`):
|
95 |
+
The scheduler to get timesteps from.
|
96 |
+
num_inference_steps (`int`):
|
97 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
98 |
+
must be `None`.
|
99 |
+
device (`str` or `torch.device`, *optional*):
|
100 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
101 |
+
timesteps (`List[int]`, *optional*):
|
102 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
103 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
104 |
+
sigmas (`List[float]`, *optional*):
|
105 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
106 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
110 |
+
second element is the number of inference steps.
|
111 |
+
"""
|
112 |
+
if timesteps is not None and sigmas is not None:
|
113 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
114 |
+
if timesteps is not None:
|
115 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
116 |
+
if not accepts_timesteps:
|
117 |
+
raise ValueError(
|
118 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
119 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
120 |
+
)
|
121 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
122 |
+
timesteps = scheduler.timesteps
|
123 |
+
num_inference_steps = len(timesteps)
|
124 |
+
elif sigmas is not None:
|
125 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
126 |
+
if not accept_sigmas:
|
127 |
+
raise ValueError(
|
128 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
129 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
130 |
+
)
|
131 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
132 |
+
timesteps = scheduler.timesteps
|
133 |
+
num_inference_steps = len(timesteps)
|
134 |
+
else:
|
135 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
136 |
+
timesteps = scheduler.timesteps
|
137 |
+
return timesteps, num_inference_steps
|
138 |
+
|
139 |
+
|
140 |
+
class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
141 |
+
r"""
|
142 |
+
Pipeline for text-to-video generation using CogVideoX.
|
143 |
+
|
144 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
145 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
146 |
+
|
147 |
+
Args:
|
148 |
+
vae ([`AutoencoderKL`]):
|
149 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
150 |
+
text_encoder ([`T5EncoderModel`]):
|
151 |
+
Frozen text-encoder. CogVideoX uses
|
152 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
153 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
154 |
+
tokenizer (`T5Tokenizer`):
|
155 |
+
Tokenizer of class
|
156 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
157 |
+
transformer ([`CogVideoXTransformer3DModel`]):
|
158 |
+
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
|
159 |
+
scheduler ([`SchedulerMixin`]):
|
160 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
161 |
+
"""
|
162 |
+
|
163 |
+
_optional_components = []
|
164 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
165 |
+
|
166 |
+
_callback_tensor_inputs = [
|
167 |
+
"latents",
|
168 |
+
"prompt_embeds",
|
169 |
+
"negative_prompt_embeds",
|
170 |
+
]
|
171 |
+
|
172 |
+
def __init__(
|
173 |
+
self,
|
174 |
+
tokenizer: T5Tokenizer,
|
175 |
+
text_encoder: T5EncoderModel,
|
176 |
+
vae: AutoencoderKLCogVideoX,
|
177 |
+
transformer: CogVideoXTransformer3DModel,
|
178 |
+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
179 |
+
):
|
180 |
+
super().__init__()
|
181 |
+
|
182 |
+
self.register_modules(
|
183 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
184 |
+
)
|
185 |
+
self.vae_scale_factor_spatial = (
|
186 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
187 |
+
)
|
188 |
+
self.vae_scale_factor_temporal = (
|
189 |
+
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
190 |
+
)
|
191 |
+
self.vae_scaling_factor_image = (
|
192 |
+
self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
|
193 |
+
)
|
194 |
+
|
195 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
196 |
+
|
197 |
+
def _get_t5_prompt_embeds(
|
198 |
+
self,
|
199 |
+
prompt: Union[str, List[str]] = None,
|
200 |
+
num_videos_per_prompt: int = 1,
|
201 |
+
max_sequence_length: int = 226,
|
202 |
+
device: Optional[torch.device] = None,
|
203 |
+
dtype: Optional[torch.dtype] = None,
|
204 |
+
):
|
205 |
+
device = device or self._execution_device
|
206 |
+
dtype = dtype or self.text_encoder.dtype
|
207 |
+
|
208 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
209 |
+
batch_size = len(prompt)
|
210 |
+
|
211 |
+
text_inputs = self.tokenizer(
|
212 |
+
prompt,
|
213 |
+
padding="max_length",
|
214 |
+
max_length=max_sequence_length,
|
215 |
+
truncation=True,
|
216 |
+
add_special_tokens=True,
|
217 |
+
return_tensors="pt",
|
218 |
+
)
|
219 |
+
text_input_ids = text_inputs.input_ids
|
220 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
221 |
+
|
222 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
223 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
224 |
+
logger.warning(
|
225 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
226 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
227 |
+
)
|
228 |
+
|
229 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
230 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
231 |
+
|
232 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
233 |
+
_, seq_len, _ = prompt_embeds.shape
|
234 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
235 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
236 |
+
|
237 |
+
return prompt_embeds
|
238 |
+
|
239 |
+
def encode_prompt(
|
240 |
+
self,
|
241 |
+
prompt: Union[str, List[str]],
|
242 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
243 |
+
do_classifier_free_guidance: bool = True,
|
244 |
+
num_videos_per_prompt: int = 1,
|
245 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
246 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
247 |
+
max_sequence_length: int = 226,
|
248 |
+
device: Optional[torch.device] = None,
|
249 |
+
dtype: Optional[torch.dtype] = None,
|
250 |
+
):
|
251 |
+
r"""
|
252 |
+
Encodes the prompt into text encoder hidden states.
|
253 |
+
|
254 |
+
Args:
|
255 |
+
prompt (`str` or `List[str]`, *optional*):
|
256 |
+
prompt to be encoded
|
257 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
258 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
259 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
260 |
+
less than `1`).
|
261 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
262 |
+
Whether to use classifier free guidance or not.
|
263 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
264 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
265 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
266 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
267 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
268 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
269 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
270 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
271 |
+
argument.
|
272 |
+
device: (`torch.device`, *optional*):
|
273 |
+
torch device
|
274 |
+
dtype: (`torch.dtype`, *optional*):
|
275 |
+
torch dtype
|
276 |
+
"""
|
277 |
+
device = device or self._execution_device
|
278 |
+
|
279 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
280 |
+
if prompt is not None:
|
281 |
+
batch_size = len(prompt)
|
282 |
+
else:
|
283 |
+
batch_size = prompt_embeds.shape[0]
|
284 |
+
|
285 |
+
if prompt_embeds is None:
|
286 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
287 |
+
prompt=prompt,
|
288 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
289 |
+
max_sequence_length=max_sequence_length,
|
290 |
+
device=device,
|
291 |
+
dtype=dtype,
|
292 |
+
)
|
293 |
+
|
294 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
295 |
+
negative_prompt = negative_prompt or ""
|
296 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
297 |
+
|
298 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
299 |
+
raise TypeError(
|
300 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
301 |
+
f" {type(prompt)}."
|
302 |
+
)
|
303 |
+
elif batch_size != len(negative_prompt):
|
304 |
+
raise ValueError(
|
305 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
306 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
307 |
+
" the batch size of `prompt`."
|
308 |
+
)
|
309 |
+
|
310 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
311 |
+
prompt=negative_prompt,
|
312 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
313 |
+
max_sequence_length=max_sequence_length,
|
314 |
+
device=device,
|
315 |
+
dtype=dtype,
|
316 |
+
)
|
317 |
+
|
318 |
+
return prompt_embeds, negative_prompt_embeds
|
319 |
+
|
320 |
+
def prepare_latents(
|
321 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
322 |
+
):
|
323 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
324 |
+
raise ValueError(
|
325 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
326 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
327 |
+
)
|
328 |
+
|
329 |
+
shape = (
|
330 |
+
batch_size,
|
331 |
+
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
332 |
+
num_channels_latents,
|
333 |
+
height // self.vae_scale_factor_spatial,
|
334 |
+
width // self.vae_scale_factor_spatial,
|
335 |
+
)
|
336 |
+
|
337 |
+
if latents is None:
|
338 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
339 |
+
else:
|
340 |
+
latents = latents.to(device)
|
341 |
+
|
342 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
343 |
+
latents = latents * self.scheduler.init_noise_sigma
|
344 |
+
return latents
|
345 |
+
|
346 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
347 |
+
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
348 |
+
latents = 1 / self.vae_scaling_factor_image * latents
|
349 |
+
|
350 |
+
frames = self.vae.decode(latents).sample
|
351 |
+
return frames
|
352 |
+
|
353 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
354 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
355 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
356 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
357 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
358 |
+
# and should be between [0, 1]
|
359 |
+
|
360 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
361 |
+
extra_step_kwargs = {}
|
362 |
+
if accepts_eta:
|
363 |
+
extra_step_kwargs["eta"] = eta
|
364 |
+
|
365 |
+
# check if the scheduler accepts generator
|
366 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
367 |
+
if accepts_generator:
|
368 |
+
extra_step_kwargs["generator"] = generator
|
369 |
+
return extra_step_kwargs
|
370 |
+
|
371 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
372 |
+
def check_inputs(
|
373 |
+
self,
|
374 |
+
prompt,
|
375 |
+
height,
|
376 |
+
width,
|
377 |
+
negative_prompt,
|
378 |
+
callback_on_step_end_tensor_inputs,
|
379 |
+
prompt_embeds=None,
|
380 |
+
negative_prompt_embeds=None,
|
381 |
+
):
|
382 |
+
if height % 8 != 0 or width % 8 != 0:
|
383 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
384 |
+
|
385 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
386 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
387 |
+
):
|
388 |
+
raise ValueError(
|
389 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
390 |
+
)
|
391 |
+
if prompt is not None and prompt_embeds is not None:
|
392 |
+
raise ValueError(
|
393 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
394 |
+
" only forward one of the two."
|
395 |
+
)
|
396 |
+
elif prompt is None and prompt_embeds is None:
|
397 |
+
raise ValueError(
|
398 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
399 |
+
)
|
400 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
401 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
402 |
+
|
403 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
404 |
+
raise ValueError(
|
405 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
406 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
407 |
+
)
|
408 |
+
|
409 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
410 |
+
raise ValueError(
|
411 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
412 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
413 |
+
)
|
414 |
+
|
415 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
416 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
417 |
+
raise ValueError(
|
418 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
419 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
420 |
+
f" {negative_prompt_embeds.shape}."
|
421 |
+
)
|
422 |
+
|
423 |
+
def fuse_qkv_projections(self) -> None:
|
424 |
+
r"""Enables fused QKV projections."""
|
425 |
+
self.fusing_transformer = True
|
426 |
+
self.transformer.fuse_qkv_projections()
|
427 |
+
|
428 |
+
def unfuse_qkv_projections(self) -> None:
|
429 |
+
r"""Disable QKV projection fusion if enabled."""
|
430 |
+
if not self.fusing_transformer:
|
431 |
+
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
432 |
+
else:
|
433 |
+
self.transformer.unfuse_qkv_projections()
|
434 |
+
self.fusing_transformer = False
|
435 |
+
|
436 |
+
def _prepare_rotary_positional_embeddings(
|
437 |
+
self,
|
438 |
+
height: int,
|
439 |
+
width: int,
|
440 |
+
num_frames: int,
|
441 |
+
device: torch.device,
|
442 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
443 |
+
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
444 |
+
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
445 |
+
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
446 |
+
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
447 |
+
|
448 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
449 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
450 |
+
)
|
451 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
452 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
453 |
+
crops_coords=grid_crops_coords,
|
454 |
+
grid_size=(grid_height, grid_width),
|
455 |
+
temporal_size=num_frames,
|
456 |
+
)
|
457 |
+
|
458 |
+
freqs_cos = freqs_cos.to(device=device)
|
459 |
+
freqs_sin = freqs_sin.to(device=device)
|
460 |
+
return freqs_cos, freqs_sin
|
461 |
+
|
462 |
+
@property
|
463 |
+
def guidance_scale(self):
|
464 |
+
return self._guidance_scale
|
465 |
+
|
466 |
+
@property
|
467 |
+
def num_timesteps(self):
|
468 |
+
return self._num_timesteps
|
469 |
+
|
470 |
+
@property
|
471 |
+
def attention_kwargs(self):
|
472 |
+
return self._attention_kwargs
|
473 |
+
|
474 |
+
@property
|
475 |
+
def interrupt(self):
|
476 |
+
return self._interrupt
|
477 |
+
|
478 |
+
@torch.no_grad()
|
479 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
480 |
+
def __call__(
|
481 |
+
self,
|
482 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
483 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
484 |
+
height: int = 480,
|
485 |
+
width: int = 720,
|
486 |
+
num_frames: int = 10, #was 49
|
487 |
+
num_inference_steps: int = 5, #was 50
|
488 |
+
timesteps: Optional[List[int]] = None,
|
489 |
+
guidance_scale: float = 6,
|
490 |
+
use_dynamic_cfg: bool = False,
|
491 |
+
num_videos_per_prompt: int = 1,
|
492 |
+
eta: float = 0.0,
|
493 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
494 |
+
latents: Optional[torch.FloatTensor] = None,
|
495 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
496 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
497 |
+
output_type: str = "pil",
|
498 |
+
return_dict: bool = True,
|
499 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
500 |
+
callback_on_step_end: Optional[
|
501 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
502 |
+
] = None,
|
503 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
504 |
+
max_sequence_length: int = 226,
|
505 |
+
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
506 |
+
"""
|
507 |
+
Function invoked when calling the pipeline for generation.
|
508 |
+
|
509 |
+
Args:
|
510 |
+
prompt (`str` or `List[str]`, *optional*):
|
511 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
512 |
+
instead.
|
513 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
514 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
515 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
516 |
+
less than `1`).
|
517 |
+
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
518 |
+
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
519 |
+
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
520 |
+
The width in pixels of the generated image. This is set to 720 by default for the best results.
|
521 |
+
num_frames (`int`, defaults to `48`):
|
522 |
+
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
523 |
+
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
|
524 |
+
num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
|
525 |
+
needs to be satisfied is that of divisibility mentioned above.
|
526 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
527 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
528 |
+
expense of slower inference.
|
529 |
+
timesteps (`List[int]`, *optional*):
|
530 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
531 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
532 |
+
passed will be used. Must be in descending order.
|
533 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
534 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
535 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
536 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
537 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
538 |
+
usually at the expense of lower image quality.
|
539 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
540 |
+
The number of videos to generate per prompt.
|
541 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
542 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
543 |
+
to make generation deterministic.
|
544 |
+
latents (`torch.FloatTensor`, *optional*):
|
545 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
546 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
547 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
548 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
549 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
550 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
551 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
552 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
553 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
554 |
+
argument.
|
555 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
556 |
+
The output format of the generate image. Choose between
|
557 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
558 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
559 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
560 |
+
of a plain tuple.
|
561 |
+
attention_kwargs (`dict`, *optional*):
|
562 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
563 |
+
`self.processor` in
|
564 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
565 |
+
callback_on_step_end (`Callable`, *optional*):
|
566 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
567 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
568 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
569 |
+
`callback_on_step_end_tensor_inputs`.
|
570 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
571 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
572 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
573 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
574 |
+
max_sequence_length (`int`, defaults to `226`):
|
575 |
+
Maximum sequence length in encoded prompt. Must be consistent with
|
576 |
+
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
577 |
+
|
578 |
+
Examples:
|
579 |
+
|
580 |
+
Returns:
|
581 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
|
582 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
|
583 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
584 |
+
"""
|
585 |
+
|
586 |
+
if num_frames > 49:
|
587 |
+
raise ValueError(
|
588 |
+
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
|
589 |
+
)
|
590 |
+
|
591 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
592 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
593 |
+
|
594 |
+
num_videos_per_prompt = 1
|
595 |
+
|
596 |
+
# 1. Check inputs. Raise error if not correct
|
597 |
+
self.check_inputs(
|
598 |
+
prompt,
|
599 |
+
height,
|
600 |
+
width,
|
601 |
+
negative_prompt,
|
602 |
+
callback_on_step_end_tensor_inputs,
|
603 |
+
prompt_embeds,
|
604 |
+
negative_prompt_embeds,
|
605 |
+
)
|
606 |
+
self._guidance_scale = guidance_scale
|
607 |
+
self._attention_kwargs = attention_kwargs
|
608 |
+
self._interrupt = False
|
609 |
+
|
610 |
+
# 2. Default call parameters
|
611 |
+
if prompt is not None and isinstance(prompt, str):
|
612 |
+
batch_size = 1
|
613 |
+
elif prompt is not None and isinstance(prompt, list):
|
614 |
+
batch_size = len(prompt)
|
615 |
+
else:
|
616 |
+
batch_size = prompt_embeds.shape[0]
|
617 |
+
|
618 |
+
device = self._execution_device
|
619 |
+
|
620 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
621 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
622 |
+
# corresponds to doing no classifier free guidance.
|
623 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
624 |
+
|
625 |
+
# 3. Encode input prompt
|
626 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
627 |
+
prompt,
|
628 |
+
negative_prompt,
|
629 |
+
do_classifier_free_guidance,
|
630 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
631 |
+
prompt_embeds=prompt_embeds,
|
632 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
633 |
+
max_sequence_length=max_sequence_length,
|
634 |
+
device=device,
|
635 |
+
)
|
636 |
+
if do_classifier_free_guidance:
|
637 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
638 |
+
|
639 |
+
# 4. Prepare timesteps
|
640 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
641 |
+
self._num_timesteps = len(timesteps)
|
642 |
+
|
643 |
+
# 5. Prepare latents.
|
644 |
+
latent_channels = self.transformer.config.in_channels
|
645 |
+
latents = self.prepare_latents(
|
646 |
+
batch_size * num_videos_per_prompt,
|
647 |
+
latent_channels,
|
648 |
+
num_frames,
|
649 |
+
height,
|
650 |
+
width,
|
651 |
+
prompt_embeds.dtype,
|
652 |
+
device,
|
653 |
+
generator,
|
654 |
+
latents,
|
655 |
+
).repeat(1,2,1,1,1) # Luozhou
|
656 |
+
|
657 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
658 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
659 |
+
|
660 |
+
# 7. Create rotary embeds if required
|
661 |
+
image_rotary_emb = (
|
662 |
+
self._prepare_rotary_positional_embeddings(height, width, latents.size(1) // 2, device) # Luozhou
|
663 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
664 |
+
else None
|
665 |
+
)
|
666 |
+
|
667 |
+
# 8. Denoising loop
|
668 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
669 |
+
|
670 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
671 |
+
# for DPM-solver++
|
672 |
+
old_pred_original_sample = None
|
673 |
+
for i, t in enumerate(timesteps):
|
674 |
+
if self.interrupt:
|
675 |
+
continue
|
676 |
+
|
677 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
678 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
679 |
+
|
680 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
681 |
+
timestep = t.expand(latent_model_input.shape[0])
|
682 |
+
|
683 |
+
# predict noise model_output
|
684 |
+
noise_pred = self.transformer(
|
685 |
+
hidden_states=latent_model_input,
|
686 |
+
encoder_hidden_states=prompt_embeds,
|
687 |
+
timestep=timestep,
|
688 |
+
image_rotary_emb=image_rotary_emb,
|
689 |
+
attention_kwargs=attention_kwargs,
|
690 |
+
return_dict=False,
|
691 |
+
)[0]
|
692 |
+
noise_pred = noise_pred.float()
|
693 |
+
|
694 |
+
# perform guidance
|
695 |
+
if use_dynamic_cfg:
|
696 |
+
self._guidance_scale = 1 + guidance_scale * (
|
697 |
+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
698 |
+
)
|
699 |
+
if do_classifier_free_guidance:
|
700 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
701 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
702 |
+
|
703 |
+
# compute the previous noisy sample x_t -> x_t-1
|
704 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
705 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
706 |
+
else:
|
707 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
708 |
+
noise_pred,
|
709 |
+
old_pred_original_sample,
|
710 |
+
t,
|
711 |
+
timesteps[i - 1] if i > 0 else None,
|
712 |
+
latents,
|
713 |
+
**extra_step_kwargs,
|
714 |
+
return_dict=False,
|
715 |
+
)
|
716 |
+
latents = latents.to(prompt_embeds.dtype)
|
717 |
+
|
718 |
+
# call the callback, if provided
|
719 |
+
if callback_on_step_end is not None:
|
720 |
+
callback_kwargs = {}
|
721 |
+
for k in callback_on_step_end_tensor_inputs:
|
722 |
+
callback_kwargs[k] = locals()[k]
|
723 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
724 |
+
|
725 |
+
latents = callback_outputs.pop("latents", latents)
|
726 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
727 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
728 |
+
|
729 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
730 |
+
progress_bar.update()
|
731 |
+
|
732 |
+
if not output_type == "latent":
|
733 |
+
video = self.decode_latents(latents)
|
734 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
735 |
+
else:
|
736 |
+
video = latents
|
737 |
+
|
738 |
+
# Offload all models
|
739 |
+
self.maybe_free_model_hooks()
|
740 |
+
|
741 |
+
if not return_dict:
|
742 |
+
return (video,)
|
743 |
+
|
744 |
+
return CogVideoXPipelineOutput(frames=video)
|
CogVideoX/rgba_utils.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
5 |
+
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
6 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
7 |
+
from safetensors.torch import load_file
|
8 |
+
|
9 |
+
logger = logging.get_logger(__name__)
|
10 |
+
|
11 |
+
@torch.no_grad()
|
12 |
+
def decode_latents(pipe, latents):
|
13 |
+
video = pipe.decode_latents(latents)
|
14 |
+
video = pipe.video_processor.postprocess_video(video=video, output_type="np")
|
15 |
+
return video
|
16 |
+
|
17 |
+
def create_attention_mask(text_length: int, seq_length: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
18 |
+
"""
|
19 |
+
Create an attention mask to block text from attending to alpha.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
text_length: Length of the text sequence.
|
23 |
+
seq_length: Length of the other sequence.
|
24 |
+
device: The device where the mask will be stored.
|
25 |
+
dtype: The data type of the mask tensor.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
An attention mask tensor.
|
29 |
+
"""
|
30 |
+
total_length = text_length + seq_length
|
31 |
+
dense_mask = torch.ones((total_length, total_length), dtype=torch.bool)
|
32 |
+
dense_mask[:text_length, text_length + seq_length // 2:] = False
|
33 |
+
return dense_mask.to(device=device, dtype=dtype)
|
34 |
+
|
35 |
+
class RGBALoRACogVideoXAttnProcessor:
|
36 |
+
r"""
|
37 |
+
Processor for implementing scaled dot-product attention for the CogVideoX model.
|
38 |
+
It applies a rotary embedding on query and key vectors, but does not include spatial normalization.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self, device, dtype, attention_mask, lora_rank=128, lora_alpha=1.0, latent_dim=3072):
|
42 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
43 |
+
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0 or later.")
|
44 |
+
|
45 |
+
# Initialize LoRA layers
|
46 |
+
self.lora_alpha = lora_alpha
|
47 |
+
self.lora_rank = lora_rank
|
48 |
+
|
49 |
+
# Helper function to create LoRA layers
|
50 |
+
def create_lora_layer(in_dim, mid_dim, out_dim):
|
51 |
+
return nn.Sequential(
|
52 |
+
nn.Linear(in_dim, mid_dim, bias=False, device=device, dtype=dtype),
|
53 |
+
nn.Linear(mid_dim, out_dim, bias=False, device=device, dtype=dtype)
|
54 |
+
)
|
55 |
+
|
56 |
+
self.to_q_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
|
57 |
+
self.to_k_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
|
58 |
+
self.to_v_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
|
59 |
+
self.to_out_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
|
60 |
+
|
61 |
+
# Store attention mask
|
62 |
+
self.attention_mask = attention_mask
|
63 |
+
|
64 |
+
def _apply_lora(self, hidden_states, seq_len, query, key, value, scaling):
|
65 |
+
"""Applies LoRA updates to query, key, and value tensors."""
|
66 |
+
query_delta = self.to_q_lora(hidden_states).to(query.device)
|
67 |
+
query[:, -seq_len // 2:, :] += query_delta[:, -seq_len // 2:, :] * scaling
|
68 |
+
|
69 |
+
key_delta = self.to_k_lora(hidden_states).to(key.device)
|
70 |
+
key[:, -seq_len // 2:, :] += key_delta[:, -seq_len // 2:, :] * scaling
|
71 |
+
|
72 |
+
value_delta = self.to_v_lora(hidden_states).to(value.device)
|
73 |
+
value[:, -seq_len // 2:, :] += value_delta[:, -seq_len // 2:, :] * scaling
|
74 |
+
|
75 |
+
return query, key, value
|
76 |
+
|
77 |
+
def _apply_rotary_embedding(self, query, key, image_rotary_emb, seq_len, text_seq_length, attn):
|
78 |
+
"""Applies rotary embeddings to query and key tensors."""
|
79 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
80 |
+
|
81 |
+
# Apply rotary embedding to RGB and alpha sections
|
82 |
+
query[:, :, text_seq_length:text_seq_length + seq_len // 2] = apply_rotary_emb(
|
83 |
+
query[:, :, text_seq_length:text_seq_length + seq_len // 2], image_rotary_emb)
|
84 |
+
query[:, :, text_seq_length + seq_len // 2:] = apply_rotary_emb(
|
85 |
+
query[:, :, text_seq_length + seq_len // 2:], image_rotary_emb)
|
86 |
+
|
87 |
+
if not attn.is_cross_attention:
|
88 |
+
key[:, :, text_seq_length:text_seq_length + seq_len // 2] = apply_rotary_emb(
|
89 |
+
key[:, :, text_seq_length:text_seq_length + seq_len // 2], image_rotary_emb)
|
90 |
+
key[:, :, text_seq_length + seq_len // 2:] = apply_rotary_emb(
|
91 |
+
key[:, :, text_seq_length + seq_len // 2:], image_rotary_emb)
|
92 |
+
|
93 |
+
return query, key
|
94 |
+
|
95 |
+
def __call__(
|
96 |
+
self,
|
97 |
+
attn,
|
98 |
+
hidden_states: torch.Tensor,
|
99 |
+
encoder_hidden_states: torch.Tensor,
|
100 |
+
attention_mask: Optional[torch.Tensor] = None,
|
101 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
102 |
+
) -> torch.Tensor:
|
103 |
+
# Concatenate encoder and decoder hidden states
|
104 |
+
text_seq_length = encoder_hidden_states.size(1)
|
105 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
106 |
+
|
107 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
108 |
+
seq_len = hidden_states.shape[1] - text_seq_length
|
109 |
+
scaling = self.lora_alpha / self.lora_rank
|
110 |
+
|
111 |
+
# Apply LoRA to query, key, value
|
112 |
+
query = attn.to_q(hidden_states)
|
113 |
+
key = attn.to_k(hidden_states)
|
114 |
+
value = attn.to_v(hidden_states)
|
115 |
+
|
116 |
+
query, key, value = self._apply_lora(hidden_states, seq_len, query, key, value, scaling)
|
117 |
+
|
118 |
+
# Reshape query, key, value for multi-head attention
|
119 |
+
inner_dim = key.shape[-1]
|
120 |
+
head_dim = inner_dim // attn.heads
|
121 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
122 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
123 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
124 |
+
|
125 |
+
# Normalize query and key if required
|
126 |
+
if attn.norm_q is not None:
|
127 |
+
query = attn.norm_q(query)
|
128 |
+
if attn.norm_k is not None:
|
129 |
+
key = attn.norm_k(key)
|
130 |
+
|
131 |
+
# Apply rotary embeddings if provided
|
132 |
+
if image_rotary_emb is not None:
|
133 |
+
query, key = self._apply_rotary_embedding(query, key, image_rotary_emb, seq_len, text_seq_length, attn)
|
134 |
+
|
135 |
+
# Compute scaled dot-product attention
|
136 |
+
hidden_states = F.scaled_dot_product_attention(
|
137 |
+
query, key, value, attn_mask=self.attention_mask, dropout_p=0.0, is_causal=False
|
138 |
+
)
|
139 |
+
|
140 |
+
# Reshape the output tensor back to the original shape
|
141 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
142 |
+
|
143 |
+
# Apply linear projection and LoRA to the output
|
144 |
+
original_hidden_states = attn.to_out[0](hidden_states)
|
145 |
+
hidden_states_delta = self.to_out_lora(hidden_states).to(hidden_states.device)
|
146 |
+
original_hidden_states[:, -seq_len // 2:, :] += hidden_states_delta[:, -seq_len // 2:, :] * scaling
|
147 |
+
|
148 |
+
# Apply dropout
|
149 |
+
hidden_states = attn.to_out[1](original_hidden_states)
|
150 |
+
|
151 |
+
# Split back into encoder and decoder hidden states
|
152 |
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
153 |
+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
154 |
+
)
|
155 |
+
|
156 |
+
return hidden_states, encoder_hidden_states
|
157 |
+
|
158 |
+
def prepare_for_rgba_inference(
|
159 |
+
model, rgba_weights_path: str, device: torch.device, dtype: torch.dtype,
|
160 |
+
lora_rank: int = 128, lora_alpha: float = 1.0, text_length: int = 226, seq_length: int = 35100
|
161 |
+
):
|
162 |
+
def load_lora_sequential_weights(lora_layer, lora_layers, prefix):
|
163 |
+
lora_layer[0].load_state_dict({'weight': lora_layers[f"{prefix}.lora_A.weight"]})
|
164 |
+
lora_layer[1].load_state_dict({'weight': lora_layers[f"{prefix}.lora_B.weight"]})
|
165 |
+
|
166 |
+
|
167 |
+
rgba_weights = load_file(rgba_weights_path)
|
168 |
+
aux_emb = rgba_weights['domain_emb']
|
169 |
+
|
170 |
+
attention_mask = create_attention_mask(text_length, seq_length, device, dtype)
|
171 |
+
attn_procs = {}
|
172 |
+
|
173 |
+
for name in model.attn_processors.keys():
|
174 |
+
attn_processor = RGBALoRACogVideoXAttnProcessor(
|
175 |
+
device=device, dtype=dtype, attention_mask=attention_mask,
|
176 |
+
lora_rank=lora_rank, lora_alpha=lora_alpha
|
177 |
+
)
|
178 |
+
|
179 |
+
index = name.split('.')[1]
|
180 |
+
base_prefix = f'transformer.transformer_blocks.{index}.attn1'
|
181 |
+
|
182 |
+
for lora_layer, prefix in [
|
183 |
+
(attn_processor.to_q_lora, f'{base_prefix}.to_q'),
|
184 |
+
(attn_processor.to_k_lora, f'{base_prefix}.to_k'),
|
185 |
+
(attn_processor.to_v_lora, f'{base_prefix}.to_v'),
|
186 |
+
(attn_processor.to_out_lora, f'{base_prefix}.to_out.0'),
|
187 |
+
]:
|
188 |
+
load_lora_sequential_weights(lora_layer, rgba_weights, prefix)
|
189 |
+
|
190 |
+
attn_procs[name] = attn_processor
|
191 |
+
|
192 |
+
model.set_attn_processor(attn_procs)
|
193 |
+
|
194 |
+
def custom_forward(self):
|
195 |
+
def forward(
|
196 |
+
hidden_states: torch.Tensor,
|
197 |
+
encoder_hidden_states: torch.Tensor,
|
198 |
+
timestep: Union[int, float, torch.LongTensor],
|
199 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
200 |
+
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
|
201 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
202 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
203 |
+
return_dict: bool = True,
|
204 |
+
):
|
205 |
+
if attention_kwargs is not None:
|
206 |
+
attention_kwargs = attention_kwargs.copy()
|
207 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
208 |
+
else:
|
209 |
+
lora_scale = 1.0
|
210 |
+
|
211 |
+
if USE_PEFT_BACKEND:
|
212 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
213 |
+
scale_lora_layers(self, lora_scale)
|
214 |
+
else:
|
215 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
216 |
+
logger.warning(
|
217 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
218 |
+
)
|
219 |
+
|
220 |
+
batch_size, num_frames, channels, height, width = hidden_states.shape
|
221 |
+
|
222 |
+
# 1. Time embedding
|
223 |
+
timesteps = timestep
|
224 |
+
t_emb = self.time_proj(timesteps)
|
225 |
+
|
226 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
227 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
228 |
+
# there might be better ways to encapsulate this.
|
229 |
+
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
230 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
231 |
+
|
232 |
+
if self.ofs_embedding is not None:
|
233 |
+
ofs_emb = self.ofs_proj(ofs)
|
234 |
+
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
|
235 |
+
ofs_emb = self.ofs_embedding(ofs_emb)
|
236 |
+
emb = emb + ofs_emb
|
237 |
+
|
238 |
+
# 2. Patch embedding
|
239 |
+
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
240 |
+
hidden_states = self.embedding_dropout(hidden_states)
|
241 |
+
|
242 |
+
text_seq_length = encoder_hidden_states.shape[1]
|
243 |
+
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
244 |
+
hidden_states = hidden_states[:, text_seq_length:]
|
245 |
+
|
246 |
+
hidden_states[:, hidden_states.size(1) // 2:, :] += aux_emb.expand(batch_size, -1, -1).to(hidden_states.device, dtype=hidden_states.dtype)
|
247 |
+
|
248 |
+
# 3. Transformer blocks
|
249 |
+
for i, block in enumerate(self.transformer_blocks):
|
250 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
251 |
+
|
252 |
+
def create_custom_forward(module):
|
253 |
+
def custom_forward(*inputs):
|
254 |
+
return module(*inputs)
|
255 |
+
|
256 |
+
return custom_forward
|
257 |
+
|
258 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
259 |
+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
260 |
+
create_custom_forward(block),
|
261 |
+
hidden_states,
|
262 |
+
encoder_hidden_states,
|
263 |
+
emb,
|
264 |
+
image_rotary_emb,
|
265 |
+
**ckpt_kwargs,
|
266 |
+
)
|
267 |
+
else:
|
268 |
+
hidden_states, encoder_hidden_states = block(
|
269 |
+
hidden_states=hidden_states,
|
270 |
+
encoder_hidden_states=encoder_hidden_states,
|
271 |
+
temb=emb,
|
272 |
+
image_rotary_emb=image_rotary_emb,
|
273 |
+
)
|
274 |
+
|
275 |
+
if not self.config.use_rotary_positional_embeddings:
|
276 |
+
# CogVideoX-2B
|
277 |
+
hidden_states = self.norm_final(hidden_states)
|
278 |
+
else:
|
279 |
+
# CogVideoX-5B
|
280 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
281 |
+
hidden_states = self.norm_final(hidden_states)
|
282 |
+
hidden_states = hidden_states[:, text_seq_length:]
|
283 |
+
|
284 |
+
# 4. Final block
|
285 |
+
hidden_states = self.norm_out(hidden_states, temb=emb)
|
286 |
+
hidden_states = self.proj_out(hidden_states)
|
287 |
+
|
288 |
+
# 5. Unpatchify
|
289 |
+
p = self.config.patch_size
|
290 |
+
p_t = self.config.patch_size_t
|
291 |
+
|
292 |
+
if p_t is None:
|
293 |
+
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
294 |
+
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
295 |
+
else:
|
296 |
+
output = hidden_states.reshape(
|
297 |
+
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
|
298 |
+
)
|
299 |
+
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
300 |
+
|
301 |
+
if USE_PEFT_BACKEND:
|
302 |
+
# remove `lora_scale` from each PEFT layer
|
303 |
+
unscale_lora_layers(self, lora_scale)
|
304 |
+
|
305 |
+
if not return_dict:
|
306 |
+
return (output,)
|
307 |
+
return Transformer2DModelOutput(sample=output)
|
308 |
+
|
309 |
+
|
310 |
+
return forward
|
311 |
+
|
312 |
+
model.forward = custom_forward(model)
|
313 |
+
|
LICENSE.md
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
**ADOBE RESEARCH LICENSE**
|
2 |
+
|
3 |
+
This license agreement (the “License”) between Adobe Inc., having a place of business at 345 Park Avenue, San Jose, California 95110-2704 (“Adobe”), and you, the individual or entity exercising rights under this License (“you” or “your”), sets forth the terms for your use of certain research materials that are owned by Adobe (the “Licensed Materials”). By exercising rights under this License, you accept and agree to be bound by its terms. If you are exercising rights under this License on behalf of an entity, then “you” means you and such entity, and you (personally) represent and warrant that you (personally) have all necessary authority to bind that entity to the terms of this License.
|
4 |
+
|
5 |
+
1. **GRANT OF LICENSE.**<br/>
|
6 |
+
1.1 Adobe grants you a nonexclusive, worldwide, royalty-free, revocable, fully paid license to (A) reproduce, use, modify, and publicly display the Licensed Materials for noncommercial research purposes only; and (B) redistribute the Licensed Materials, and modifications or derivative works thereof, for noncommercial research purposes only, provided that you give recipients a copy of this License upon redistribution.<br/>
|
7 |
+
1.2 You may add your own copyright statement to your modifications and/or provide additional or different license terms for use, reproduction, modification, public display, and redistribution of your modifications and derivative works, provided that such license terms limit the use, reproduction, modification, public display, and redistribution of such modifications and derivative works to noncommercial research purposes only.<br/>
|
8 |
+
1.3 For purposes of this License, noncommercial research purposes include academic research and teaching only. Noncommercial research purposes do not include commercial licensing or distribution, development of commercial products, or any other activity that results in commercial gain.<br/>
|
9 |
+
2. **OWNERSHIP AND ATTRIBUTION.** Adobe and its licensors own all right, title, and interest in the Licensed Materials. You must retain all copyright notices and/or disclaimers in the Licensed Materials.
|
10 |
+
3. **DISCLAIMER OF WARRANTIES.** THE LICENSED MATERIALS ARE PROVIDED “AS IS” WITHOUT WARRANTY OF ANY KIND. THE ENTIRE RISK AS TO THE USE, RESULTS, AND PERFORMANCE OF THE LICENSED MATERIALS IS ASSUMED BY YOU. ADOBE DISCLAIMS ALL WARRANTIES, EXPRESS, IMPLIED OR STATUTORY, WITH REGARD TO YOUR USE OF THE LICENSED MATERIALS, INCLUDING, BUT NOT LIMITED TO, NONINFRINGEMENT OF THIRD-PARTY RIGHTS.
|
11 |
+
4. **LIMITATION OF LIABILITY.** IN NO EVENT WILL ADOBE BE LIABLE FOR ANY ACTUAL, INCIDENTAL, SPECIAL OR CONSEQUENTIAL DAMAGES, INCLUDING WITHOUT LIMITATION, LOSS OF PROFITS OR OTHER COMMERCIAL LOSS, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THE LICENSED MATERIALS, EVEN IF ADOBE HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
12 |
+
5. **TERM AND TERMINATION.**<br/>
|
13 |
+
5.1 The License is effective upon acceptance by you and will remain in effect unless terminated earlier in accordance with Section 5.2.<br/>
|
14 |
+
5.2 Any breach of any material provision of this License will automatically terminate the rights granted herein.<br/>
|
15 |
+
5.3 Sections 2 (Ownership and Attribution), 3 (Disclaimer of Warranties), 4 (Limitation of Liability) will survive termination of this License.<br/>
|
Mochi/README.md
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# RGBA LoRA Training Instructions
|
2 |
+
|
3 |
+
<!-- <table align=center>
|
4 |
+
<tr>
|
5 |
+
<th align=center> Dataset Sample </th>
|
6 |
+
<th align=center> Test Sample </th>
|
7 |
+
</tr>
|
8 |
+
<tr>
|
9 |
+
<td align=center><video src="https://github.com/user-attachments/assets/6f906a32-b169-493f-a713-07679e87cd91"> Your browser does not support the video tag. </video></td>
|
10 |
+
<td align=center><video src="https://github.com/user-attachments/assets/d356e70f-ccf4-47f7-be1d-8d21108d8a84"> Your browser does not support the video tag. </video></td>
|
11 |
+
</tr>
|
12 |
+
</table> -->
|
13 |
+
<!--
|
14 |
+
Now you can make Mochi-1 your own with `diffusers`, too 🤗 🧨
|
15 |
+
|
16 |
+
We provide a minimal and faithful reimplementation of the [Mochi-1 original fine-tuner](https://github.com/genmoai/mochi/tree/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner). As usual, we leverage `peft` for things LoRA in our implementation.
|
17 |
+
|
18 |
+
**Updates**
|
19 |
+
|
20 |
+
December 1 2024: Support for checkpoint saving and loading. -->
|
21 |
+
|
22 |
+
We follow the same steps as the original [finetrainers](https://github.com/a-r-r-o-w/finetrainers/blob/main/training/mochi-1/README.md) to prepare the [RGBA dataset](https://grail.cs.washington.edu/projects/background-matting-v2/#/datasets).
|
23 |
+
For RGBA dataset, you can follow the instructions above to preprocess the dataset yourself.
|
24 |
+
|
25 |
+
Here are some detailed steps to prepare the dataset for Mochi-1 fine-tuning:
|
26 |
+
|
27 |
+
1. Download our preprocessed [Video RGBA dataset](https://hkustgz-my.sharepoint.com/:u:/g/personal/lwang592_connect_hkust-gz_edu_cn/EezKQoum3IVJiJ9c8GebNfYBe-xN0OS5mVUvAwyL_rQLuw?e=1obdbA), which has undergone preprocessing operations such as color decontamination and background blur.
|
28 |
+
2. Use `trim_and_crop_videos.py` to crop and trim the RGB and Alpha videos as needed.
|
29 |
+
3. Use `embed.py` to encode the RGB videos into latent representations and embed the video captions into embeddings.
|
30 |
+
4. Use `embed.py` to encode the Alpha videos into latent representations.
|
31 |
+
5. Concatenate the RGB and Alpha latent representations along the frames dimension.
|
32 |
+
|
33 |
+
Finally, the dataset should be in the following format:
|
34 |
+
```
|
35 |
+
<video_1_concatenated>.latent.pt
|
36 |
+
<video_1_captions>.embed.pt
|
37 |
+
<video_2_concatenated>.latent.pt
|
38 |
+
<video_2_captions>.embed.pt
|
39 |
+
```
|
40 |
+
|
41 |
+
|
42 |
+
Now, we're ready to fine-tune. To launch, run:
|
43 |
+
|
44 |
+
```bash
|
45 |
+
bash train.sh
|
46 |
+
```
|
47 |
+
**Note:**
|
48 |
+
|
49 |
+
The arg `--num_frames` is used to specify the number of frames of generated **RGB** video. During generation, we will actually double the number of frames to generate the **RGB** video and **Alpha** video jointly. This double operation is automatically handled by our implementation.
|
50 |
+
|
51 |
+
For an 80GB GPU, we support processing RGB videos with dimensions of 480 × 848 × 79 (Height × Width × Frames) at a batch size of 1 using bfloat16 precision for training. However, the training is relatively slow (over one minute per iteration) because the model processes a total of 79 × 2 frames as input.
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
~~We haven't rigorously tested but without validation enabled, this script should run under 40GBs of GPU VRAM.~~
|
58 |
+
|
59 |
+
## Inference
|
60 |
+
|
61 |
+
To generate the RGBA video, run:
|
62 |
+
|
63 |
+
```bash
|
64 |
+
python cli.py \
|
65 |
+
--lora_path /path/to/lora \
|
66 |
+
--prompt "..." \
|
67 |
+
```
|
68 |
+
|
69 |
+
This command generates the RGB and Alpha videos simultaneously and saves them. Specifically, the RGB video is saved in its premultiplied form. To blend this video with any background image, you can simply use the following formula:
|
70 |
+
|
71 |
+
```python
|
72 |
+
com = rgb + (1 - alpha) * bgr
|
73 |
+
```
|
74 |
+
|
75 |
+
## Known limitations
|
76 |
+
|
77 |
+
(Contributions are welcome 🤗)
|
78 |
+
|
79 |
+
Our script currently doesn't leverage `accelerate` and some of its consequences are detailed below:
|
80 |
+
|
81 |
+
* No support for distributed training.
|
82 |
+
* `train_batch_size > 1` are supported but can potentially lead to OOMs because we currently don't have gradient accumulation support.
|
83 |
+
* No support for 8bit optimizers (but should be relatively easy to add).
|
84 |
+
|
85 |
+
**Misc**:
|
86 |
+
|
87 |
+
* We're aware of the quality issues in the `diffusers` implementation of Mochi-1. This is being fixed in [this PR](https://github.com/huggingface/diffusers/pull/10033).
|
88 |
+
* `embed.py` script is non-batched.
|
Mochi/args.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Default values taken from
|
3 |
+
https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/configs/lora.yaml
|
4 |
+
when applicable.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
|
10 |
+
def _get_model_args(parser: argparse.ArgumentParser) -> None:
|
11 |
+
parser.add_argument(
|
12 |
+
"--pretrained_model_name_or_path",
|
13 |
+
type=str,
|
14 |
+
default=None,
|
15 |
+
required=True,
|
16 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
17 |
+
)
|
18 |
+
parser.add_argument(
|
19 |
+
"--revision",
|
20 |
+
type=str,
|
21 |
+
default=None,
|
22 |
+
required=False,
|
23 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--variant",
|
27 |
+
type=str,
|
28 |
+
default=None,
|
29 |
+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--cache_dir",
|
33 |
+
type=str,
|
34 |
+
default=None,
|
35 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--cast_dit",
|
39 |
+
action="store_true",
|
40 |
+
help="If we should cast DiT params to a lower precision.",
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--compile_dit",
|
44 |
+
action="store_true",
|
45 |
+
help="If we should compile the DiT.",
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
def _get_dataset_args(parser: argparse.ArgumentParser) -> None:
|
50 |
+
parser.add_argument(
|
51 |
+
"--data_root",
|
52 |
+
type=str,
|
53 |
+
default=None,
|
54 |
+
help=("A folder containing the training data."),
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--caption_dropout",
|
58 |
+
type=float,
|
59 |
+
default=None,
|
60 |
+
help=("Probability to drop out captions randomly."),
|
61 |
+
)
|
62 |
+
|
63 |
+
parser.add_argument(
|
64 |
+
"--dataloader_num_workers",
|
65 |
+
type=int,
|
66 |
+
default=0,
|
67 |
+
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
68 |
+
)
|
69 |
+
parser.add_argument(
|
70 |
+
"--pin_memory",
|
71 |
+
action="store_true",
|
72 |
+
help="Whether or not to use the pinned memory setting in pytorch dataloader.",
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
def _get_validation_args(parser: argparse.ArgumentParser) -> None:
|
77 |
+
parser.add_argument(
|
78 |
+
"--validation_prompt",
|
79 |
+
type=str,
|
80 |
+
default=None,
|
81 |
+
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
|
82 |
+
)
|
83 |
+
parser.add_argument(
|
84 |
+
"--validation_images",
|
85 |
+
type=str,
|
86 |
+
default=None,
|
87 |
+
help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
|
88 |
+
)
|
89 |
+
parser.add_argument(
|
90 |
+
"--validation_prompt_separator",
|
91 |
+
type=str,
|
92 |
+
default=":::",
|
93 |
+
help="String that separates multiple validation prompts",
|
94 |
+
)
|
95 |
+
parser.add_argument(
|
96 |
+
"--num_validation_videos",
|
97 |
+
type=int,
|
98 |
+
default=1,
|
99 |
+
help="Number of videos that should be generated during validation per `validation_prompt`.",
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
"--validation_epochs",
|
103 |
+
type=int,
|
104 |
+
default=50,
|
105 |
+
help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.",
|
106 |
+
)
|
107 |
+
parser.add_argument(
|
108 |
+
"--enable_slicing",
|
109 |
+
action="store_true",
|
110 |
+
default=False,
|
111 |
+
help="Whether or not to use VAE slicing for saving memory.",
|
112 |
+
)
|
113 |
+
parser.add_argument(
|
114 |
+
"--enable_tiling",
|
115 |
+
action="store_true",
|
116 |
+
default=False,
|
117 |
+
help="Whether or not to use VAE tiling for saving memory.",
|
118 |
+
)
|
119 |
+
parser.add_argument(
|
120 |
+
"--enable_model_cpu_offload",
|
121 |
+
action="store_true",
|
122 |
+
default=False,
|
123 |
+
help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.",
|
124 |
+
)
|
125 |
+
parser.add_argument(
|
126 |
+
"--fps",
|
127 |
+
type=int,
|
128 |
+
default=30,
|
129 |
+
help="FPS to use when serializing the output videos.",
|
130 |
+
)
|
131 |
+
parser.add_argument(
|
132 |
+
"--height",
|
133 |
+
type=int,
|
134 |
+
default=480,
|
135 |
+
)
|
136 |
+
parser.add_argument(
|
137 |
+
"--width",
|
138 |
+
type=int,
|
139 |
+
default=848,
|
140 |
+
)
|
141 |
+
|
142 |
+
|
143 |
+
def _get_training_args(parser: argparse.ArgumentParser) -> None:
|
144 |
+
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
145 |
+
parser.add_argument("--rank", type=int, default=16, help="The rank for LoRA matrices.")
|
146 |
+
parser.add_argument(
|
147 |
+
"--lora_alpha",
|
148 |
+
type=int,
|
149 |
+
default=16,
|
150 |
+
help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.",
|
151 |
+
)
|
152 |
+
parser.add_argument(
|
153 |
+
"--target_modules",
|
154 |
+
nargs="+",
|
155 |
+
type=str,
|
156 |
+
default=["to_k", "to_q", "to_v", "to_out.0"],
|
157 |
+
help="Target modules to train LoRA for.",
|
158 |
+
)
|
159 |
+
parser.add_argument(
|
160 |
+
"--output_dir",
|
161 |
+
type=str,
|
162 |
+
default="mochi-lora",
|
163 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
164 |
+
)
|
165 |
+
parser.add_argument(
|
166 |
+
"--train_batch_size",
|
167 |
+
type=int,
|
168 |
+
default=4,
|
169 |
+
help="Batch size (per device) for the training dataloader.",
|
170 |
+
)
|
171 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
172 |
+
parser.add_argument(
|
173 |
+
"--max_train_steps",
|
174 |
+
type=int,
|
175 |
+
default=None,
|
176 |
+
help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.",
|
177 |
+
)
|
178 |
+
parser.add_argument(
|
179 |
+
"--gradient_checkpointing",
|
180 |
+
action="store_true",
|
181 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
182 |
+
)
|
183 |
+
parser.add_argument(
|
184 |
+
"--learning_rate",
|
185 |
+
type=float,
|
186 |
+
default=2e-4,
|
187 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
188 |
+
)
|
189 |
+
parser.add_argument(
|
190 |
+
"--scale_lr",
|
191 |
+
action="store_true",
|
192 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
193 |
+
)
|
194 |
+
parser.add_argument(
|
195 |
+
"--lr_warmup_steps",
|
196 |
+
type=int,
|
197 |
+
default=200,
|
198 |
+
help="Number of steps for the warmup in the lr scheduler.",
|
199 |
+
)
|
200 |
+
parser.add_argument(
|
201 |
+
"--checkpointing_steps",
|
202 |
+
type=int,
|
203 |
+
default=1000,
|
204 |
+
)
|
205 |
+
parser.add_argument(
|
206 |
+
"--resume_from_checkpoint",
|
207 |
+
type=str,
|
208 |
+
default=None,
|
209 |
+
)
|
210 |
+
|
211 |
+
|
212 |
+
def _get_optimizer_args(parser: argparse.ArgumentParser) -> None:
|
213 |
+
parser.add_argument(
|
214 |
+
"--optimizer",
|
215 |
+
type=lambda s: s.lower(),
|
216 |
+
default="adam",
|
217 |
+
choices=["adam", "adamw"],
|
218 |
+
help=("The optimizer type to use."),
|
219 |
+
)
|
220 |
+
parser.add_argument(
|
221 |
+
"--weight_decay",
|
222 |
+
type=float,
|
223 |
+
default=0.01,
|
224 |
+
help="Weight decay to use for optimizer.",
|
225 |
+
)
|
226 |
+
|
227 |
+
|
228 |
+
def _get_configuration_args(parser: argparse.ArgumentParser) -> None:
|
229 |
+
parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name")
|
230 |
+
parser.add_argument(
|
231 |
+
"--push_to_hub",
|
232 |
+
action="store_true",
|
233 |
+
help="Whether or not to push the model to the Hub.",
|
234 |
+
)
|
235 |
+
parser.add_argument(
|
236 |
+
"--hub_token",
|
237 |
+
type=str,
|
238 |
+
default=None,
|
239 |
+
help="The token to use to push to the Model Hub.",
|
240 |
+
)
|
241 |
+
parser.add_argument(
|
242 |
+
"--hub_model_id",
|
243 |
+
type=str,
|
244 |
+
default=None,
|
245 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
246 |
+
)
|
247 |
+
parser.add_argument(
|
248 |
+
"--allow_tf32",
|
249 |
+
action="store_true",
|
250 |
+
help=(
|
251 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
252 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
253 |
+
),
|
254 |
+
)
|
255 |
+
parser.add_argument("--report_to", type=str, default=None, help="If logging to wandb.")
|
256 |
+
|
257 |
+
|
258 |
+
def get_args():
|
259 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script for Mochi-1.")
|
260 |
+
|
261 |
+
_get_model_args(parser)
|
262 |
+
_get_dataset_args(parser)
|
263 |
+
_get_training_args(parser)
|
264 |
+
_get_validation_args(parser)
|
265 |
+
_get_optimizer_args(parser)
|
266 |
+
_get_configuration_args(parser)
|
267 |
+
|
268 |
+
return parser.parse_args()
|
Mochi/cli.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
# from diffusers import MochiPipeline
|
4 |
+
from pipeline_mochi_rgba import MochiPipeline
|
5 |
+
from diffusers.utils import export_to_video
|
6 |
+
import argparse
|
7 |
+
from rgba_utils import *
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
def main(args):
|
12 |
+
# 1. load pipeline
|
13 |
+
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16).to("cuda")
|
14 |
+
pipe.enable_vae_tiling()
|
15 |
+
|
16 |
+
# 2. define prompt and arguments
|
17 |
+
pipeline_args = {
|
18 |
+
"prompt": args.prompt,
|
19 |
+
"guidance_scale": args.guidance_scale,
|
20 |
+
"num_inference_steps": args.num_inference_steps,
|
21 |
+
"height": args.height,
|
22 |
+
"width": args.width,
|
23 |
+
"num_frames": args.num_frames,
|
24 |
+
"max_sequence_length": 256,
|
25 |
+
"output_type": "latent",
|
26 |
+
}
|
27 |
+
|
28 |
+
# 3. prepare rgbx utils
|
29 |
+
prepare_for_rgba_inference(
|
30 |
+
pipe.transformer,
|
31 |
+
device="cuda",
|
32 |
+
dtype=torch.bfloat16,
|
33 |
+
)
|
34 |
+
|
35 |
+
if args.lora_path is not None:
|
36 |
+
checkpoint = torch.load(args.lora_path, map_location="cpu")
|
37 |
+
processor_state_dict = checkpoint["state_dict"]
|
38 |
+
load_processor_state_dict(pipe.transformer, processor_state_dict)
|
39 |
+
|
40 |
+
|
41 |
+
# 4. inference
|
42 |
+
generator = torch.manual_seed(args.seed) if args.seed else None
|
43 |
+
frames_latents = pipe(**pipeline_args, generator=generator).frames
|
44 |
+
|
45 |
+
frames_latents_rgb, frames_latents_alpha = frames_latents.chunk(2, dim=2)
|
46 |
+
|
47 |
+
frames_rgb = decode_latents(pipe, frames_latents_rgb)
|
48 |
+
frames_alpha = decode_latents(pipe, frames_latents_alpha)
|
49 |
+
|
50 |
+
pooled_alpha = np.max(frames_alpha, axis=-1, keepdims=True)
|
51 |
+
frames_alpha_pooled = np.repeat(pooled_alpha, 3, axis=-1)
|
52 |
+
premultiplied_rgb = frames_rgb * frames_alpha_pooled
|
53 |
+
|
54 |
+
if os.path.exists(args.output_path) == False:
|
55 |
+
os.makedirs(args.output_path)
|
56 |
+
|
57 |
+
export_to_video(premultiplied_rgb[0], os.path.join(args.output_path, "rgb.mp4"), fps=args.fps)
|
58 |
+
export_to_video(frames_alpha_pooled[0], os.path.join(args.output_path, "alpha.mp4"), fps=args.fps)
|
59 |
+
export_to_video(frames_rgb[0], os.path.join(args.output_path, "original_rgb.mp4"), fps=args.fps)
|
60 |
+
|
61 |
+
if __name__ == "__main__":
|
62 |
+
parser = argparse.ArgumentParser(description="Generate a video from a text prompt")
|
63 |
+
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
|
64 |
+
parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
|
65 |
+
|
66 |
+
parser.add_argument(
|
67 |
+
"--model_path", type=str, default="genmo/mochi-1-preview", help="Path of the pre-trained model use"
|
68 |
+
)
|
69 |
+
parser.add_argument("--output_path", type=str, default="./output", help="The path save generated video")
|
70 |
+
parser.add_argument("--guidance_scale", type=float, default=6, help="The scale for classifier-free guidance")
|
71 |
+
parser.add_argument("--num_inference_steps", type=int, default=64, help="Inference steps")
|
72 |
+
parser.add_argument("--num_frames", type=int, default=79, help="Number of steps for the inference process")
|
73 |
+
parser.add_argument("--width", type=int, default=848, help="Number of steps for the inference process")
|
74 |
+
parser.add_argument("--height", type=int, default=480, help="Number of steps for the inference process")
|
75 |
+
parser.add_argument("--fps", type=int, default=30, help="Number of steps for the inference process")
|
76 |
+
parser.add_argument("--seed", type=int, default=None, help="The seed for reproducibility")
|
77 |
+
args = parser.parse_args()
|
78 |
+
|
79 |
+
main(args)
|
Mochi/dataset_simple.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Taken from
|
3 |
+
https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/dataset.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import click
|
9 |
+
import torch
|
10 |
+
from torch.utils.data import DataLoader, Dataset
|
11 |
+
|
12 |
+
|
13 |
+
def load_to_cpu(x):
|
14 |
+
return torch.load(x, map_location=torch.device("cpu"), weights_only=True)
|
15 |
+
|
16 |
+
|
17 |
+
class LatentEmbedDataset(Dataset):
|
18 |
+
def __init__(self, file_paths, repeat=1):
|
19 |
+
self.items = [
|
20 |
+
(Path(p).with_suffix(".latent.pt"), Path(p).with_suffix(".embed.pt"))
|
21 |
+
for p in file_paths
|
22 |
+
if Path(p).with_suffix(".latent.pt").is_file() and Path(p).with_suffix(".embed.pt").is_file()
|
23 |
+
]
|
24 |
+
self.items = self.items * repeat
|
25 |
+
print(f"Loaded {len(self.items)}/{len(file_paths)} valid file pairs.")
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return len(self.items)
|
29 |
+
|
30 |
+
def __getitem__(self, idx):
|
31 |
+
latent_path, embed_path = self.items[idx]
|
32 |
+
return load_to_cpu(latent_path), load_to_cpu(embed_path)
|
33 |
+
|
34 |
+
|
35 |
+
@click.command()
|
36 |
+
@click.argument("directory", type=click.Path(exists=True, file_okay=False))
|
37 |
+
def process_videos(directory):
|
38 |
+
dir_path = Path(directory)
|
39 |
+
mp4_files = [str(f) for f in dir_path.glob("**/*.mp4") if not f.name.endswith(".recon.mp4")]
|
40 |
+
assert mp4_files, f"No mp4 files found"
|
41 |
+
|
42 |
+
dataset = LatentEmbedDataset(mp4_files)
|
43 |
+
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
|
44 |
+
|
45 |
+
for latents, embeds in dataloader:
|
46 |
+
print([(k, v.shape) for k, v in latents.items()])
|
47 |
+
|
48 |
+
|
49 |
+
if __name__ == "__main__":
|
50 |
+
process_videos()
|
Mochi/embed.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Adapted from:
|
3 |
+
https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/encode_videos.py
|
4 |
+
https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/embed_captions.py
|
5 |
+
"""
|
6 |
+
|
7 |
+
import click
|
8 |
+
import torch
|
9 |
+
import torchvision
|
10 |
+
from pathlib import Path
|
11 |
+
from diffusers import AutoencoderKLMochi, MochiPipeline
|
12 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
13 |
+
from tqdm.auto import tqdm
|
14 |
+
|
15 |
+
|
16 |
+
def encode_videos(model: torch.nn.Module, vid_path: Path, shape: str):
|
17 |
+
T, H, W = [int(s) for s in shape.split("x")]
|
18 |
+
assert (T - 1) % 6 == 0, "Expected T to be 1 mod 6"
|
19 |
+
video, _, metadata = torchvision.io.read_video(str(vid_path), output_format="THWC", pts_unit="secs")
|
20 |
+
fps = metadata["video_fps"]
|
21 |
+
video = video.permute(3, 0, 1, 2)
|
22 |
+
og_shape = video.shape
|
23 |
+
assert video.shape[2] == H, f"Expected {vid_path} to have height {H}, got {video.shape}"
|
24 |
+
assert video.shape[3] == W, f"Expected {vid_path} to have width {W}, got {video.shape}"
|
25 |
+
assert video.shape[1] >= T, f"Expected {vid_path} to have at least {T} frames, got {video.shape}"
|
26 |
+
if video.shape[1] > T:
|
27 |
+
video = video[:, :T]
|
28 |
+
print(f"Trimmed video from {og_shape[1]} to first {T} frames")
|
29 |
+
video = video.unsqueeze(0)
|
30 |
+
video = video.float() / 127.5 - 1.0
|
31 |
+
video = video.to(model.device)
|
32 |
+
|
33 |
+
assert video.ndim == 5
|
34 |
+
|
35 |
+
with torch.inference_mode():
|
36 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
37 |
+
ldist = model._encode(video)
|
38 |
+
|
39 |
+
torch.save(dict(ldist=ldist), vid_path.with_suffix(".latent.pt"))
|
40 |
+
|
41 |
+
|
42 |
+
@click.command()
|
43 |
+
@click.argument("output_dir", type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path))
|
44 |
+
@click.option(
|
45 |
+
"--model_id",
|
46 |
+
type=str,
|
47 |
+
help="Repo id. Should be genmo/mochi-1-preview",
|
48 |
+
default="genmo/mochi-1-preview",
|
49 |
+
)
|
50 |
+
@click.option("--shape", default="163x480x848", help="Shape of the video to encode")
|
51 |
+
@click.option("--overwrite", "-ow", is_flag=True, help="Overwrite existing latents and caption embeddings.")
|
52 |
+
def batch_process(output_dir: Path, model_id: Path, shape: str, overwrite: bool) -> None:
|
53 |
+
"""Process all videos and captions in a directory using a single GPU."""
|
54 |
+
# comment out when running on unsupported hardware
|
55 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
56 |
+
torch.backends.cudnn.allow_tf32 = True
|
57 |
+
|
58 |
+
# Get all video paths
|
59 |
+
video_paths = list(output_dir.glob("**/*.mp4"))
|
60 |
+
if not video_paths:
|
61 |
+
print(f"No MP4 files found in {output_dir}")
|
62 |
+
return
|
63 |
+
|
64 |
+
text_paths = list(output_dir.glob("**/*.txt"))
|
65 |
+
if not text_paths:
|
66 |
+
print(f"No text files found in {output_dir}")
|
67 |
+
return
|
68 |
+
|
69 |
+
# load the models
|
70 |
+
vae = AutoencoderKLMochi.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
71 |
+
text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder")
|
72 |
+
tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer")
|
73 |
+
pipeline = MochiPipeline.from_pretrained(
|
74 |
+
model_id, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None, vae=None
|
75 |
+
).to("cuda")
|
76 |
+
|
77 |
+
for idx, video_path in tqdm(enumerate(sorted(video_paths))):
|
78 |
+
print(f"Processing {video_path}")
|
79 |
+
try:
|
80 |
+
if video_path.with_suffix(".latent.pt").exists() and not overwrite:
|
81 |
+
print(f"Skipping {video_path}")
|
82 |
+
continue
|
83 |
+
|
84 |
+
# encode videos.
|
85 |
+
encode_videos(vae, vid_path=video_path, shape=shape)
|
86 |
+
|
87 |
+
# embed captions.
|
88 |
+
prompt_path = Path("/".join(str(video_path).split(".")[:-1]) + ".txt")
|
89 |
+
embed_path = prompt_path.with_suffix(".embed.pt")
|
90 |
+
|
91 |
+
if embed_path.exists() and not overwrite:
|
92 |
+
print(f"Skipping {prompt_path} - embeddings already exist")
|
93 |
+
continue
|
94 |
+
|
95 |
+
with open(prompt_path) as f:
|
96 |
+
text = f.read().strip()
|
97 |
+
with torch.inference_mode():
|
98 |
+
conditioning = pipeline.encode_prompt(prompt=[text])
|
99 |
+
|
100 |
+
conditioning = {"prompt_embeds": conditioning[0], "prompt_attention_mask": conditioning[1]}
|
101 |
+
torch.save(conditioning, embed_path)
|
102 |
+
|
103 |
+
except Exception as e:
|
104 |
+
import traceback
|
105 |
+
|
106 |
+
traceback.print_exc()
|
107 |
+
print(f"Error processing {video_path}: {str(e)}")
|
108 |
+
|
109 |
+
|
110 |
+
if __name__ == "__main__":
|
111 |
+
batch_process()
|
Mochi/pipeline_mochi_rgba.py
ADDED
@@ -0,0 +1,782 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Genmo and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
from transformers import T5EncoderModel, T5TokenizerFast
|
21 |
+
|
22 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
23 |
+
from diffusers.loaders import Mochi1LoraLoaderMixin
|
24 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
25 |
+
from diffusers.models.transformers import MochiTransformer3DModel
|
26 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
27 |
+
from diffusers.utils import (
|
28 |
+
is_torch_xla_available,
|
29 |
+
logging,
|
30 |
+
replace_example_docstring,
|
31 |
+
)
|
32 |
+
from diffusers.utils.torch_utils import randn_tensor
|
33 |
+
from diffusers.video_processor import VideoProcessor
|
34 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
35 |
+
from diffusers.pipelines.mochi.pipeline_output import MochiPipelineOutput
|
36 |
+
|
37 |
+
|
38 |
+
if is_torch_xla_available():
|
39 |
+
import torch_xla.core.xla_model as xm
|
40 |
+
|
41 |
+
XLA_AVAILABLE = True
|
42 |
+
else:
|
43 |
+
XLA_AVAILABLE = False
|
44 |
+
|
45 |
+
|
46 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
47 |
+
|
48 |
+
EXAMPLE_DOC_STRING = """
|
49 |
+
Examples:
|
50 |
+
```py
|
51 |
+
>>> import torch
|
52 |
+
>>> from diffusers import MochiPipeline
|
53 |
+
>>> from diffusers.utils import export_to_video
|
54 |
+
|
55 |
+
>>> pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16)
|
56 |
+
>>> pipe.enable_model_cpu_offload()
|
57 |
+
>>> pipe.enable_vae_tiling()
|
58 |
+
>>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
|
59 |
+
>>> frames = pipe(prompt, num_inference_steps=28, guidance_scale=3.5).frames[0]
|
60 |
+
>>> export_to_video(frames, "mochi.mp4")
|
61 |
+
```
|
62 |
+
"""
|
63 |
+
|
64 |
+
|
65 |
+
def calculate_shift(
|
66 |
+
image_seq_len,
|
67 |
+
base_seq_len: int = 256,
|
68 |
+
max_seq_len: int = 4096,
|
69 |
+
base_shift: float = 0.5,
|
70 |
+
max_shift: float = 1.16,
|
71 |
+
):
|
72 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
73 |
+
b = base_shift - m * base_seq_len
|
74 |
+
mu = image_seq_len * m + b
|
75 |
+
return mu
|
76 |
+
|
77 |
+
|
78 |
+
# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
|
79 |
+
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
|
80 |
+
if linear_steps is None:
|
81 |
+
linear_steps = num_steps // 2
|
82 |
+
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
|
83 |
+
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
|
84 |
+
quadratic_steps = num_steps - linear_steps
|
85 |
+
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
|
86 |
+
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
|
87 |
+
const = quadratic_coef * (linear_steps**2)
|
88 |
+
quadratic_sigma_schedule = [
|
89 |
+
quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
|
90 |
+
]
|
91 |
+
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule
|
92 |
+
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
93 |
+
return sigma_schedule
|
94 |
+
|
95 |
+
|
96 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
97 |
+
def retrieve_timesteps(
|
98 |
+
scheduler,
|
99 |
+
num_inference_steps: Optional[int] = None,
|
100 |
+
device: Optional[Union[str, torch.device]] = None,
|
101 |
+
timesteps: Optional[List[int]] = None,
|
102 |
+
sigmas: Optional[List[float]] = None,
|
103 |
+
**kwargs,
|
104 |
+
):
|
105 |
+
r"""
|
106 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
107 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
scheduler (`SchedulerMixin`):
|
111 |
+
The scheduler to get timesteps from.
|
112 |
+
num_inference_steps (`int`):
|
113 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
114 |
+
must be `None`.
|
115 |
+
device (`str` or `torch.device`, *optional*):
|
116 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
117 |
+
timesteps (`List[int]`, *optional*):
|
118 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
119 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
120 |
+
sigmas (`List[float]`, *optional*):
|
121 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
122 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
126 |
+
second element is the number of inference steps.
|
127 |
+
"""
|
128 |
+
if timesteps is not None and sigmas is not None:
|
129 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
130 |
+
if timesteps is not None:
|
131 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
132 |
+
if not accepts_timesteps:
|
133 |
+
raise ValueError(
|
134 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
135 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
136 |
+
)
|
137 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
138 |
+
timesteps = scheduler.timesteps
|
139 |
+
num_inference_steps = len(timesteps)
|
140 |
+
elif sigmas is not None:
|
141 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
142 |
+
if not accept_sigmas:
|
143 |
+
raise ValueError(
|
144 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
145 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
146 |
+
)
|
147 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
148 |
+
timesteps = scheduler.timesteps
|
149 |
+
num_inference_steps = len(timesteps)
|
150 |
+
else:
|
151 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
152 |
+
timesteps = scheduler.timesteps
|
153 |
+
return timesteps, num_inference_steps
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
def prepare_attention_mask(prompt_attention_mask, latents):
|
161 |
+
|
162 |
+
device = prompt_attention_mask.device
|
163 |
+
|
164 |
+
(_, _, num_frames, height, width) = latents.shape # shape of two modalities
|
165 |
+
seq_length = (height // 2) * (width // 2) * num_frames
|
166 |
+
|
167 |
+
rect_attention_mask = []
|
168 |
+
for prompt_attention_mask_i in prompt_attention_mask:
|
169 |
+
text_length = torch.sum(prompt_attention_mask_i).item()
|
170 |
+
total_length = text_length + seq_length
|
171 |
+
|
172 |
+
if text_length == 0:
|
173 |
+
rect_attention_mask.append(None)
|
174 |
+
else:
|
175 |
+
dense_mask = torch.ones((total_length, total_length), dtype=torch.bool)
|
176 |
+
dense_mask[seq_length:, seq_length // 2: seq_length] = False
|
177 |
+
rect_attention_mask.append(dense_mask.to(device))
|
178 |
+
|
179 |
+
return {
|
180 |
+
"prompt_attention_mask": prompt_attention_mask,
|
181 |
+
"rect_attention_mask": rect_attention_mask,
|
182 |
+
}
|
183 |
+
|
184 |
+
|
185 |
+
class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
|
186 |
+
r"""
|
187 |
+
The mochi pipeline for text-to-video generation.
|
188 |
+
|
189 |
+
Reference: https://github.com/genmoai/models
|
190 |
+
|
191 |
+
Args:
|
192 |
+
transformer ([`MochiTransformer3DModel`]):
|
193 |
+
Conditional Transformer architecture to denoise the encoded video latents.
|
194 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
195 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
196 |
+
vae ([`AutoencoderKL`]):
|
197 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
198 |
+
text_encoder ([`T5EncoderModel`]):
|
199 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
200 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
201 |
+
tokenizer (`CLIPTokenizer`):
|
202 |
+
Tokenizer of class
|
203 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
204 |
+
tokenizer (`T5TokenizerFast`):
|
205 |
+
Second Tokenizer of class
|
206 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
207 |
+
"""
|
208 |
+
|
209 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
210 |
+
_optional_components = []
|
211 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
212 |
+
|
213 |
+
def __init__(
|
214 |
+
self,
|
215 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
216 |
+
vae: AutoencoderKL,
|
217 |
+
text_encoder: T5EncoderModel,
|
218 |
+
tokenizer: T5TokenizerFast,
|
219 |
+
transformer: MochiTransformer3DModel,
|
220 |
+
force_zeros_for_empty_prompt: bool = False,
|
221 |
+
):
|
222 |
+
super().__init__()
|
223 |
+
|
224 |
+
self.register_modules(
|
225 |
+
vae=vae,
|
226 |
+
text_encoder=text_encoder,
|
227 |
+
tokenizer=tokenizer,
|
228 |
+
transformer=transformer,
|
229 |
+
scheduler=scheduler,
|
230 |
+
)
|
231 |
+
# TODO: determine these scaling factors from model parameters
|
232 |
+
self.vae_spatial_scale_factor = 8
|
233 |
+
self.vae_temporal_scale_factor = 6
|
234 |
+
self.patch_size = 2
|
235 |
+
|
236 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)
|
237 |
+
self.tokenizer_max_length = (
|
238 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 256
|
239 |
+
)
|
240 |
+
self.default_height = 480
|
241 |
+
self.default_width = 848
|
242 |
+
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
243 |
+
|
244 |
+
def _get_t5_prompt_embeds(
|
245 |
+
self,
|
246 |
+
prompt: Union[str, List[str]] = None,
|
247 |
+
num_videos_per_prompt: int = 1,
|
248 |
+
max_sequence_length: int = 256,
|
249 |
+
device: Optional[torch.device] = None,
|
250 |
+
dtype: Optional[torch.dtype] = None,
|
251 |
+
):
|
252 |
+
device = device or self._execution_device
|
253 |
+
dtype = dtype or self.text_encoder.dtype
|
254 |
+
|
255 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
256 |
+
batch_size = len(prompt)
|
257 |
+
|
258 |
+
text_inputs = self.tokenizer(
|
259 |
+
prompt,
|
260 |
+
padding="max_length",
|
261 |
+
max_length=max_sequence_length,
|
262 |
+
truncation=True,
|
263 |
+
add_special_tokens=True,
|
264 |
+
return_tensors="pt",
|
265 |
+
)
|
266 |
+
|
267 |
+
text_input_ids = text_inputs.input_ids
|
268 |
+
prompt_attention_mask = text_inputs.attention_mask
|
269 |
+
prompt_attention_mask = prompt_attention_mask.bool().to(device)
|
270 |
+
|
271 |
+
# The original Mochi implementation zeros out empty negative prompts
|
272 |
+
# but this can lead to overflow when placing the entire pipeline under the autocast context
|
273 |
+
# adding this here so that we can enable zeroing prompts if necessary
|
274 |
+
if self.config.force_zeros_for_empty_prompt and (prompt == "" or prompt[-1] == ""):
|
275 |
+
text_input_ids = torch.zeros_like(text_input_ids, device=device)
|
276 |
+
prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device)
|
277 |
+
|
278 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
279 |
+
|
280 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
281 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
282 |
+
logger.warning(
|
283 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
284 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
285 |
+
)
|
286 |
+
|
287 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
|
288 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
289 |
+
|
290 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
291 |
+
_, seq_len, _ = prompt_embeds.shape
|
292 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
293 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
294 |
+
|
295 |
+
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
|
296 |
+
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
|
297 |
+
|
298 |
+
return prompt_embeds, prompt_attention_mask
|
299 |
+
|
300 |
+
# Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
|
301 |
+
def encode_prompt(
|
302 |
+
self,
|
303 |
+
prompt: Union[str, List[str]],
|
304 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
305 |
+
do_classifier_free_guidance: bool = True,
|
306 |
+
num_videos_per_prompt: int = 1,
|
307 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
308 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
309 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
310 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
311 |
+
max_sequence_length: int = 256,
|
312 |
+
device: Optional[torch.device] = None,
|
313 |
+
dtype: Optional[torch.dtype] = None,
|
314 |
+
):
|
315 |
+
r"""
|
316 |
+
Encodes the prompt into text encoder hidden states.
|
317 |
+
|
318 |
+
Args:
|
319 |
+
prompt (`str` or `List[str]`, *optional*):
|
320 |
+
prompt to be encoded
|
321 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
322 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
323 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
324 |
+
less than `1`).
|
325 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
326 |
+
Whether to use classifier free guidance or not.
|
327 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
328 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
329 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
330 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
331 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
332 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
333 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
334 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
335 |
+
argument.
|
336 |
+
device: (`torch.device`, *optional*):
|
337 |
+
torch device
|
338 |
+
dtype: (`torch.dtype`, *optional*):
|
339 |
+
torch dtype
|
340 |
+
"""
|
341 |
+
device = device or self._execution_device
|
342 |
+
|
343 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
344 |
+
if prompt is not None:
|
345 |
+
batch_size = len(prompt)
|
346 |
+
else:
|
347 |
+
batch_size = prompt_embeds.shape[0]
|
348 |
+
|
349 |
+
if prompt_embeds is None:
|
350 |
+
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
|
351 |
+
prompt=prompt,
|
352 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
353 |
+
max_sequence_length=max_sequence_length,
|
354 |
+
device=device,
|
355 |
+
dtype=dtype,
|
356 |
+
)
|
357 |
+
|
358 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
359 |
+
negative_prompt = negative_prompt or ""
|
360 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
361 |
+
|
362 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
363 |
+
raise TypeError(
|
364 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
365 |
+
f" {type(prompt)}."
|
366 |
+
)
|
367 |
+
elif batch_size != len(negative_prompt):
|
368 |
+
raise ValueError(
|
369 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
370 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
371 |
+
" the batch size of `prompt`."
|
372 |
+
)
|
373 |
+
|
374 |
+
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
|
375 |
+
prompt=negative_prompt,
|
376 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
377 |
+
max_sequence_length=max_sequence_length,
|
378 |
+
device=device,
|
379 |
+
dtype=dtype,
|
380 |
+
)
|
381 |
+
|
382 |
+
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
383 |
+
|
384 |
+
def check_inputs(
|
385 |
+
self,
|
386 |
+
prompt,
|
387 |
+
height,
|
388 |
+
width,
|
389 |
+
callback_on_step_end_tensor_inputs=None,
|
390 |
+
prompt_embeds=None,
|
391 |
+
negative_prompt_embeds=None,
|
392 |
+
prompt_attention_mask=None,
|
393 |
+
negative_prompt_attention_mask=None,
|
394 |
+
):
|
395 |
+
if height % 8 != 0 or width % 8 != 0:
|
396 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
397 |
+
|
398 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
399 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
400 |
+
):
|
401 |
+
raise ValueError(
|
402 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
403 |
+
)
|
404 |
+
|
405 |
+
if prompt is not None and prompt_embeds is not None:
|
406 |
+
raise ValueError(
|
407 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
408 |
+
" only forward one of the two."
|
409 |
+
)
|
410 |
+
elif prompt is None and prompt_embeds is None:
|
411 |
+
raise ValueError(
|
412 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
413 |
+
)
|
414 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
415 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
416 |
+
|
417 |
+
if prompt_embeds is not None and prompt_attention_mask is None:
|
418 |
+
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
419 |
+
|
420 |
+
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
421 |
+
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
422 |
+
|
423 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
424 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
425 |
+
raise ValueError(
|
426 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
427 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
428 |
+
f" {negative_prompt_embeds.shape}."
|
429 |
+
)
|
430 |
+
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
431 |
+
raise ValueError(
|
432 |
+
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
433 |
+
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
434 |
+
f" {negative_prompt_attention_mask.shape}."
|
435 |
+
)
|
436 |
+
|
437 |
+
def enable_vae_slicing(self):
|
438 |
+
r"""
|
439 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
440 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
441 |
+
"""
|
442 |
+
self.vae.enable_slicing()
|
443 |
+
|
444 |
+
def disable_vae_slicing(self):
|
445 |
+
r"""
|
446 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
447 |
+
computing decoding in one step.
|
448 |
+
"""
|
449 |
+
self.vae.disable_slicing()
|
450 |
+
|
451 |
+
def enable_vae_tiling(self):
|
452 |
+
r"""
|
453 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
454 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
455 |
+
processing larger images.
|
456 |
+
"""
|
457 |
+
self.vae.enable_tiling()
|
458 |
+
|
459 |
+
def disable_vae_tiling(self):
|
460 |
+
r"""
|
461 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
462 |
+
computing decoding in one step.
|
463 |
+
"""
|
464 |
+
self.vae.disable_tiling()
|
465 |
+
|
466 |
+
def prepare_latents(
|
467 |
+
self,
|
468 |
+
batch_size,
|
469 |
+
num_channels_latents,
|
470 |
+
height,
|
471 |
+
width,
|
472 |
+
num_frames,
|
473 |
+
dtype,
|
474 |
+
device,
|
475 |
+
generator,
|
476 |
+
latents=None,
|
477 |
+
):
|
478 |
+
height = height // self.vae_spatial_scale_factor
|
479 |
+
width = width // self.vae_spatial_scale_factor
|
480 |
+
num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1
|
481 |
+
|
482 |
+
shape = (batch_size, num_channels_latents, num_frames, height, width)
|
483 |
+
|
484 |
+
if latents is not None:
|
485 |
+
return latents.to(device=device, dtype=dtype)
|
486 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
487 |
+
raise ValueError(
|
488 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
489 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
490 |
+
)
|
491 |
+
|
492 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)
|
493 |
+
latents = latents.to(dtype)
|
494 |
+
return latents
|
495 |
+
|
496 |
+
@property
|
497 |
+
def guidance_scale(self):
|
498 |
+
return self._guidance_scale
|
499 |
+
|
500 |
+
@property
|
501 |
+
def do_classifier_free_guidance(self):
|
502 |
+
return self._guidance_scale > 1.0
|
503 |
+
|
504 |
+
@property
|
505 |
+
def num_timesteps(self):
|
506 |
+
return self._num_timesteps
|
507 |
+
|
508 |
+
@property
|
509 |
+
def attention_kwargs(self):
|
510 |
+
return self._attention_kwargs
|
511 |
+
|
512 |
+
@property
|
513 |
+
def interrupt(self):
|
514 |
+
return self._interrupt
|
515 |
+
|
516 |
+
|
517 |
+
@torch.no_grad()
|
518 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
519 |
+
def __call__(
|
520 |
+
self,
|
521 |
+
prompt: Union[str, List[str]] = None,
|
522 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
523 |
+
height: Optional[int] = None,
|
524 |
+
width: Optional[int] = None,
|
525 |
+
num_frames: int = 19,
|
526 |
+
num_inference_steps: int = 64,
|
527 |
+
timesteps: List[int] = None,
|
528 |
+
guidance_scale: float = 4.5,
|
529 |
+
num_videos_per_prompt: Optional[int] = 1,
|
530 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
531 |
+
latents: Optional[torch.Tensor] = None,
|
532 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
533 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
534 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
535 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
536 |
+
output_type: Optional[str] = "pil",
|
537 |
+
return_dict: bool = True,
|
538 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
539 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
540 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
541 |
+
max_sequence_length: int = 256,
|
542 |
+
):
|
543 |
+
r"""
|
544 |
+
Function invoked when calling the pipeline for generation.
|
545 |
+
|
546 |
+
Args:
|
547 |
+
prompt (`str` or `List[str]`, *optional*):
|
548 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
549 |
+
instead.
|
550 |
+
height (`int`, *optional*, defaults to `self.default_height`):
|
551 |
+
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
552 |
+
width (`int`, *optional*, defaults to `self.default_width`):
|
553 |
+
The width in pixels of the generated image. This is set to 848 by default for the best results.
|
554 |
+
num_frames (`int`, defaults to `19`):
|
555 |
+
The number of video frames to generate
|
556 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
557 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
558 |
+
expense of slower inference.
|
559 |
+
timesteps (`List[int]`, *optional*):
|
560 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
561 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
562 |
+
passed will be used. Must be in descending order.
|
563 |
+
guidance_scale (`float`, defaults to `4.5`):
|
564 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
565 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
566 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
567 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
568 |
+
usually at the expense of lower image quality.
|
569 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
570 |
+
The number of videos to generate per prompt.
|
571 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
572 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
573 |
+
to make generation deterministic.
|
574 |
+
latents (`torch.Tensor`, *optional*):
|
575 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
576 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
577 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
578 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
579 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
580 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
581 |
+
prompt_attention_mask (`torch.Tensor`, *optional*):
|
582 |
+
Pre-generated attention mask for text embeddings.
|
583 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
584 |
+
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
585 |
+
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
586 |
+
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
587 |
+
Pre-generated attention mask for negative text embeddings.
|
588 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
589 |
+
The output format of the generate image. Choose between
|
590 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
591 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
592 |
+
Whether or not to return a [`~pipelines.mochi.MochiPipelineOutput`] instead of a plain tuple.
|
593 |
+
attention_kwargs (`dict`, *optional*):
|
594 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
595 |
+
`self.processor` in
|
596 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
597 |
+
callback_on_step_end (`Callable`, *optional*):
|
598 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
599 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
600 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
601 |
+
`callback_on_step_end_tensor_inputs`.
|
602 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
603 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
604 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
605 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
606 |
+
max_sequence_length (`int` defaults to `256`):
|
607 |
+
Maximum sequence length to use with the `prompt`.
|
608 |
+
|
609 |
+
Examples:
|
610 |
+
|
611 |
+
Returns:
|
612 |
+
[`~pipelines.mochi.MochiPipelineOutput`] or `tuple`:
|
613 |
+
If `return_dict` is `True`, [`~pipelines.mochi.MochiPipelineOutput`] is returned, otherwise a `tuple`
|
614 |
+
is returned where the first element is a list with the generated images.
|
615 |
+
"""
|
616 |
+
|
617 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
618 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
619 |
+
|
620 |
+
height = height or self.default_height
|
621 |
+
width = width or self.default_width
|
622 |
+
|
623 |
+
# 1. Check inputs. Raise error if not correct
|
624 |
+
self.check_inputs(
|
625 |
+
prompt=prompt,
|
626 |
+
height=height,
|
627 |
+
width=width,
|
628 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
629 |
+
prompt_embeds=prompt_embeds,
|
630 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
631 |
+
prompt_attention_mask=prompt_attention_mask,
|
632 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
633 |
+
)
|
634 |
+
|
635 |
+
self._guidance_scale = guidance_scale
|
636 |
+
self._attention_kwargs = attention_kwargs
|
637 |
+
self._interrupt = False
|
638 |
+
|
639 |
+
# 2. Define call parameters
|
640 |
+
if prompt is not None and isinstance(prompt, str):
|
641 |
+
batch_size = 1
|
642 |
+
elif prompt is not None and isinstance(prompt, list):
|
643 |
+
batch_size = len(prompt)
|
644 |
+
else:
|
645 |
+
batch_size = prompt_embeds.shape[0]
|
646 |
+
|
647 |
+
device = self._execution_device
|
648 |
+
# 3. Prepare text embeddings
|
649 |
+
(
|
650 |
+
prompt_embeds,
|
651 |
+
prompt_attention_mask,
|
652 |
+
negative_prompt_embeds,
|
653 |
+
negative_prompt_attention_mask,
|
654 |
+
) = self.encode_prompt(
|
655 |
+
prompt=prompt,
|
656 |
+
negative_prompt=negative_prompt,
|
657 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
658 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
659 |
+
prompt_embeds=prompt_embeds,
|
660 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
661 |
+
prompt_attention_mask=prompt_attention_mask,
|
662 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
663 |
+
max_sequence_length=max_sequence_length,
|
664 |
+
device=device,
|
665 |
+
)
|
666 |
+
# 4. Prepare latent variables
|
667 |
+
num_channels_latents = self.transformer.config.in_channels
|
668 |
+
latents = self.prepare_latents(
|
669 |
+
batch_size * num_videos_per_prompt,
|
670 |
+
num_channels_latents,
|
671 |
+
height,
|
672 |
+
width,
|
673 |
+
num_frames,
|
674 |
+
prompt_embeds.dtype,
|
675 |
+
device,
|
676 |
+
generator,
|
677 |
+
latents,
|
678 |
+
).repeat(1,1,2,1,1)
|
679 |
+
|
680 |
+
if self.do_classifier_free_guidance:
|
681 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
682 |
+
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
683 |
+
|
684 |
+
|
685 |
+
# 5.5 Prepare attention rectification masks
|
686 |
+
all_attention_mask = prepare_attention_mask(prompt_attention_mask, latents)
|
687 |
+
|
688 |
+
# 5. Prepare timestep
|
689 |
+
# from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
|
690 |
+
threshold_noise = 0.025
|
691 |
+
sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
|
692 |
+
sigmas = np.array(sigmas)
|
693 |
+
|
694 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
695 |
+
self.scheduler,
|
696 |
+
num_inference_steps,
|
697 |
+
device,
|
698 |
+
timesteps,
|
699 |
+
sigmas,
|
700 |
+
)
|
701 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
702 |
+
self._num_timesteps = len(timesteps)
|
703 |
+
|
704 |
+
# 6. Denoising loop
|
705 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
706 |
+
for i, t in enumerate(timesteps):
|
707 |
+
if self.interrupt:
|
708 |
+
continue
|
709 |
+
|
710 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
711 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
712 |
+
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
713 |
+
|
714 |
+
noise_pred = self.transformer(
|
715 |
+
hidden_states=latent_model_input,
|
716 |
+
encoder_hidden_states=prompt_embeds,
|
717 |
+
timestep=timestep,
|
718 |
+
encoder_attention_mask=all_attention_mask,
|
719 |
+
attention_kwargs=attention_kwargs,
|
720 |
+
return_dict=False,
|
721 |
+
)[0]
|
722 |
+
# Mochi CFG + Sampling runs in FP32
|
723 |
+
noise_pred = noise_pred.to(torch.float32)
|
724 |
+
|
725 |
+
if self.do_classifier_free_guidance:
|
726 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
727 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
728 |
+
|
729 |
+
# compute the previous noisy sample x_t -> x_t-1
|
730 |
+
latents_dtype = latents.dtype
|
731 |
+
latents = self.scheduler.step(noise_pred, t, latents.to(torch.float32), return_dict=False)[0]
|
732 |
+
latents = latents.to(latents_dtype)
|
733 |
+
|
734 |
+
if latents.dtype != latents_dtype:
|
735 |
+
if torch.backends.mps.is_available():
|
736 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
737 |
+
latents = latents.to(latents_dtype)
|
738 |
+
|
739 |
+
if callback_on_step_end is not None:
|
740 |
+
callback_kwargs = {}
|
741 |
+
for k in callback_on_step_end_tensor_inputs:
|
742 |
+
callback_kwargs[k] = locals()[k]
|
743 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
744 |
+
|
745 |
+
latents = callback_outputs.pop("latents", latents)
|
746 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
747 |
+
|
748 |
+
# call the callback, if provided
|
749 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
750 |
+
progress_bar.update()
|
751 |
+
|
752 |
+
if XLA_AVAILABLE:
|
753 |
+
xm.mark_step()
|
754 |
+
|
755 |
+
if output_type == "latent":
|
756 |
+
video = latents
|
757 |
+
else:
|
758 |
+
# unscale/denormalize the latents
|
759 |
+
# denormalize with the mean and std if available and not None
|
760 |
+
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
761 |
+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
762 |
+
if has_latents_mean and has_latents_std:
|
763 |
+
latents_mean = (
|
764 |
+
torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
|
765 |
+
)
|
766 |
+
latents_std = (
|
767 |
+
torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
|
768 |
+
)
|
769 |
+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
770 |
+
else:
|
771 |
+
latents = latents / self.vae.config.scaling_factor
|
772 |
+
|
773 |
+
video = self.vae.decode(latents, return_dict=False)[0]
|
774 |
+
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
775 |
+
|
776 |
+
# Offload all models
|
777 |
+
self.maybe_free_model_hooks()
|
778 |
+
|
779 |
+
if not return_dict:
|
780 |
+
return (video,)
|
781 |
+
|
782 |
+
return MochiPipelineOutput(frames=video)
|
Mochi/prepare_dataset.sh
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
GPU_ID=0
|
4 |
+
VIDEO_DIR=video-dataset-disney-organized
|
5 |
+
OUTPUT_DIR=videos_prepared
|
6 |
+
NUM_FRAMES=37
|
7 |
+
RESOLUTION=480x848
|
8 |
+
|
9 |
+
# Extract width and height from RESOLUTION
|
10 |
+
WIDTH=$(echo $RESOLUTION | cut -dx -f1)
|
11 |
+
HEIGHT=$(echo $RESOLUTION | cut -dx -f2)
|
12 |
+
|
13 |
+
python trim_and_crop_videos.py $VIDEO_DIR $OUTPUT_DIR --num_frames=$NUM_FRAMES --resolution=$RESOLUTION --force_upsample
|
14 |
+
|
15 |
+
CUDA_VISIBLE_DEVICES=$GPU_ID python embed.py $OUTPUT_DIR --shape=${NUM_FRAMES}x${WIDTH}x${HEIGHT}
|
Mochi/rgba_utils.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
6 |
+
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
7 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
8 |
+
|
9 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
10 |
+
|
11 |
+
@torch.no_grad()
|
12 |
+
def decode_latents(pipe, latents):
|
13 |
+
has_latents_mean = hasattr(pipe.vae.config, "latents_mean") and pipe.vae.config.latents_mean is not None
|
14 |
+
has_latents_std = hasattr(pipe.vae.config, "latents_std") and pipe.vae.config.latents_std is not None
|
15 |
+
if has_latents_mean and has_latents_std:
|
16 |
+
latents_mean = (
|
17 |
+
torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
|
18 |
+
)
|
19 |
+
latents_std = (
|
20 |
+
torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
|
21 |
+
)
|
22 |
+
latents = latents * latents_std / pipe.vae.config.scaling_factor + latents_mean
|
23 |
+
else:
|
24 |
+
latents = latents / pipe.vae.config.scaling_factor
|
25 |
+
|
26 |
+
video = pipe.vae.decode(latents, return_dict=False)[0]
|
27 |
+
video = pipe.video_processor.postprocess_video(video, output_type='np')
|
28 |
+
|
29 |
+
return video
|
30 |
+
|
31 |
+
|
32 |
+
class RGBALoRAMochiAttnProcessor:
|
33 |
+
"""Attention processor used in Mochi."""
|
34 |
+
def __init__(self, device, dtype, lora_rank=128, lora_alpha=1.0, latent_dim=3072):
|
35 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
36 |
+
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
37 |
+
|
38 |
+
|
39 |
+
# Initialize LoRA layers
|
40 |
+
self.lora_alpha = lora_alpha
|
41 |
+
self.lora_rank = lora_rank
|
42 |
+
|
43 |
+
# Helper function to create LoRA layers
|
44 |
+
def create_lora_layer(in_dim, mid_dim, out_dim, device=device, dtype=dtype):
|
45 |
+
# Define the LoRA layers
|
46 |
+
lora_a = nn.Linear(in_dim, mid_dim, bias=False, device=device, dtype=dtype)
|
47 |
+
lora_b = nn.Linear(mid_dim, out_dim, bias=False, device=device, dtype=dtype)
|
48 |
+
|
49 |
+
# Initialize lora_a with random parameters (default initialization)
|
50 |
+
nn.init.kaiming_uniform_(lora_a.weight, a=math.sqrt(5)) # or another suitable initialization
|
51 |
+
|
52 |
+
# Initialize lora_b with zero values
|
53 |
+
nn.init.zeros_(lora_b.weight)
|
54 |
+
|
55 |
+
lora_a.weight.requires_grad = True
|
56 |
+
lora_b.weight.requires_grad = True
|
57 |
+
|
58 |
+
# Combine the layers into a sequential module
|
59 |
+
return nn.Sequential(lora_a, lora_b)
|
60 |
+
|
61 |
+
self.to_q_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
|
62 |
+
self.to_k_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
|
63 |
+
self.to_v_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
|
64 |
+
self.to_out_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
|
65 |
+
|
66 |
+
def _apply_lora(self, hidden_states, seq_len, query, key, value, scaling):
|
67 |
+
"""Applies LoRA updates to query, key, and value tensors."""
|
68 |
+
query_delta = self.to_q_lora(hidden_states).to(query.device)
|
69 |
+
query[:, -seq_len // 2:, :] += query_delta[:, -seq_len // 2:, :] * scaling
|
70 |
+
|
71 |
+
key_delta = self.to_k_lora(hidden_states).to(key.device)
|
72 |
+
key[:, -seq_len // 2:, :] += key_delta[:, -seq_len // 2:, :] * scaling
|
73 |
+
|
74 |
+
value_delta = self.to_v_lora(hidden_states).to(value.device)
|
75 |
+
value[:, -seq_len // 2:, :] += value_delta[:, -seq_len // 2:, :] * scaling
|
76 |
+
|
77 |
+
return query, key, value
|
78 |
+
|
79 |
+
def __call__(
|
80 |
+
self,
|
81 |
+
attn,
|
82 |
+
hidden_states: torch.Tensor,
|
83 |
+
encoder_hidden_states: torch.Tensor,
|
84 |
+
attention_mask: Optional[torch.Tensor] = None,
|
85 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
86 |
+
) -> torch.Tensor:
|
87 |
+
query = attn.to_q(hidden_states)
|
88 |
+
key = attn.to_k(hidden_states)
|
89 |
+
value = attn.to_v(hidden_states)
|
90 |
+
|
91 |
+
scaling = self.lora_alpha / self.lora_rank
|
92 |
+
sequence_length = query.size(1)
|
93 |
+
query, key, value = self._apply_lora(hidden_states, sequence_length, query, key, value, scaling)
|
94 |
+
|
95 |
+
query = query.unflatten(2, (attn.heads, -1))
|
96 |
+
key = key.unflatten(2, (attn.heads, -1))
|
97 |
+
value = value.unflatten(2, (attn.heads, -1))
|
98 |
+
|
99 |
+
if attn.norm_q is not None:
|
100 |
+
query = attn.norm_q(query)
|
101 |
+
if attn.norm_k is not None:
|
102 |
+
key = attn.norm_k(key)
|
103 |
+
|
104 |
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
105 |
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
106 |
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
107 |
+
|
108 |
+
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
|
109 |
+
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
|
110 |
+
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
|
111 |
+
|
112 |
+
if attn.norm_added_q is not None:
|
113 |
+
encoder_query = attn.norm_added_q(encoder_query)
|
114 |
+
if attn.norm_added_k is not None:
|
115 |
+
encoder_key = attn.norm_added_k(encoder_key)
|
116 |
+
|
117 |
+
if image_rotary_emb is not None:
|
118 |
+
|
119 |
+
def apply_rotary_emb(x, freqs_cos, freqs_sin):
|
120 |
+
x_even = x[..., 0::2].float()
|
121 |
+
x_odd = x[..., 1::2].float()
|
122 |
+
|
123 |
+
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
|
124 |
+
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
|
125 |
+
|
126 |
+
return torch.stack([cos, sin], dim=-1).flatten(-2)
|
127 |
+
|
128 |
+
query[:,sequence_length//2:] = apply_rotary_emb(query[:,sequence_length//2:], *image_rotary_emb)
|
129 |
+
query[:,:sequence_length//2] = apply_rotary_emb(query[:,:sequence_length//2], *image_rotary_emb)
|
130 |
+
|
131 |
+
key[:,sequence_length//2:] = apply_rotary_emb(key[:,sequence_length//2:], *image_rotary_emb)
|
132 |
+
key[:,:sequence_length//2] = apply_rotary_emb(key[:,:sequence_length//2], *image_rotary_emb)
|
133 |
+
# query = apply_rotary_emb(query, *image_rotary_emb)
|
134 |
+
# key = apply_rotary_emb(key, *image_rotary_emb)
|
135 |
+
|
136 |
+
|
137 |
+
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
138 |
+
encoder_query, encoder_key, encoder_value = (
|
139 |
+
encoder_query.transpose(1, 2),
|
140 |
+
encoder_key.transpose(1, 2),
|
141 |
+
encoder_value.transpose(1, 2),
|
142 |
+
)
|
143 |
+
|
144 |
+
sequence_length = query.size(2)
|
145 |
+
encoder_sequence_length = encoder_query.size(2)
|
146 |
+
total_length = sequence_length + encoder_sequence_length
|
147 |
+
|
148 |
+
batch_size, heads, _, dim = query.shape
|
149 |
+
|
150 |
+
attn_outputs = []
|
151 |
+
prompt_attention_mask = attention_mask["prompt_attention_mask"]
|
152 |
+
rect_attention_mask = attention_mask["rect_attention_mask"]
|
153 |
+
for idx in range(batch_size):
|
154 |
+
mask = prompt_attention_mask[idx][None, :] # two components: attention mask and prompt mask
|
155 |
+
valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()
|
156 |
+
|
157 |
+
valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :]
|
158 |
+
valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :]
|
159 |
+
valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :]
|
160 |
+
|
161 |
+
valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2)
|
162 |
+
valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2)
|
163 |
+
valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2)
|
164 |
+
|
165 |
+
attn_output = F.scaled_dot_product_attention(
|
166 |
+
valid_query,
|
167 |
+
valid_key,
|
168 |
+
valid_value,
|
169 |
+
dropout_p=0.0,
|
170 |
+
attn_mask=rect_attention_mask[idx],
|
171 |
+
is_causal=False
|
172 |
+
)
|
173 |
+
valid_sequence_length = attn_output.size(2)
|
174 |
+
attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length))
|
175 |
+
attn_outputs.append(attn_output)
|
176 |
+
|
177 |
+
hidden_states = torch.cat(attn_outputs, dim=0)
|
178 |
+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
179 |
+
|
180 |
+
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
|
181 |
+
(sequence_length, encoder_sequence_length), dim=1
|
182 |
+
)
|
183 |
+
|
184 |
+
# linear proj
|
185 |
+
original_hidden_states = attn.to_out[0](hidden_states)
|
186 |
+
hidden_states_delta = self.to_out_lora(hidden_states).to(hidden_states.device)
|
187 |
+
original_hidden_states[:, -sequence_length // 2:, :] += hidden_states_delta[:, -sequence_length // 2:, :] * scaling
|
188 |
+
# dropout
|
189 |
+
hidden_states = attn.to_out[1](original_hidden_states)
|
190 |
+
|
191 |
+
if hasattr(attn, "to_add_out"):
|
192 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
193 |
+
|
194 |
+
return hidden_states, encoder_hidden_states
|
195 |
+
|
196 |
+
def prepare_for_rgba_inference(
|
197 |
+
model, device: torch.device, dtype: torch.dtype,
|
198 |
+
lora_rank: int = 128, lora_alpha: float = 1.0
|
199 |
+
):
|
200 |
+
|
201 |
+
def custom_forward(self):
|
202 |
+
def forward(
|
203 |
+
hidden_states: torch.Tensor,
|
204 |
+
encoder_hidden_states: torch.Tensor,
|
205 |
+
timestep: torch.LongTensor,
|
206 |
+
encoder_attention_mask: torch.Tensor,
|
207 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
208 |
+
return_dict: bool = True,
|
209 |
+
) -> torch.Tensor:
|
210 |
+
if attention_kwargs is not None:
|
211 |
+
attention_kwargs = attention_kwargs.copy()
|
212 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
213 |
+
else:
|
214 |
+
lora_scale = 1.0
|
215 |
+
|
216 |
+
if USE_PEFT_BACKEND:
|
217 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
218 |
+
scale_lora_layers(self, lora_scale)
|
219 |
+
else:
|
220 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
221 |
+
logger.warning(
|
222 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
223 |
+
)
|
224 |
+
|
225 |
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
226 |
+
p = self.config.patch_size
|
227 |
+
|
228 |
+
post_patch_height = height // p
|
229 |
+
post_patch_width = width // p
|
230 |
+
|
231 |
+
temb, encoder_hidden_states = self.time_embed(
|
232 |
+
timestep,
|
233 |
+
encoder_hidden_states,
|
234 |
+
encoder_attention_mask["prompt_attention_mask"],
|
235 |
+
hidden_dtype=hidden_states.dtype,
|
236 |
+
)
|
237 |
+
|
238 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
239 |
+
hidden_states = self.patch_embed(hidden_states)
|
240 |
+
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
|
241 |
+
|
242 |
+
image_rotary_emb = self.rope(
|
243 |
+
self.pos_frequencies,
|
244 |
+
num_frames // 2, # Identitical PE for RGB and Alpha
|
245 |
+
post_patch_height,
|
246 |
+
post_patch_width,
|
247 |
+
device=hidden_states.device,
|
248 |
+
dtype=torch.float32,
|
249 |
+
)
|
250 |
+
|
251 |
+
for i, block in enumerate(self.transformer_blocks):
|
252 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
253 |
+
|
254 |
+
def create_custom_forward(module):
|
255 |
+
def custom_forward(*inputs):
|
256 |
+
return module(*inputs)
|
257 |
+
|
258 |
+
return custom_forward
|
259 |
+
|
260 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
261 |
+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
262 |
+
create_custom_forward(block),
|
263 |
+
hidden_states,
|
264 |
+
encoder_hidden_states,
|
265 |
+
temb,
|
266 |
+
encoder_attention_mask,
|
267 |
+
image_rotary_emb,
|
268 |
+
**ckpt_kwargs,
|
269 |
+
)
|
270 |
+
else:
|
271 |
+
hidden_states, encoder_hidden_states = block(
|
272 |
+
hidden_states=hidden_states,
|
273 |
+
encoder_hidden_states=encoder_hidden_states,
|
274 |
+
temb=temb,
|
275 |
+
encoder_attention_mask=encoder_attention_mask,
|
276 |
+
image_rotary_emb=image_rotary_emb,
|
277 |
+
)
|
278 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
279 |
+
hidden_states = self.proj_out(hidden_states)
|
280 |
+
|
281 |
+
hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1)
|
282 |
+
hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
|
283 |
+
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
|
284 |
+
|
285 |
+
if USE_PEFT_BACKEND:
|
286 |
+
# remove `lora_scale` from each PEFT layer
|
287 |
+
unscale_lora_layers(self, lora_scale)
|
288 |
+
|
289 |
+
if not return_dict:
|
290 |
+
return (output,)
|
291 |
+
return Transformer2DModelOutput(sample=output)
|
292 |
+
return forward
|
293 |
+
|
294 |
+
for _, block in enumerate(model.transformer_blocks):
|
295 |
+
attn_processor = RGBALoRAMochiAttnProcessor(
|
296 |
+
device=device,
|
297 |
+
dtype=dtype,
|
298 |
+
lora_rank=lora_rank,
|
299 |
+
lora_alpha=lora_alpha
|
300 |
+
)
|
301 |
+
# block.attn1.set_processor(attn_processor)
|
302 |
+
block.attn1.processor = attn_processor
|
303 |
+
|
304 |
+
model.forward = custom_forward(model)
|
305 |
+
|
306 |
+
def get_processor_state_dict(model):
|
307 |
+
"""Save trainable parameters of processors to a checkpoint."""
|
308 |
+
processor_state_dict = {}
|
309 |
+
|
310 |
+
for index, block in enumerate(model.transformer_blocks):
|
311 |
+
if hasattr(block.attn1, "processor"):
|
312 |
+
processor = block.attn1.processor
|
313 |
+
for attr_name in ["to_q_lora", "to_k_lora", "to_v_lora", "to_out_lora"]:
|
314 |
+
if hasattr(processor, attr_name):
|
315 |
+
lora_layer = getattr(processor, attr_name)
|
316 |
+
for param_name, param in lora_layer.named_parameters():
|
317 |
+
key = f"block_{index}.{attr_name}.{param_name}"
|
318 |
+
processor_state_dict[key] = param.data.clone()
|
319 |
+
|
320 |
+
# torch.save({"processor_state_dict": processor_state_dict}, checkpoint_path)
|
321 |
+
# print(f"Processor state_dict saved to {checkpoint_path}")
|
322 |
+
return processor_state_dict
|
323 |
+
|
324 |
+
def load_processor_state_dict(model, processor_state_dict):
|
325 |
+
"""Load trainable parameters of processors from a checkpoint."""
|
326 |
+
for index, block in enumerate(model.transformer_blocks):
|
327 |
+
if hasattr(block.attn1, "processor"):
|
328 |
+
processor = block.attn1.processor
|
329 |
+
for attr_name in ["to_q_lora", "to_k_lora", "to_v_lora", "to_out_lora"]:
|
330 |
+
if hasattr(processor, attr_name):
|
331 |
+
lora_layer = getattr(processor, attr_name)
|
332 |
+
for param_name, param in lora_layer.named_parameters():
|
333 |
+
key = f"block_{index}.{attr_name}.{param_name}"
|
334 |
+
if key in processor_state_dict:
|
335 |
+
param.data.copy_(processor_state_dict[key])
|
336 |
+
else:
|
337 |
+
raise KeyError(f"Missing key {key} in checkpoint.")
|
338 |
+
|
339 |
+
# Prepare training parameters
|
340 |
+
def get_processor_params(processor):
|
341 |
+
params = []
|
342 |
+
for attr_name in ["to_q_lora", "to_k_lora", "to_v_lora", "to_out_lora"]:
|
343 |
+
if hasattr(processor, attr_name):
|
344 |
+
lora_layer = getattr(processor, attr_name)
|
345 |
+
params.extend(p for p in lora_layer.parameters() if p.requires_grad)
|
346 |
+
return params
|
347 |
+
|
348 |
+
def get_all_processor_params(transformer):
|
349 |
+
all_params = []
|
350 |
+
for block in transformer.transformer_blocks:
|
351 |
+
if hasattr(block.attn1, "processor"):
|
352 |
+
processor = block.attn1.processor
|
353 |
+
all_params.extend(get_processor_params(processor))
|
354 |
+
return all_params
|
Mochi/train.py
ADDED
@@ -0,0 +1,593 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import gc
|
17 |
+
import random
|
18 |
+
from glob import glob
|
19 |
+
import math
|
20 |
+
import os
|
21 |
+
import torch.nn.functional as F
|
22 |
+
import numpy as np
|
23 |
+
from pathlib import Path
|
24 |
+
from typing import Any, Dict, Tuple, List
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import wandb
|
28 |
+
from pipeline_mochi_rgba import *
|
29 |
+
from diffusers import FlowMatchEulerDiscreteScheduler, MochiTransformer3DModel
|
30 |
+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
31 |
+
from diffusers.training_utils import cast_training_params
|
32 |
+
from diffusers.utils import export_to_video
|
33 |
+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
34 |
+
from huggingface_hub import create_repo, upload_folder
|
35 |
+
from torch.utils.data import DataLoader
|
36 |
+
from tqdm.auto import tqdm
|
37 |
+
|
38 |
+
|
39 |
+
from args import get_args # isort:skip
|
40 |
+
from dataset_simple import LatentEmbedDataset
|
41 |
+
|
42 |
+
from utils import print_memory, reset_memory # isort:skip
|
43 |
+
from rgba_utils import *
|
44 |
+
|
45 |
+
|
46 |
+
# Taken from
|
47 |
+
# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/train.py#L139
|
48 |
+
def get_cosine_annealing_lr_scheduler(
|
49 |
+
optimizer: torch.optim.Optimizer,
|
50 |
+
warmup_steps: int,
|
51 |
+
total_steps: int,
|
52 |
+
):
|
53 |
+
def lr_lambda(step):
|
54 |
+
if step < warmup_steps:
|
55 |
+
return float(step) / float(max(1, warmup_steps))
|
56 |
+
else:
|
57 |
+
return 0.5 * (1 + np.cos(np.pi * (step - warmup_steps) / (total_steps - warmup_steps)))
|
58 |
+
|
59 |
+
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
60 |
+
|
61 |
+
|
62 |
+
def save_model_card(
|
63 |
+
repo_id: str,
|
64 |
+
videos=None,
|
65 |
+
base_model: str = None,
|
66 |
+
validation_prompt=None,
|
67 |
+
repo_folder=None,
|
68 |
+
fps=30,
|
69 |
+
):
|
70 |
+
widget_dict = []
|
71 |
+
if videos is not None and len(videos) > 0:
|
72 |
+
for i, video in enumerate(videos):
|
73 |
+
export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4"), fps=fps)
|
74 |
+
widget_dict.append(
|
75 |
+
{
|
76 |
+
"text": validation_prompt if validation_prompt else " ",
|
77 |
+
"output": {"url": f"final_video_{i}.mp4"},
|
78 |
+
}
|
79 |
+
)
|
80 |
+
|
81 |
+
model_description = f"""
|
82 |
+
# Mochi-1 Preview LoRA Finetune
|
83 |
+
|
84 |
+
<Gallery />
|
85 |
+
|
86 |
+
## Model description
|
87 |
+
|
88 |
+
This is a lora finetune of the Mochi-1 preview model `{base_model}`.
|
89 |
+
|
90 |
+
The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX and Mochi family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
|
91 |
+
|
92 |
+
## Download model
|
93 |
+
|
94 |
+
[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
|
95 |
+
|
96 |
+
## Usage
|
97 |
+
|
98 |
+
Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
|
99 |
+
|
100 |
+
```py
|
101 |
+
from diffusers import MochiPipeline
|
102 |
+
from diffusers.utils import export_to_video
|
103 |
+
import torch
|
104 |
+
|
105 |
+
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview")
|
106 |
+
pipe.load_lora_weights("CHANGE_ME")
|
107 |
+
pipe.enable_model_cpu_offload()
|
108 |
+
|
109 |
+
with torch.autocast("cuda", torch.bfloat16):
|
110 |
+
video = pipe(
|
111 |
+
prompt="CHANGE_ME",
|
112 |
+
guidance_scale=6.0,
|
113 |
+
num_inference_steps=64,
|
114 |
+
height=480,
|
115 |
+
width=848,
|
116 |
+
max_sequence_length=256,
|
117 |
+
output_type="np"
|
118 |
+
).frames[0]
|
119 |
+
export_to_video(video)
|
120 |
+
```
|
121 |
+
|
122 |
+
For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
|
123 |
+
|
124 |
+
"""
|
125 |
+
model_card = load_or_create_model_card(
|
126 |
+
repo_id_or_path=repo_id,
|
127 |
+
from_training=True,
|
128 |
+
license="apache-2.0",
|
129 |
+
base_model=base_model,
|
130 |
+
prompt=validation_prompt,
|
131 |
+
model_description=model_description,
|
132 |
+
widget=widget_dict,
|
133 |
+
)
|
134 |
+
tags = [
|
135 |
+
"text-to-video",
|
136 |
+
"diffusers-training",
|
137 |
+
"diffusers",
|
138 |
+
"lora",
|
139 |
+
"mochi-1-preview",
|
140 |
+
"mochi-1-preview-diffusers",
|
141 |
+
"template:sd-lora",
|
142 |
+
]
|
143 |
+
|
144 |
+
model_card = populate_model_card(model_card, tags=tags)
|
145 |
+
model_card.save(os.path.join(repo_folder, "README.md"))
|
146 |
+
|
147 |
+
|
148 |
+
def log_validation(
|
149 |
+
pipe: MochiPipeline,
|
150 |
+
args: Dict[str, Any],
|
151 |
+
pipeline_args: Dict[str, Any],
|
152 |
+
step: int,
|
153 |
+
wandb_run: str = None,
|
154 |
+
is_final_validation: bool = False,
|
155 |
+
):
|
156 |
+
print(
|
157 |
+
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
|
158 |
+
)
|
159 |
+
phase_name = "test" if is_final_validation else "validation"
|
160 |
+
|
161 |
+
if not args.enable_model_cpu_offload:
|
162 |
+
pipe = pipe.to("cuda")
|
163 |
+
|
164 |
+
# run inference
|
165 |
+
generator = torch.manual_seed(args.seed) if args.seed else None
|
166 |
+
|
167 |
+
videos = []
|
168 |
+
with torch.autocast("cuda", torch.bfloat16, cache_enabled=False):
|
169 |
+
for _ in range(args.num_validation_videos):
|
170 |
+
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
|
171 |
+
videos.append(video)
|
172 |
+
|
173 |
+
video_filenames = []
|
174 |
+
for i, video in enumerate(videos):
|
175 |
+
prompt = (
|
176 |
+
pipeline_args["prompt"][:25]
|
177 |
+
.replace(" ", "_")
|
178 |
+
.replace(" ", "_")
|
179 |
+
.replace("'", "_")
|
180 |
+
.replace('"', "_")
|
181 |
+
.replace("/", "_")
|
182 |
+
)
|
183 |
+
filename = os.path.join(args.output_dir, f"{phase_name}_{str(step)}_video_{i}_{prompt}.mp4")
|
184 |
+
export_to_video(video, filename, fps=30)
|
185 |
+
video_filenames.append(filename)
|
186 |
+
|
187 |
+
if wandb_run:
|
188 |
+
wandb.log(
|
189 |
+
{
|
190 |
+
phase_name: [
|
191 |
+
wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}", fps=30)
|
192 |
+
for i, filename in enumerate(video_filenames)
|
193 |
+
]
|
194 |
+
}
|
195 |
+
)
|
196 |
+
|
197 |
+
return videos
|
198 |
+
|
199 |
+
|
200 |
+
# Adapted from the original code:
|
201 |
+
# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/pipelines.py#L578
|
202 |
+
def cast_dit(model, dtype):
|
203 |
+
for name, module in model.named_modules():
|
204 |
+
if isinstance(module, torch.nn.Linear):
|
205 |
+
assert any(
|
206 |
+
n in name for n in ["time_embed", "proj_out", "blocks", "norm_out"]
|
207 |
+
), f"Unexpected linear layer: {name}"
|
208 |
+
module.to(dtype=dtype)
|
209 |
+
elif isinstance(module, torch.nn.Conv2d):
|
210 |
+
module.to(dtype=dtype)
|
211 |
+
return model
|
212 |
+
|
213 |
+
|
214 |
+
def save_checkpoint(model, optimizer, lr_scheduler, global_step, checkpoint_path):
|
215 |
+
# lora_state_dict = get_peft_model_state_dict(model)
|
216 |
+
processor_state_dict = get_processor_state_dict(model)
|
217 |
+
torch.save(
|
218 |
+
{
|
219 |
+
"state_dict": processor_state_dict,
|
220 |
+
"optimizer": optimizer.state_dict(),
|
221 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
222 |
+
"global_step": global_step,
|
223 |
+
},
|
224 |
+
checkpoint_path,
|
225 |
+
)
|
226 |
+
|
227 |
+
|
228 |
+
class CollateFunction:
|
229 |
+
def __init__(self, caption_dropout: float = None) -> None:
|
230 |
+
self.caption_dropout = caption_dropout
|
231 |
+
|
232 |
+
def __call__(self, samples: List[Tuple[dict, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
233 |
+
ldists = torch.cat([data[0]["ldist"] for data in samples], dim=0)
|
234 |
+
z = DiagonalGaussianDistribution(ldists).sample()
|
235 |
+
assert torch.isfinite(z).all()
|
236 |
+
|
237 |
+
# Sample noise which we will add to the samples.
|
238 |
+
eps = torch.randn_like(z)
|
239 |
+
sigma = torch.rand(z.shape[:1], device="cpu", dtype=torch.float32)
|
240 |
+
|
241 |
+
prompt_embeds = torch.cat([data[1]["prompt_embeds"] for data in samples], dim=0)
|
242 |
+
prompt_attention_mask = torch.cat([data[1]["prompt_attention_mask"] for data in samples], dim=0)
|
243 |
+
if self.caption_dropout and random.random() < self.caption_dropout:
|
244 |
+
prompt_embeds.zero_()
|
245 |
+
prompt_attention_mask = prompt_attention_mask.long()
|
246 |
+
prompt_attention_mask.zero_()
|
247 |
+
prompt_attention_mask = prompt_attention_mask.bool()
|
248 |
+
|
249 |
+
return dict(
|
250 |
+
z=z, eps=eps, sigma=sigma, prompt_embeds=prompt_embeds, prompt_attention_mask=prompt_attention_mask
|
251 |
+
)
|
252 |
+
|
253 |
+
|
254 |
+
def main(args):
|
255 |
+
if not torch.cuda.is_available():
|
256 |
+
raise ValueError("Not supported without CUDA.")
|
257 |
+
|
258 |
+
if args.report_to == "wandb" and args.hub_token is not None:
|
259 |
+
raise ValueError(
|
260 |
+
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
261 |
+
" Please use `huggingface-cli login` to authenticate with the Hub."
|
262 |
+
)
|
263 |
+
|
264 |
+
# Handle the repository creation
|
265 |
+
if args.output_dir is not None:
|
266 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
267 |
+
|
268 |
+
# Prepare models and scheduler
|
269 |
+
transformer = MochiTransformer3DModel.from_pretrained(
|
270 |
+
args.pretrained_model_name_or_path,
|
271 |
+
subfolder="transformer",
|
272 |
+
revision=args.revision,
|
273 |
+
variant=args.variant,
|
274 |
+
)
|
275 |
+
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
276 |
+
args.pretrained_model_name_or_path, subfolder="scheduler"
|
277 |
+
)
|
278 |
+
|
279 |
+
transformer.requires_grad_(False)
|
280 |
+
transformer.to("cuda")
|
281 |
+
if args.gradient_checkpointing:
|
282 |
+
transformer.enable_gradient_checkpointing()
|
283 |
+
if args.cast_dit:
|
284 |
+
transformer = cast_dit(transformer, torch.bfloat16)
|
285 |
+
if args.compile_dit:
|
286 |
+
transformer.compile()
|
287 |
+
|
288 |
+
prepare_for_rgba_inference(
|
289 |
+
model=transformer,
|
290 |
+
device=torch.device("cuda"),
|
291 |
+
dtype=torch.bfloat16,
|
292 |
+
# seq_length=seq_length,
|
293 |
+
)
|
294 |
+
processor_params = get_all_processor_params(transformer)
|
295 |
+
|
296 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
297 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
298 |
+
if args.allow_tf32 and torch.cuda.is_available():
|
299 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
300 |
+
|
301 |
+
if args.scale_lr:
|
302 |
+
args.learning_rate = args.learning_rate * args.train_batch_size
|
303 |
+
# only upcast trainable parameters (LoRA) into fp32
|
304 |
+
|
305 |
+
if not isinstance(processor_params, list):
|
306 |
+
processor_params = [processor_params]
|
307 |
+
for m in processor_params:
|
308 |
+
for param in m:
|
309 |
+
# only upcast trainable parameters into fp32
|
310 |
+
if param.requires_grad:
|
311 |
+
param.data = param.to(torch.float32)
|
312 |
+
|
313 |
+
# Prepare optimizer
|
314 |
+
transformer_lora_parameters = processor_params # list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
315 |
+
num_trainable_parameters = sum(param.numel() for param in transformer_lora_parameters)
|
316 |
+
optimizer = torch.optim.AdamW(transformer_lora_parameters, lr=args.learning_rate, weight_decay=args.weight_decay)
|
317 |
+
|
318 |
+
# Dataset and DataLoader
|
319 |
+
train_vids = list(sorted(glob(f"{args.data_root}/*.mp4")))
|
320 |
+
train_vids = [v for v in train_vids if not v.endswith(".recon.mp4")]
|
321 |
+
print(f"Found {len(train_vids)} training videos in {args.data_root}")
|
322 |
+
assert len(train_vids) > 0, f"No training data found in {args.data_root}"
|
323 |
+
|
324 |
+
collate_fn = CollateFunction(caption_dropout=args.caption_dropout)
|
325 |
+
train_dataset = LatentEmbedDataset(train_vids, repeat=1)
|
326 |
+
train_dataloader = DataLoader(
|
327 |
+
train_dataset,
|
328 |
+
collate_fn=collate_fn,
|
329 |
+
batch_size=args.train_batch_size,
|
330 |
+
num_workers=args.dataloader_num_workers,
|
331 |
+
pin_memory=args.pin_memory,
|
332 |
+
)
|
333 |
+
|
334 |
+
# LR scheduler and math around the number of training steps.
|
335 |
+
overrode_max_train_steps = False
|
336 |
+
num_update_steps_per_epoch = len(train_dataloader)
|
337 |
+
if args.max_train_steps is None:
|
338 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
339 |
+
overrode_max_train_steps = True
|
340 |
+
|
341 |
+
lr_scheduler = get_cosine_annealing_lr_scheduler(
|
342 |
+
optimizer, warmup_steps=args.lr_warmup_steps, total_steps=args.max_train_steps
|
343 |
+
)
|
344 |
+
|
345 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
346 |
+
num_update_steps_per_epoch = len(train_dataloader)
|
347 |
+
if overrode_max_train_steps:
|
348 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
349 |
+
# Afterwards we recalculate our number of training epochs
|
350 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
351 |
+
|
352 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
353 |
+
# The trackers initializes automatically on the main process.
|
354 |
+
wandb_run = None
|
355 |
+
if args.report_to == "wandb":
|
356 |
+
tracker_name = args.tracker_name or "mochi-1-rgba-lora"
|
357 |
+
wandb_run = wandb.init(project=tracker_name, config=vars(args))
|
358 |
+
|
359 |
+
# Resume from checkpoint if specified
|
360 |
+
if args.resume_from_checkpoint:
|
361 |
+
checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu")
|
362 |
+
if "global_step" in checkpoint:
|
363 |
+
global_step = checkpoint["global_step"]
|
364 |
+
if "optimizer" in checkpoint:
|
365 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
366 |
+
if "lr_scheduler" in checkpoint:
|
367 |
+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
368 |
+
|
369 |
+
# set_peft_model_state_dict(transformer, checkpoint["state_dict"]) # Luozhou: modify this line
|
370 |
+
|
371 |
+
processor_state_dict = checkpoint["state_dict"]
|
372 |
+
load_processor_state_dict(transformer, processor_state_dict)
|
373 |
+
|
374 |
+
print(f"Resuming from checkpoint: {args.resume_from_checkpoint}")
|
375 |
+
print(f"Resuming from global step: {global_step}")
|
376 |
+
else:
|
377 |
+
global_step = 0
|
378 |
+
|
379 |
+
print("===== Memory before training =====")
|
380 |
+
reset_memory("cuda")
|
381 |
+
print_memory("cuda")
|
382 |
+
|
383 |
+
# Train!
|
384 |
+
total_batch_size = args.train_batch_size
|
385 |
+
print("***** Running training *****")
|
386 |
+
print(f" Num trainable parameters = {num_trainable_parameters}")
|
387 |
+
print(f" Num examples = {len(train_dataset)}")
|
388 |
+
print(f" Num batches each epoch = {len(train_dataloader)}")
|
389 |
+
print(f" Num epochs = {args.num_train_epochs}")
|
390 |
+
print(f" Instantaneous batch size per device = {args.train_batch_size}")
|
391 |
+
print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
392 |
+
print(f" Total optimization steps = {args.max_train_steps}")
|
393 |
+
|
394 |
+
first_epoch = 0
|
395 |
+
progress_bar = tqdm(
|
396 |
+
range(0, args.max_train_steps),
|
397 |
+
initial=global_step,
|
398 |
+
desc="Steps",
|
399 |
+
)
|
400 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
401 |
+
transformer.train()
|
402 |
+
|
403 |
+
for step, batch in enumerate(train_dataloader):
|
404 |
+
with torch.no_grad():
|
405 |
+
z = batch["z"].to("cuda")
|
406 |
+
eps = batch["eps"].to("cuda")
|
407 |
+
sigma = batch["sigma"].to("cuda")
|
408 |
+
prompt_embeds = batch["prompt_embeds"].to("cuda")
|
409 |
+
prompt_attention_mask = batch["prompt_attention_mask"].to("cuda")
|
410 |
+
|
411 |
+
all_attention_mask = prepare_attention_mask(
|
412 |
+
prompt_attention_mask=prompt_attention_mask,
|
413 |
+
latents=z
|
414 |
+
)
|
415 |
+
|
416 |
+
sigma_bcthw = sigma[:, None, None, None, None] # [B, 1, 1, 1, 1]
|
417 |
+
# Add noise according to flow matching.
|
418 |
+
# zt = (1 - texp) * x + texp * z1
|
419 |
+
z_sigma = (1 - sigma_bcthw) * z + sigma_bcthw * eps
|
420 |
+
ut = z - eps
|
421 |
+
|
422 |
+
# (1 - sigma) because of
|
423 |
+
# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py#L656
|
424 |
+
# Also, we operate on the scaled version of the `timesteps` directly in the `diffusers` implementation.
|
425 |
+
timesteps = (1 - sigma) * scheduler.config.num_train_timesteps
|
426 |
+
|
427 |
+
with torch.autocast("cuda", torch.bfloat16):
|
428 |
+
model_pred = transformer(
|
429 |
+
hidden_states=z_sigma,
|
430 |
+
encoder_hidden_states=prompt_embeds,
|
431 |
+
encoder_attention_mask=all_attention_mask,
|
432 |
+
timestep=timesteps,
|
433 |
+
return_dict=False,
|
434 |
+
)[0]
|
435 |
+
assert model_pred.shape == z.shape
|
436 |
+
loss = F.mse_loss(model_pred.float(), ut.float())
|
437 |
+
loss.backward()
|
438 |
+
|
439 |
+
optimizer.step()
|
440 |
+
optimizer.zero_grad()
|
441 |
+
lr_scheduler.step()
|
442 |
+
|
443 |
+
progress_bar.update(1)
|
444 |
+
global_step += 1
|
445 |
+
|
446 |
+
last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
|
447 |
+
logs = {"loss": loss.detach().item(), "lr": last_lr}
|
448 |
+
progress_bar.set_postfix(**logs)
|
449 |
+
if wandb_run:
|
450 |
+
wandb_run.log(logs, step=global_step)
|
451 |
+
|
452 |
+
if args.checkpointing_steps is not None and global_step % args.checkpointing_steps == 0:
|
453 |
+
print(f"Saving checkpoint at step {global_step}")
|
454 |
+
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.pt")
|
455 |
+
save_checkpoint(
|
456 |
+
transformer,
|
457 |
+
optimizer,
|
458 |
+
lr_scheduler,
|
459 |
+
global_step,
|
460 |
+
checkpoint_path,
|
461 |
+
)
|
462 |
+
|
463 |
+
# if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
|
464 |
+
print("===== Memory before validation =====")
|
465 |
+
print_memory("cuda")
|
466 |
+
|
467 |
+
transformer.eval()
|
468 |
+
pipe = MochiPipeline.from_pretrained(
|
469 |
+
args.pretrained_model_name_or_path,
|
470 |
+
transformer=transformer,
|
471 |
+
scheduler=scheduler,
|
472 |
+
revision=args.revision,
|
473 |
+
variant=args.variant,
|
474 |
+
)
|
475 |
+
|
476 |
+
if args.enable_slicing:
|
477 |
+
pipe.vae.enable_slicing()
|
478 |
+
if args.enable_tiling:
|
479 |
+
pipe.vae.enable_tiling()
|
480 |
+
if args.enable_model_cpu_offload:
|
481 |
+
pipe.enable_model_cpu_offload()
|
482 |
+
|
483 |
+
# validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
|
484 |
+
validation_prompts = [
|
485 |
+
"A boy in a white shirt and shorts is seen bouncing a ball, isolated background",
|
486 |
+
]
|
487 |
+
for validation_prompt in validation_prompts:
|
488 |
+
pipeline_args = {
|
489 |
+
"prompt": validation_prompt,
|
490 |
+
"guidance_scale": 6.0,
|
491 |
+
"num_frames": 37,
|
492 |
+
"num_inference_steps": 64,
|
493 |
+
"height": args.height,
|
494 |
+
"width": args.width,
|
495 |
+
"max_sequence_length": 256,
|
496 |
+
}
|
497 |
+
log_validation(
|
498 |
+
pipe=pipe,
|
499 |
+
args=args,
|
500 |
+
pipeline_args=pipeline_args,
|
501 |
+
step=global_step,
|
502 |
+
wandb_run=wandb_run,
|
503 |
+
)
|
504 |
+
|
505 |
+
print("===== Memory after validation =====")
|
506 |
+
print_memory("cuda")
|
507 |
+
reset_memory("cuda")
|
508 |
+
|
509 |
+
del pipe.text_encoder
|
510 |
+
del pipe.vae
|
511 |
+
del pipe
|
512 |
+
gc.collect()
|
513 |
+
torch.cuda.empty_cache()
|
514 |
+
|
515 |
+
transformer.train()
|
516 |
+
|
517 |
+
if global_step >= args.max_train_steps:
|
518 |
+
break
|
519 |
+
|
520 |
+
if global_step >= args.max_train_steps:
|
521 |
+
break
|
522 |
+
|
523 |
+
transformer.eval()
|
524 |
+
|
525 |
+
# saving lora weights
|
526 |
+
# transformer_lora_layers = get_peft_model_state_dict(transformer)
|
527 |
+
# MochiPipeline.save_lora_weights(save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers)
|
528 |
+
|
529 |
+
# Cleanup trained models to save memory
|
530 |
+
del transformer
|
531 |
+
|
532 |
+
gc.collect()
|
533 |
+
torch.cuda.empty_cache()
|
534 |
+
|
535 |
+
# Final test inference
|
536 |
+
# validation_outputs = []
|
537 |
+
# if args.validation_prompt and args.num_validation_videos > 0:
|
538 |
+
# print("===== Memory before testing =====")
|
539 |
+
# print_memory("cuda")
|
540 |
+
# reset_memory("cuda")
|
541 |
+
|
542 |
+
# pipe = MochiPipeline.from_pretrained(
|
543 |
+
# args.pretrained_model_name_or_path,
|
544 |
+
# revision=args.revision,
|
545 |
+
# variant=args.variant,
|
546 |
+
# )
|
547 |
+
|
548 |
+
|
549 |
+
|
550 |
+
# if args.enable_slicing:
|
551 |
+
# pipe.vae.enable_slicing()
|
552 |
+
# if args.enable_tiling:
|
553 |
+
# pipe.vae.enable_tiling()
|
554 |
+
# if args.enable_model_cpu_offload:
|
555 |
+
# pipe.enable_model_cpu_offload()
|
556 |
+
|
557 |
+
# # Load LoRA weights
|
558 |
+
# # lora_scaling = args.lora_alpha / args.rank
|
559 |
+
# # pipe.load_lora_weights(args.output_dir, adapter_name="mochi-lora")
|
560 |
+
# # pipe.set_adapters(["mochi-lora"], [lora_scaling])
|
561 |
+
|
562 |
+
# # Run inference
|
563 |
+
# validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
|
564 |
+
# for validation_prompt in validation_prompts:
|
565 |
+
# pipeline_args = {
|
566 |
+
# "prompt": validation_prompt,
|
567 |
+
# "guidance_scale": 6.0,
|
568 |
+
# "num_inference_steps": 64,
|
569 |
+
# "height": args.height,
|
570 |
+
# "width": args.width,
|
571 |
+
# "max_sequence_length": 256,
|
572 |
+
# }
|
573 |
+
|
574 |
+
# video = log_validation(
|
575 |
+
# pipe=pipe,
|
576 |
+
# args=args,
|
577 |
+
# pipeline_args=pipeline_args,
|
578 |
+
# epoch=epoch,
|
579 |
+
# wandb_run=wandb_run,
|
580 |
+
# is_final_validation=True,
|
581 |
+
# )
|
582 |
+
# validation_outputs.extend(video)
|
583 |
+
|
584 |
+
# print("===== Memory after testing =====")
|
585 |
+
# print_memory("cuda")
|
586 |
+
# reset_memory("cuda")
|
587 |
+
# torch.cuda.synchronize("cuda")
|
588 |
+
|
589 |
+
|
590 |
+
|
591 |
+
if __name__ == "__main__":
|
592 |
+
args = get_args()
|
593 |
+
main(args)
|
Mochi/train.sh
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
export NCCL_P2P_DISABLE=1
|
3 |
+
export TORCH_NCCL_ENABLE_MONITORING=0
|
4 |
+
|
5 |
+
GPU_IDS="3"
|
6 |
+
|
7 |
+
DATA_ROOT="/hpc2hdd/home/lwang592/projects/finetrainers/training/data/video-matte-240k-rgb-prepared-f37"
|
8 |
+
MODEL="genmo/mochi-1-preview"
|
9 |
+
OUTPUT_PATH="mochi-rgba-lora-f37"
|
10 |
+
|
11 |
+
cmd="CUDA_VISIBLE_DEVICES=$GPU_IDS python train.py \
|
12 |
+
--pretrained_model_name_or_path $MODEL \
|
13 |
+
--cast_dit \
|
14 |
+
--data_root $DATA_ROOT \
|
15 |
+
--seed 42 \
|
16 |
+
--output_dir $OUTPUT_PATH \
|
17 |
+
--train_batch_size 2 \
|
18 |
+
--dataloader_num_workers 4 \
|
19 |
+
--pin_memory \
|
20 |
+
--caption_dropout 0.0 \
|
21 |
+
--max_train_steps 5000 \
|
22 |
+
--gradient_checkpointing \
|
23 |
+
--enable_slicing \
|
24 |
+
--enable_tiling \
|
25 |
+
--enable_model_cpu_offload \
|
26 |
+
--optimizer adamw \
|
27 |
+
--allow_tf32"
|
28 |
+
|
29 |
+
echo "Running command: $cmd"
|
30 |
+
eval $cmd
|
31 |
+
echo -ne "-------------------- Finished executing script --------------------\n\n"
|
Mochi/trim_and_crop_videos.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Adapted from:
|
3 |
+
https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/trim_and_crop_videos.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
from pathlib import Path
|
7 |
+
import shutil
|
8 |
+
|
9 |
+
import click
|
10 |
+
from moviepy.editor import VideoFileClip
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
|
14 |
+
@click.command()
|
15 |
+
@click.argument("folder", type=click.Path(exists=True, dir_okay=True))
|
16 |
+
@click.argument("output_folder", type=click.Path(dir_okay=True))
|
17 |
+
@click.option("--num_frames", "-f", type=float, default=30, help="Number of frames")
|
18 |
+
@click.option("--resolution", "-r", type=str, default="480x848", help="Video resolution")
|
19 |
+
@click.option("--force_upsample", is_flag=True, help="Force upsample.")
|
20 |
+
def truncate_videos(folder, output_folder, num_frames, resolution, force_upsample):
|
21 |
+
"""Truncate all MP4 and MOV files in FOLDER to specified number of frames and resolution"""
|
22 |
+
input_path = Path(folder)
|
23 |
+
output_path = Path(output_folder)
|
24 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
25 |
+
|
26 |
+
# Parse target resolution
|
27 |
+
target_height, target_width = map(int, resolution.split("x"))
|
28 |
+
|
29 |
+
# Calculate duration
|
30 |
+
duration = (num_frames / 30) + 0.09
|
31 |
+
|
32 |
+
# Find all MP4 and MOV files
|
33 |
+
video_files = (
|
34 |
+
list(input_path.rglob("*.mp4"))
|
35 |
+
+ list(input_path.rglob("*.MOV"))
|
36 |
+
+ list(input_path.rglob("*.mov"))
|
37 |
+
+ list(input_path.rglob("*.MP4"))
|
38 |
+
)
|
39 |
+
|
40 |
+
for file_path in tqdm(video_files):
|
41 |
+
try:
|
42 |
+
relative_path = file_path.relative_to(input_path)
|
43 |
+
output_file = output_path / relative_path.with_suffix(".mp4")
|
44 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
45 |
+
|
46 |
+
click.echo(f"Processing: {file_path}")
|
47 |
+
video = VideoFileClip(str(file_path))
|
48 |
+
|
49 |
+
# Skip if video is too short
|
50 |
+
if video.duration < duration:
|
51 |
+
click.echo(f"Skipping {file_path} as it is too short")
|
52 |
+
continue
|
53 |
+
|
54 |
+
# Skip if target resolution is larger than input
|
55 |
+
if target_width > video.w or target_height > video.h:
|
56 |
+
if force_upsample:
|
57 |
+
click.echo(
|
58 |
+
f"{file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}. So, upsampling the video."
|
59 |
+
)
|
60 |
+
video = video.resize(width=target_width, height=target_height)
|
61 |
+
else:
|
62 |
+
click.echo(
|
63 |
+
f"Skipping {file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}"
|
64 |
+
)
|
65 |
+
continue
|
66 |
+
|
67 |
+
# First truncate duration
|
68 |
+
truncated = video.subclip(0, duration)
|
69 |
+
|
70 |
+
# Calculate crop dimensions to maintain aspect ratio
|
71 |
+
target_ratio = target_width / target_height
|
72 |
+
current_ratio = truncated.w / truncated.h
|
73 |
+
|
74 |
+
if current_ratio > target_ratio:
|
75 |
+
# Video is wider than target ratio - crop width
|
76 |
+
new_width = int(truncated.h * target_ratio)
|
77 |
+
x1 = (truncated.w - new_width) // 2
|
78 |
+
final = truncated.crop(x1=x1, width=new_width).resize((target_width, target_height))
|
79 |
+
else:
|
80 |
+
# Video is taller than target ratio - crop height
|
81 |
+
new_height = int(truncated.w / target_ratio)
|
82 |
+
y1 = (truncated.h - new_height) // 2
|
83 |
+
final = truncated.crop(y1=y1, height=new_height).resize((target_width, target_height))
|
84 |
+
|
85 |
+
# Set output parameters for consistent MP4 encoding
|
86 |
+
output_params = {
|
87 |
+
"codec": "libx264",
|
88 |
+
"audio": False, # Disable audio
|
89 |
+
"preset": "medium", # Balance between speed and quality
|
90 |
+
"bitrate": "5000k", # Adjust as needed
|
91 |
+
}
|
92 |
+
|
93 |
+
# Set FPS to 30
|
94 |
+
final = final.set_fps(30)
|
95 |
+
|
96 |
+
# Check for a corresponding .txt file
|
97 |
+
txt_file_path = file_path.with_suffix(".txt")
|
98 |
+
if txt_file_path.exists():
|
99 |
+
output_txt_file = output_path / relative_path.with_suffix(".txt")
|
100 |
+
output_txt_file.parent.mkdir(parents=True, exist_ok=True)
|
101 |
+
shutil.copy(txt_file_path, output_txt_file)
|
102 |
+
click.echo(f"Copied {txt_file_path} to {output_txt_file}")
|
103 |
+
else:
|
104 |
+
# Print warning in bold yellow with a warning emoji
|
105 |
+
click.echo(
|
106 |
+
f"\033[1;33m⚠️ Warning: No caption found for {file_path}, using an empty caption. This may hurt fine-tuning quality.\033[0m"
|
107 |
+
)
|
108 |
+
output_txt_file = output_path / relative_path.with_suffix(".txt")
|
109 |
+
output_txt_file.parent.mkdir(parents=True, exist_ok=True)
|
110 |
+
output_txt_file.touch()
|
111 |
+
|
112 |
+
# Write the output file
|
113 |
+
final.write_videofile(str(output_file), **output_params)
|
114 |
+
|
115 |
+
# Clean up
|
116 |
+
video.close()
|
117 |
+
truncated.close()
|
118 |
+
final.close()
|
119 |
+
|
120 |
+
except Exception as e:
|
121 |
+
click.echo(f"\033[1;31m Error processing {file_path}: {str(e)}\033[0m", err=True)
|
122 |
+
raise
|
123 |
+
|
124 |
+
|
125 |
+
if __name__ == "__main__":
|
126 |
+
truncate_videos()
|
Mochi/utils.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import inspect
|
3 |
+
from typing import Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from accelerate import Accelerator
|
7 |
+
from accelerate.logging import get_logger
|
8 |
+
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
9 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
10 |
+
|
11 |
+
|
12 |
+
logger = get_logger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
def get_optimizer(
|
16 |
+
params_to_optimize,
|
17 |
+
optimizer_name: str = "adam",
|
18 |
+
learning_rate: float = 1e-3,
|
19 |
+
beta1: float = 0.9,
|
20 |
+
beta2: float = 0.95,
|
21 |
+
beta3: float = 0.98,
|
22 |
+
epsilon: float = 1e-8,
|
23 |
+
weight_decay: float = 1e-4,
|
24 |
+
prodigy_decouple: bool = False,
|
25 |
+
prodigy_use_bias_correction: bool = False,
|
26 |
+
prodigy_safeguard_warmup: bool = False,
|
27 |
+
use_8bit: bool = False,
|
28 |
+
use_4bit: bool = False,
|
29 |
+
use_torchao: bool = False,
|
30 |
+
use_deepspeed: bool = False,
|
31 |
+
use_cpu_offload_optimizer: bool = False,
|
32 |
+
offload_gradients: bool = False,
|
33 |
+
) -> torch.optim.Optimizer:
|
34 |
+
optimizer_name = optimizer_name.lower()
|
35 |
+
|
36 |
+
# Use DeepSpeed optimzer
|
37 |
+
if use_deepspeed:
|
38 |
+
from accelerate.utils import DummyOptim
|
39 |
+
|
40 |
+
return DummyOptim(
|
41 |
+
params_to_optimize,
|
42 |
+
lr=learning_rate,
|
43 |
+
betas=(beta1, beta2),
|
44 |
+
eps=epsilon,
|
45 |
+
weight_decay=weight_decay,
|
46 |
+
)
|
47 |
+
|
48 |
+
if use_8bit and use_4bit:
|
49 |
+
raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.")
|
50 |
+
|
51 |
+
if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer:
|
52 |
+
try:
|
53 |
+
import torchao
|
54 |
+
|
55 |
+
torchao.__version__
|
56 |
+
except ImportError:
|
57 |
+
raise ImportError(
|
58 |
+
"To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`."
|
59 |
+
)
|
60 |
+
|
61 |
+
if not use_torchao and use_4bit:
|
62 |
+
raise ValueError("4-bit Optimizers are only supported with torchao.")
|
63 |
+
|
64 |
+
# Optimizer creation
|
65 |
+
supported_optimizers = ["adam", "adamw", "prodigy", "came"]
|
66 |
+
if optimizer_name not in supported_optimizers:
|
67 |
+
logger.warning(
|
68 |
+
f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`."
|
69 |
+
)
|
70 |
+
optimizer_name = "adamw"
|
71 |
+
|
72 |
+
if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
|
73 |
+
raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.")
|
74 |
+
|
75 |
+
if use_8bit:
|
76 |
+
try:
|
77 |
+
import bitsandbytes as bnb
|
78 |
+
except ImportError:
|
79 |
+
raise ImportError(
|
80 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
81 |
+
)
|
82 |
+
|
83 |
+
if optimizer_name == "adamw":
|
84 |
+
if use_torchao:
|
85 |
+
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
|
86 |
+
|
87 |
+
optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
|
88 |
+
else:
|
89 |
+
optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
|
90 |
+
|
91 |
+
init_kwargs = {
|
92 |
+
"betas": (beta1, beta2),
|
93 |
+
"eps": epsilon,
|
94 |
+
"weight_decay": weight_decay,
|
95 |
+
}
|
96 |
+
|
97 |
+
elif optimizer_name == "adam":
|
98 |
+
if use_torchao:
|
99 |
+
from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
|
100 |
+
|
101 |
+
optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam
|
102 |
+
else:
|
103 |
+
optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam
|
104 |
+
|
105 |
+
init_kwargs = {
|
106 |
+
"betas": (beta1, beta2),
|
107 |
+
"eps": epsilon,
|
108 |
+
"weight_decay": weight_decay,
|
109 |
+
}
|
110 |
+
|
111 |
+
elif optimizer_name == "prodigy":
|
112 |
+
try:
|
113 |
+
import prodigyopt
|
114 |
+
except ImportError:
|
115 |
+
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
|
116 |
+
|
117 |
+
optimizer_class = prodigyopt.Prodigy
|
118 |
+
|
119 |
+
if learning_rate <= 0.1:
|
120 |
+
logger.warning(
|
121 |
+
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
122 |
+
)
|
123 |
+
|
124 |
+
init_kwargs = {
|
125 |
+
"lr": learning_rate,
|
126 |
+
"betas": (beta1, beta2),
|
127 |
+
"beta3": beta3,
|
128 |
+
"eps": epsilon,
|
129 |
+
"weight_decay": weight_decay,
|
130 |
+
"decouple": prodigy_decouple,
|
131 |
+
"use_bias_correction": prodigy_use_bias_correction,
|
132 |
+
"safeguard_warmup": prodigy_safeguard_warmup,
|
133 |
+
}
|
134 |
+
|
135 |
+
elif optimizer_name == "came":
|
136 |
+
try:
|
137 |
+
import came_pytorch
|
138 |
+
except ImportError:
|
139 |
+
raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`")
|
140 |
+
|
141 |
+
optimizer_class = came_pytorch.CAME
|
142 |
+
|
143 |
+
init_kwargs = {
|
144 |
+
"lr": learning_rate,
|
145 |
+
"eps": (1e-30, 1e-16),
|
146 |
+
"betas": (beta1, beta2, beta3),
|
147 |
+
"weight_decay": weight_decay,
|
148 |
+
}
|
149 |
+
|
150 |
+
if use_cpu_offload_optimizer:
|
151 |
+
from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
|
152 |
+
|
153 |
+
if "fused" in inspect.signature(optimizer_class.__init__).parameters:
|
154 |
+
init_kwargs.update({"fused": True})
|
155 |
+
|
156 |
+
optimizer = CPUOffloadOptimizer(
|
157 |
+
params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs
|
158 |
+
)
|
159 |
+
else:
|
160 |
+
optimizer = optimizer_class(params_to_optimize, **init_kwargs)
|
161 |
+
|
162 |
+
return optimizer
|
163 |
+
|
164 |
+
|
165 |
+
def get_gradient_norm(parameters):
|
166 |
+
norm = 0
|
167 |
+
for param in parameters:
|
168 |
+
if param.grad is None:
|
169 |
+
continue
|
170 |
+
local_norm = param.grad.detach().data.norm(2)
|
171 |
+
norm += local_norm.item() ** 2
|
172 |
+
norm = norm**0.5
|
173 |
+
return norm
|
174 |
+
|
175 |
+
|
176 |
+
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
177 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
178 |
+
tw = tgt_width
|
179 |
+
th = tgt_height
|
180 |
+
h, w = src
|
181 |
+
r = h / w
|
182 |
+
if r > (th / tw):
|
183 |
+
resize_height = th
|
184 |
+
resize_width = int(round(th / h * w))
|
185 |
+
else:
|
186 |
+
resize_width = tw
|
187 |
+
resize_height = int(round(tw / w * h))
|
188 |
+
|
189 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
190 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
191 |
+
|
192 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
193 |
+
|
194 |
+
|
195 |
+
def prepare_rotary_positional_embeddings(
|
196 |
+
height: int,
|
197 |
+
width: int,
|
198 |
+
num_frames: int,
|
199 |
+
vae_scale_factor_spatial: int = 8,
|
200 |
+
patch_size: int = 2,
|
201 |
+
patch_size_t: int = None,
|
202 |
+
attention_head_dim: int = 64,
|
203 |
+
device: Optional[torch.device] = None,
|
204 |
+
base_height: int = 480,
|
205 |
+
base_width: int = 720,
|
206 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
207 |
+
grid_height = height // (vae_scale_factor_spatial * patch_size)
|
208 |
+
grid_width = width // (vae_scale_factor_spatial * patch_size)
|
209 |
+
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
|
210 |
+
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
|
211 |
+
|
212 |
+
if patch_size_t is None:
|
213 |
+
# CogVideoX 1.0
|
214 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
215 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
216 |
+
)
|
217 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
218 |
+
embed_dim=attention_head_dim,
|
219 |
+
crops_coords=grid_crops_coords,
|
220 |
+
grid_size=(grid_height, grid_width),
|
221 |
+
temporal_size=num_frames,
|
222 |
+
)
|
223 |
+
else:
|
224 |
+
# CogVideoX 1.5
|
225 |
+
base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t
|
226 |
+
|
227 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
228 |
+
embed_dim=attention_head_dim,
|
229 |
+
crops_coords=None,
|
230 |
+
grid_size=(grid_height, grid_width),
|
231 |
+
temporal_size=base_num_frames,
|
232 |
+
grid_type="slice",
|
233 |
+
max_size=(base_size_height, base_size_width),
|
234 |
+
)
|
235 |
+
|
236 |
+
freqs_cos = freqs_cos.to(device=device)
|
237 |
+
freqs_sin = freqs_sin.to(device=device)
|
238 |
+
return freqs_cos, freqs_sin
|
239 |
+
|
240 |
+
|
241 |
+
def reset_memory(device: Union[str, torch.device]) -> None:
|
242 |
+
gc.collect()
|
243 |
+
torch.cuda.empty_cache()
|
244 |
+
torch.cuda.reset_peak_memory_stats(device)
|
245 |
+
torch.cuda.reset_accumulated_memory_stats(device)
|
246 |
+
|
247 |
+
|
248 |
+
def print_memory(device: Union[str, torch.device]) -> None:
|
249 |
+
memory_allocated = torch.cuda.memory_allocated(device) / 1024**3
|
250 |
+
max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3
|
251 |
+
max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
|
252 |
+
print(f"{memory_allocated=:.3f} GB")
|
253 |
+
print(f"{max_memory_allocated=:.3f} GB")
|
254 |
+
print(f"{max_memory_reserved=:.3f} GB")
|
255 |
+
|
256 |
+
|
257 |
+
def unwrap_model(accelerator: Accelerator, model):
|
258 |
+
model = accelerator.unwrap_model(model)
|
259 |
+
model = model._orig_mod if is_compiled_module(model) else model
|
260 |
+
return model
|
README.md
CHANGED
@@ -1,12 +1,161 @@
|
|
1 |
-
---
|
2 |
-
title: TransPixelerTest
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: TransPixelerTest
|
3 |
+
app_file: app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 5.35.0
|
6 |
+
---
|
7 |
+
## TransPixeler: Advancing Text-to-Video Generation with Transparency (CVPR2025)
|
8 |
+
<br>
|
9 |
+
<a href="https://arxiv.org/abs/2501.03006"><img src='https://img.shields.io/badge/arXiv-2501.03006-b31b1b.svg'></a>
|
10 |
+
<a href='https://wileewang.github.io/TransPixeler'><img src='https://img.shields.io/badge/Project_Page-TransPixeler-blue'></a>
|
11 |
+
<a href='https://huggingface.co/spaces/wileewang/TransPixar'><img src='https://img.shields.io/badge/HuggingFace-TransPixeler-yellow'></a>
|
12 |
+
<a href="https://discord.gg/7Xds3Qjr"><img src="https://img.shields.io/badge/Discord-join-blueviolet?logo=discord&"></a>
|
13 |
+
<a href="https://github.com/wileewang/TransPixar/blob/main/wechat_group.jpg"><img src="https://img.shields.io/badge/Wechat-Join-green?logo=wechat&"></a>
|
14 |
+
<a href='https://openbayes.com/console/public/tutorials/tKhPalKrDb9'><img src='https://img.shields.io/badge/Demo-OpenBayes贝式计算-blue'></a>
|
15 |
+
<br>
|
16 |
+
|
17 |
+
[Luozhou Wang*](https://wileewang.github.io/),
|
18 |
+
[Yijun Li**](https://yijunmaverick.github.io/),
|
19 |
+
[Zhifei Chen](),
|
20 |
+
[Jui-Hsien Wang](http://juiwang.com/),
|
21 |
+
[Zhifei Zhang](https://zzutk.github.io/),
|
22 |
+
[He Zhang](https://sites.google.com/site/hezhangsprinter),
|
23 |
+
[Zhe Lin](https://sites.google.com/site/zhelin625/home),
|
24 |
+
[Ying-Cong Chen†](https://www.yingcong.me)
|
25 |
+
|
26 |
+
HKUST(GZ), HKUST, Adobe Research.
|
27 |
+
|
28 |
+
\* Internship Project
|
29 |
+
\** Project Lead
|
30 |
+
† Corresponding Author
|
31 |
+
|
32 |
+
Text-to-video generative models have made significant strides, enabling diverse applications in entertainment, advertising, and education. However, generating RGBA video, which includes alpha channels for transparency, remains a challenge due to limited datasets and the difficulty of adapting existing models. Alpha channels are crucial for visual effects (VFX), allowing transparent elements like smoke and reflections to blend seamlessly into scenes.
|
33 |
+
We introduce TransPixar, a method to extend pretrained video models for RGBA generation while retaining the original RGB capabilities. TransPixar leverages a diffusion transformer (DiT) architecture, incorporating alpha-specific tokens and using LoRA-based fine-tuning to jointly generate RGB and alpha channels with high consistency. By optimizing attention mechanisms, TransPixeler preserves the strengths of the original RGB model and achieves strong alignment between RGB and alpha channels despite limited training data.
|
34 |
+
Our approach effectively generates diverse and consistent RGBA videos, advancing the possibilities for VFX and interactive content creation.
|
35 |
+
|
36 |
+
<!-- insert a teaser gif -->
|
37 |
+
<!-- <img src="assets/mi.gif" width="640" /> -->
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
## 📰 News
|
42 |
+
|
43 |
+
- **[2025.04.28]** We have introduced a new development branch [`wan`](https://github.com/wileewang/TransPixar/tree/wan) that integrates the [Wan2.1](https://github.com/Wan-Video/Wan2.1) video generation model to support **joint generation** tasks. This branch includes training code tailored for generating both RGB and associated modalities (e.g., segmentation maps, alpha masks) from a shared text prompt.
|
44 |
+
|
45 |
+
- **[2025.02.26]** **TransPixeler** is accepted by CVPR 2025! See you in Nashville!
|
46 |
+
|
47 |
+
- **[2025.01.19]** We've renamed our project from **TransPixar** to **TransPixeler**!!
|
48 |
+
|
49 |
+
- **[2025.01.17]** We’ve created a [Discord group](https://discord.gg/7Xds3Qjr) and a [WeChat group](https://github.com/wileewang/TransPixar/blob/main/wechat_group.jpg)! Everyone is welcome to join for discussions and collaborations.
|
50 |
+
|
51 |
+
- **[2025.01.14]** Added new tasks to the repository's roadmap, including support for Hunyuan and LTX video models, and ComfyUI integration.
|
52 |
+
|
53 |
+
- **[2025.01.07]** Released project page, arXiv paper, inference code, and Hugging Face demo.
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
## 🔥 New Branch for Joint Generation with Wan2.1
|
59 |
+
|
60 |
+
We have introduced a new development branch [`wan`](https://github.com/wileewang/TransPixar/tree/wan) that integrates the [Wan2.1](https://github.com/Wan-Video/Wan2.1) video generation model to support **joint generation** tasks.
|
61 |
+
|
62 |
+
In the `wan` branch, we have developed and released training code tailored for joint generation scenarios, enabling the simultaneous generation of RGB videos and associated modalities (e.g., segmentation maps, alpha masks) from a shared text prompt.
|
63 |
+
|
64 |
+
**Key features of the `wan` branch:**
|
65 |
+
- **Integration of Wan2.1**: Leverages the capabilities of the Wan2.1 video generation model for enhanced performance.
|
66 |
+
- **Joint Generation Support**: Facilitates the concurrent generation of RGB and paired modality videos.
|
67 |
+
- **Dataset Structure**: Expects each sample to include:
|
68 |
+
- A primary video file (`001.mp4`) representing the RGB content.
|
69 |
+
- A paired secondary video file (`001_seg.mp4`) with a fixed `_seg` suffix, representing the associated modality.
|
70 |
+
- A caption text file (`001.txt`) with the same base name as the primary video.
|
71 |
+
- **Periodic Evaluation**: Supports periodic video sampling during training by setting `eval_every_step` or `eval_every_epoch` in the configuration.
|
72 |
+
- **Customized Pipelines**: Offers tailored training and inference pipelines designed specifically for joint generation tasks.
|
73 |
+
|
74 |
+
👉 To utilize the joint generation features, please checkout the [`wan`](https://github.com/wileewang/TransPixar/tree/wan) branch.
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
## Contents
|
80 |
+
|
81 |
+
* [Installation](#installation)
|
82 |
+
* [TransPixar LoRA Weights](#transpixar-lora-hub)
|
83 |
+
* [Training](#training)
|
84 |
+
* [Inference](#inference)
|
85 |
+
* [Acknowledgement](#acknowledgement)
|
86 |
+
* [Citation](#citation)
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
## Installation
|
91 |
+
|
92 |
+
```bash
|
93 |
+
# For the main branch
|
94 |
+
conda create -n TransPixeler python=3.10
|
95 |
+
conda activate TransPixeler
|
96 |
+
pip install -r requirements.txt
|
97 |
+
```
|
98 |
+
|
99 |
+
**Note:**
|
100 |
+
If you want to use the **Wan2.1 model**, please first checkout the `wan` branch:
|
101 |
+
|
102 |
+
```bash
|
103 |
+
git checkout wan
|
104 |
+
```
|
105 |
+
|
106 |
+
## TransPixeler LoRA Weights
|
107 |
+
|
108 |
+
Our pipeline is designed to support various video tasks, including Text-to-RGBA Video, Image-to-RGBA Video.
|
109 |
+
|
110 |
+
We provide the following pre-trained LoRA weights:
|
111 |
+
|
112 |
+
| Task | Base Model | Frames | LoRA weights | Inference VRAM |
|
113 |
+
|---------------|---------------------------------------------------------------|--------|--------------------------------------------------------------------|----------------|
|
114 |
+
| T2V + RGBA | [THUDM/CogVideoX-5B](https://huggingface.co/THUDM/CogVideoX-5b) | 49 | [link](https://huggingface.co/wileewang/TransPixar/blob/main/cogvideox_rgba_lora.safetensors) | ~24GB |
|
115 |
+
|
116 |
+
|
117 |
+
## Training - RGB + Alpha Joint Generation
|
118 |
+
We have open-sourced the training code for **Mochi** on RGBA joint generation. Please refer to the [Mochi README](Mochi/README.md) for details.
|
119 |
+
|
120 |
+
|
121 |
+
## Inference - Gradio Demo
|
122 |
+
In addition to the [Hugging Face online demo](https://huggingface.co/spaces/wileewang/TransPixar), users can also launch a local inference demo based on CogVideoX-5B by running the following command:
|
123 |
+
|
124 |
+
```bash
|
125 |
+
python app.py
|
126 |
+
```
|
127 |
+
|
128 |
+
## Inference - Command Line Interface (CLI)
|
129 |
+
To generate RGBA videos, navigate to the corresponding directory for the video model and execute the following command:
|
130 |
+
```bash
|
131 |
+
python cli.py \
|
132 |
+
--lora_path /path/to/lora \
|
133 |
+
--prompt "..."
|
134 |
+
```
|
135 |
+
|
136 |
+
---
|
137 |
+
|
138 |
+
## Acknowledgement
|
139 |
+
|
140 |
+
* [finetrainers](https://github.com/a-r-r-o-w/finetrainers): We followed their implementation of Mochi training and inference.
|
141 |
+
* [CogVideoX](https://github.com/THUDM/CogVideo): We followed their implementation of CogVideoX training and inference.
|
142 |
+
|
143 |
+
We are grateful for their exceptional work and generous contribution to the open-source community.
|
144 |
+
|
145 |
+
## Citation
|
146 |
+
|
147 |
+
```bibtex
|
148 |
+
@misc{wang2025transpixeler,
|
149 |
+
title={TransPixeler: Advancing Text-to-Video Generation with Transparency},
|
150 |
+
author={Luozhou Wang and Yijun Li and Zhifei Chen and Jui-Hsien Wang and Zhifei Zhang and He Zhang and Zhe Lin and Ying-Cong Chen},
|
151 |
+
year={2025},
|
152 |
+
eprint={2501.03006},
|
153 |
+
archivePrefix={arXiv},
|
154 |
+
primaryClass={cs.CV},
|
155 |
+
url={https://arxiv.org/abs/2501.03006},
|
156 |
+
}
|
157 |
+
```
|
158 |
+
|
159 |
+
## Star History
|
160 |
+
|
161 |
+
[](https://star-history.com/#wileewang/TransPixeler&Date)
|
app.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
THis is the main file for the gradio web demo. It uses the CogVideoX-5B model to generate videos gradio web demo.
|
3 |
+
set environment variable OPENAI_API_KEY to use the OpenAI API to enhance the prompt.
|
4 |
+
Usage:
|
5 |
+
OpenAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py
|
6 |
+
"""
|
7 |
+
|
8 |
+
import math
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
import threading
|
12 |
+
import time
|
13 |
+
|
14 |
+
import cv2
|
15 |
+
import tempfile
|
16 |
+
import imageio_ffmpeg
|
17 |
+
import gradio as gr
|
18 |
+
import torch
|
19 |
+
from PIL import Image
|
20 |
+
# from diffusers import (
|
21 |
+
# CogVideoXPipeline,
|
22 |
+
# CogVideoXDPMScheduler,
|
23 |
+
# CogVideoXVideoToVideoPipeline,
|
24 |
+
# CogVideoXImageToVideoPipeline,
|
25 |
+
# CogVideoXTransformer3DModel,
|
26 |
+
# )
|
27 |
+
from typing import Union, List
|
28 |
+
from CogVideoX.pipeline_rgba import CogVideoXPipeline
|
29 |
+
from CogVideoX.rgba_utils import *
|
30 |
+
from diffusers import CogVideoXDPMScheduler
|
31 |
+
|
32 |
+
from diffusers.utils import load_video, load_image, export_to_video
|
33 |
+
from datetime import datetime, timedelta
|
34 |
+
|
35 |
+
from diffusers.image_processor import VaeImageProcessor
|
36 |
+
import moviepy.editor as mp
|
37 |
+
import numpy as np
|
38 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
39 |
+
import gc
|
40 |
+
|
41 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
42 |
+
|
43 |
+
# hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran")
|
44 |
+
hf_hub_download(repo_id="wileewang/TransPixar", filename="cogvideox_rgba_lora.safetensors", local_dir="model_cogvideox_rgba_lora")
|
45 |
+
# snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
|
46 |
+
|
47 |
+
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5B", torch_dtype=torch.bfloat16)
|
48 |
+
# pipe.enable_sequential_cpu_offload()
|
49 |
+
pipe.vae.enable_slicing()
|
50 |
+
pipe.vae.enable_tiling()
|
51 |
+
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
52 |
+
seq_length = 2 * (
|
53 |
+
(480 // pipe.vae_scale_factor_spatial // 2)
|
54 |
+
* (720 // pipe.vae_scale_factor_spatial // 2)
|
55 |
+
* ((13 - 1) // pipe.vae_scale_factor_temporal + 1)
|
56 |
+
)
|
57 |
+
prepare_for_rgba_inference(
|
58 |
+
pipe.transformer,
|
59 |
+
rgba_weights_path="model_cogvideox_rgba_lora/cogvideox_rgba_lora.safetensors",
|
60 |
+
device="cuda",
|
61 |
+
dtype=torch.bfloat16,
|
62 |
+
text_length=226,
|
63 |
+
seq_length=seq_length, # this is for the creation of attention mask.
|
64 |
+
)
|
65 |
+
|
66 |
+
# pipe.transformer.to(memory_format=torch.channels_last)
|
67 |
+
# pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
|
68 |
+
# pipe_image.transformer.to(memory_format=torch.channels_last)
|
69 |
+
# pipe_image.transformer = torch.compile(pipe_image.transformer, mode="max-autotune", fullgraph=True)
|
70 |
+
|
71 |
+
os.makedirs("./output", exist_ok=True)
|
72 |
+
os.makedirs("./gradio_tmp", exist_ok=True)
|
73 |
+
|
74 |
+
# upscale_model = utils.load_sd_upscale("model_real_esran/RealESRGAN_x4.pth", device)
|
75 |
+
# frame_interpolation_model = load_rife_model("model_rife")
|
76 |
+
|
77 |
+
|
78 |
+
sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
|
79 |
+
For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
|
80 |
+
There are a few rules to follow:
|
81 |
+
You will only ever output a single video description per user request.
|
82 |
+
When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
|
83 |
+
Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
|
84 |
+
Video descriptions must have the same num of words as examples below. Extra words will be ignored.
|
85 |
+
"""
|
86 |
+
def save_video(tensor: Union[List[np.ndarray], List[Image.Image]], fps: int = 8, prefix='rgb'):
|
87 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
88 |
+
video_path = f"./output/{prefix}_{timestamp}.mp4"
|
89 |
+
os.makedirs(os.path.dirname(video_path), exist_ok=True)
|
90 |
+
export_to_video(tensor, video_path, fps=fps)
|
91 |
+
return video_path
|
92 |
+
|
93 |
+
def resize_if_unfit(input_video, progress=gr.Progress(track_tqdm=True)):
|
94 |
+
width, height = get_video_dimensions(input_video)
|
95 |
+
|
96 |
+
if width == 720 and height == 480:
|
97 |
+
processed_video = input_video
|
98 |
+
else:
|
99 |
+
processed_video = center_crop_resize(input_video)
|
100 |
+
return processed_video
|
101 |
+
|
102 |
+
|
103 |
+
def get_video_dimensions(input_video_path):
|
104 |
+
reader = imageio_ffmpeg.read_frames(input_video_path)
|
105 |
+
metadata = next(reader)
|
106 |
+
return metadata["size"]
|
107 |
+
|
108 |
+
|
109 |
+
def center_crop_resize(input_video_path, target_width=720, target_height=480):
|
110 |
+
cap = cv2.VideoCapture(input_video_path)
|
111 |
+
|
112 |
+
orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
113 |
+
orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
114 |
+
orig_fps = cap.get(cv2.CAP_PROP_FPS)
|
115 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
116 |
+
|
117 |
+
width_factor = target_width / orig_width
|
118 |
+
height_factor = target_height / orig_height
|
119 |
+
resize_factor = max(width_factor, height_factor)
|
120 |
+
|
121 |
+
inter_width = int(orig_width * resize_factor)
|
122 |
+
inter_height = int(orig_height * resize_factor)
|
123 |
+
|
124 |
+
target_fps = 8
|
125 |
+
ideal_skip = max(0, math.ceil(orig_fps / target_fps) - 1)
|
126 |
+
skip = min(5, ideal_skip) # Cap at 5
|
127 |
+
|
128 |
+
while (total_frames / (skip + 1)) < 49 and skip > 0:
|
129 |
+
skip -= 1
|
130 |
+
|
131 |
+
processed_frames = []
|
132 |
+
frame_count = 0
|
133 |
+
total_read = 0
|
134 |
+
|
135 |
+
while frame_count < 49 and total_read < total_frames:
|
136 |
+
ret, frame = cap.read()
|
137 |
+
if not ret:
|
138 |
+
break
|
139 |
+
|
140 |
+
if total_read % (skip + 1) == 0:
|
141 |
+
resized = cv2.resize(frame, (inter_width, inter_height), interpolation=cv2.INTER_AREA)
|
142 |
+
|
143 |
+
start_x = (inter_width - target_width) // 2
|
144 |
+
start_y = (inter_height - target_height) // 2
|
145 |
+
cropped = resized[start_y : start_y + target_height, start_x : start_x + target_width]
|
146 |
+
|
147 |
+
processed_frames.append(cropped)
|
148 |
+
frame_count += 1
|
149 |
+
|
150 |
+
total_read += 1
|
151 |
+
|
152 |
+
cap.release()
|
153 |
+
|
154 |
+
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
|
155 |
+
temp_video_path = temp_file.name
|
156 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
157 |
+
out = cv2.VideoWriter(temp_video_path, fourcc, target_fps, (target_width, target_height))
|
158 |
+
|
159 |
+
for frame in processed_frames:
|
160 |
+
out.write(frame)
|
161 |
+
|
162 |
+
out.release()
|
163 |
+
|
164 |
+
return temp_video_path
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
def infer(
|
169 |
+
prompt: str,
|
170 |
+
num_inference_steps: int,
|
171 |
+
guidance_scale: float,
|
172 |
+
seed: int = -1,
|
173 |
+
progress=gr.Progress(track_tqdm=True),
|
174 |
+
):
|
175 |
+
if seed == -1:
|
176 |
+
seed = random.randint(0, 2**8 - 1)
|
177 |
+
pipe.to(device)
|
178 |
+
video_pt = pipe(
|
179 |
+
prompt=prompt + ", isolated background",
|
180 |
+
num_videos_per_prompt=1,
|
181 |
+
num_inference_steps=num_inference_steps,
|
182 |
+
num_frames=13,
|
183 |
+
use_dynamic_cfg=True,
|
184 |
+
output_type="latent",
|
185 |
+
guidance_scale=guidance_scale,
|
186 |
+
generator=torch.Generator(device=device).manual_seed(int(seed)),
|
187 |
+
).frames
|
188 |
+
# pipe.to("cpu")
|
189 |
+
gc.collect()
|
190 |
+
return (video_pt, seed)
|
191 |
+
|
192 |
+
|
193 |
+
def convert_to_gif(video_path):
|
194 |
+
clip = mp.VideoFileClip(video_path)
|
195 |
+
clip = clip.set_fps(8)
|
196 |
+
clip = clip.resize(height=240)
|
197 |
+
gif_path = video_path.replace(".mp4", ".gif")
|
198 |
+
clip.write_gif(gif_path, fps=8)
|
199 |
+
return gif_path
|
200 |
+
|
201 |
+
|
202 |
+
def delete_old_files():
|
203 |
+
while True:
|
204 |
+
now = datetime.now()
|
205 |
+
cutoff = now - timedelta(minutes=10)
|
206 |
+
directories = ["./output", "./gradio_tmp"]
|
207 |
+
|
208 |
+
for directory in directories:
|
209 |
+
for filename in os.listdir(directory):
|
210 |
+
file_path = os.path.join(directory, filename)
|
211 |
+
if os.path.isfile(file_path):
|
212 |
+
file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
|
213 |
+
if file_mtime < cutoff:
|
214 |
+
os.remove(file_path)
|
215 |
+
time.sleep(600)
|
216 |
+
|
217 |
+
|
218 |
+
threading.Thread(target=delete_old_files, daemon=True).start()
|
219 |
+
|
220 |
+
with gr.Blocks() as demo:
|
221 |
+
gr.HTML("""
|
222 |
+
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
|
223 |
+
TransPixar + CogVideoX-5B Huggingface Space🤗
|
224 |
+
</div>
|
225 |
+
<div style="text-align: center;">
|
226 |
+
<a href="https://huggingface.co/wileewang/TransPixar">🤗 TransPixar LoRA Hub</a> |
|
227 |
+
<a href="https://github.com/wileewang/TransPixar">🌐 Github</a> |
|
228 |
+
<a href="https://arxiv.org/">📜 arxiv </a>
|
229 |
+
</div>
|
230 |
+
<div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
|
231 |
+
⚠️ This demo is for academic research and experiential use only.
|
232 |
+
</div>
|
233 |
+
""")
|
234 |
+
with gr.Row():
|
235 |
+
with gr.Column():
|
236 |
+
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
|
237 |
+
with gr.Group():
|
238 |
+
with gr.Column():
|
239 |
+
with gr.Row():
|
240 |
+
seed_param = gr.Number(
|
241 |
+
label="Inference Seed (Enter a positive number, -1 for random)", value=-1
|
242 |
+
)
|
243 |
+
|
244 |
+
generate_button = gr.Button("🎬 Generate Video")
|
245 |
+
with gr.Row():
|
246 |
+
gr.Markdown(
|
247 |
+
"""
|
248 |
+
**Note:** The output RGB is a premultiplied version to avoid the color decontamination problem.
|
249 |
+
It can directly composite with a background using:
|
250 |
+
```
|
251 |
+
composite = rgb + (1 - alpha) * background
|
252 |
+
```
|
253 |
+
"""
|
254 |
+
)
|
255 |
+
|
256 |
+
with gr.Column():
|
257 |
+
rgb_video_output = gr.Video(label="Generated RGB Video", width=720, height=480)
|
258 |
+
alpha_video_output = gr.Video(label="Generated Alpha Video", width=720, height=480)
|
259 |
+
with gr.Row():
|
260 |
+
download_rgb_video_button = gr.File(label="📥 Download RGB Video", visible=False)
|
261 |
+
download_alpha_video_button = gr.File(label="📥 Download Alpha Video", visible=False)
|
262 |
+
seed_text = gr.Number(label="Seed Used for Video Generation", visible=False)
|
263 |
+
|
264 |
+
|
265 |
+
def generate(
|
266 |
+
prompt,
|
267 |
+
seed_value,
|
268 |
+
progress=gr.Progress(track_tqdm=True)
|
269 |
+
):
|
270 |
+
latents, seed = infer(
|
271 |
+
prompt,
|
272 |
+
num_inference_steps=25, # NOT Changed 25
|
273 |
+
guidance_scale=7.0, # NOT Changed
|
274 |
+
seed=seed_value,
|
275 |
+
progress=progress,
|
276 |
+
)
|
277 |
+
|
278 |
+
latents_rgb, latents_alpha = latents.chunk(2, dim=1)
|
279 |
+
|
280 |
+
frames_rgb = decode_latents(pipe, latents_rgb)
|
281 |
+
frames_alpha = decode_latents(pipe, latents_alpha)
|
282 |
+
|
283 |
+
pooled_alpha = np.max(frames_alpha, axis=-1, keepdims=True)
|
284 |
+
frames_alpha_pooled = np.repeat(pooled_alpha, 3, axis=-1)
|
285 |
+
premultiplied_rgb = frames_rgb * frames_alpha_pooled
|
286 |
+
|
287 |
+
rgb_video_path = save_video(premultiplied_rgb[0], fps=8, prefix='rgb')
|
288 |
+
rgb_video_update = gr.update(visible=True, value=rgb_video_path)
|
289 |
+
|
290 |
+
alpha_video_path = save_video(frames_alpha_pooled[0], fps=8, prefix='alpha')
|
291 |
+
alpha_video_update = gr.update(visible=True, value=alpha_video_path)
|
292 |
+
seed_update = gr.update(visible=True, value=seed)
|
293 |
+
|
294 |
+
return rgb_video_path, alpha_video_path, rgb_video_update, alpha_video_update, seed_update
|
295 |
+
|
296 |
+
|
297 |
+
generate_button.click(
|
298 |
+
generate,
|
299 |
+
inputs=[prompt, seed_param],
|
300 |
+
outputs=[rgb_video_output, alpha_video_output, download_rgb_video_button, download_alpha_video_button, seed_text],
|
301 |
+
)
|
302 |
+
|
303 |
+
|
304 |
+
if __name__ == "__main__":
|
305 |
+
demo.queue(max_size=15)
|
306 |
+
demo.launch(share=True)
|
model_cogvideox_rgba_lora/cogvideox_rgba_lora.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1865b9208cbfaca5ac8d9cec0c0e3788a32834dea11873d2dee653303e72006e
|
3 |
+
size 264292376
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.4.0
|
2 |
+
torchvision
|
3 |
+
torchaudio
|
4 |
+
wandb
|
5 |
+
gradio
|
6 |
+
sentencepiece
|
7 |
+
diffusers==0.32.0
|
8 |
+
huggingface_hub==0.27.0
|
9 |
+
transformers
|
10 |
+
imageio>=2.5.0
|
11 |
+
imageio-ffmpeg
|
12 |
+
moviepy==1.0.3
|
13 |
+
opencv-python>=4.5
|
14 |
+
accelerate
|
wechat_group.jpg
ADDED
![]() |
Git LFS Details
|