LoveHandles commited on
Commit
d4a8a38
·
verified ·
1 Parent(s): c8f4b95

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ wechat_group.jpg filter=lfs diff=lfs merge=lfs -text
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
CogVideoX/__pycache__/pipeline_rgba.cpython-310.pyc ADDED
Binary file (25.8 kB). View file
 
CogVideoX/__pycache__/rgba_utils.cpython-310.pyc ADDED
Binary file (9.28 kB). View file
 
CogVideoX/cli.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from diffusers import CogVideoXDPMScheduler
4
+ from pipeline_rgba import CogVideoXPipeline
5
+ from diffusers.utils import export_to_video
6
+ import argparse
7
+ import numpy as np
8
+ from rgba_utils import *
9
+
10
+ def main(args):
11
+ # 1. load pipeline
12
+ pipe = CogVideoXPipeline.from_pretrained(args.model_path, torch_dtype=torch.bfloat16)
13
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
14
+ pipe.enable_sequential_cpu_offload()
15
+ pipe.vae.enable_slicing()
16
+ pipe.vae.enable_tiling()
17
+
18
+
19
+ # 2. define prompt and arguments
20
+ pipeline_args = {
21
+ "prompt": args.prompt,
22
+ "guidance_scale": args.guidance_scale,
23
+ "num_inference_steps": args.num_inference_steps,
24
+ "height": args.height,
25
+ "width": args.width,
26
+ "num_frames": args.num_frames,
27
+ "output_type": "latent",
28
+ "use_dynamic_cfg":True,
29
+ }
30
+
31
+ # 3. prepare rgbx utils
32
+ # breakpoint()
33
+ seq_length = 2 * (
34
+ (args.height // pipe.vae_scale_factor_spatial // 2)
35
+ * (args.width // pipe.vae_scale_factor_spatial // 2)
36
+ * ((args.num_frames - 1) // pipe.vae_scale_factor_temporal + 1)
37
+ )
38
+ # seq_length = 35100
39
+
40
+ prepare_for_rgba_inference(
41
+ pipe.transformer,
42
+ rgba_weights_path=args.lora_path,
43
+ device="cuda",
44
+ dtype=torch.bfloat16,
45
+ text_length=226,
46
+ seq_length=seq_length, # this is for the creation of attention mask.
47
+ )
48
+
49
+ # 4. run inference
50
+ generator = torch.manual_seed(args.seed) if args.seed else None
51
+ frames_latents = pipe(**pipeline_args, generator=generator).frames
52
+
53
+ frames_latents_rgb, frames_latents_alpha = frames_latents.chunk(2, dim=1)
54
+
55
+ frames_rgb = decode_latents(pipe, frames_latents_rgb)
56
+ frames_alpha = decode_latents(pipe, frames_latents_alpha)
57
+
58
+
59
+ pooled_alpha = np.max(frames_alpha, axis=-1, keepdims=True)
60
+ frames_alpha_pooled = np.repeat(pooled_alpha, 3, axis=-1)
61
+ premultiplied_rgb = frames_rgb * frames_alpha_pooled
62
+
63
+ if os.path.exists(args.output_path) == False:
64
+ os.makedirs(args.output_path)
65
+
66
+ export_to_video(premultiplied_rgb[0], os.path.join(args.output_path, "rgb.mp4"), fps=args.fps)
67
+ export_to_video(frames_alpha_pooled[0], os.path.join(args.output_path, "alpha.mp4"), fps=args.fps)
68
+
69
+
70
+ if __name__ == "__main__":
71
+ parser = argparse.ArgumentParser(description="Generate a video from a text prompt")
72
+ parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
73
+ parser.add_argument("--lora_path", type=str, default="/hpc2hdd/home/lwang592/projects/CogVideo/sat/outputs/training/ckpts-5b-attn_rebias-partial_lora-8gpu-wo_t2a/lora-rgba-12-21-19-11/5000/rgba_lora.safetensors", help="The path of the LoRA weights to be used")
74
+
75
+ parser.add_argument(
76
+ "--model_path", type=str, default="THUDM/CogVideoX-5B", help="Path of the pre-trained model use"
77
+ )
78
+
79
+
80
+ parser.add_argument("--output_path", type=str, default="./output", help="The path save generated video")
81
+ parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
82
+ parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
83
+ parser.add_argument("--num_frames", type=int, default=49, help="Number of steps for the inference process")
84
+ parser.add_argument("--width", type=int, default=720, help="Number of steps for the inference process")
85
+ parser.add_argument("--height", type=int, default=480, help="Number of steps for the inference process")
86
+ parser.add_argument("--fps", type=int, default=8, help="Number of steps for the inference process")
87
+ parser.add_argument("--seed", type=int, default=None, help="The seed for reproducibility")
88
+ args = parser.parse_args()
89
+
90
+ main(args)
CogVideoX/pipeline_rgba.py ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ from transformers import T5EncoderModel, T5Tokenizer
22
+
23
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
24
+ from diffusers.loaders import CogVideoXLoraLoaderMixin
25
+ from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
26
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
28
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
29
+ from diffusers.utils import logging, replace_example_docstring
30
+ from diffusers.utils.torch_utils import randn_tensor
31
+ from diffusers.video_processor import VideoProcessor
32
+ from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ EXAMPLE_DOC_STRING = """
39
+ Examples:
40
+ ```python
41
+ >>> import torch
42
+ >>> from diffusers import CogVideoXPipeline
43
+ >>> from diffusers.utils import export_to_video
44
+
45
+ >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
46
+ >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
47
+ >>> prompt = (
48
+ ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
49
+ ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
50
+ ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
51
+ ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
52
+ ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
53
+ ... "atmosphere of this unique musical performance."
54
+ ... )
55
+ >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
56
+ >>> export_to_video(video, "output.mp4", fps=8)
57
+ ```
58
+ """
59
+
60
+
61
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
62
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
63
+ tw = tgt_width
64
+ th = tgt_height
65
+ h, w = src
66
+ r = h / w
67
+ if r > (th / tw):
68
+ resize_height = th
69
+ resize_width = int(round(th / h * w))
70
+ else:
71
+ resize_width = tw
72
+ resize_height = int(round(tw / w * h))
73
+
74
+ crop_top = int(round((th - resize_height) / 2.0))
75
+ crop_left = int(round((tw - resize_width) / 2.0))
76
+
77
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
78
+
79
+
80
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
81
+ def retrieve_timesteps(
82
+ scheduler,
83
+ num_inference_steps: Optional[int] = None,
84
+ device: Optional[Union[str, torch.device]] = None,
85
+ timesteps: Optional[List[int]] = None,
86
+ sigmas: Optional[List[float]] = None,
87
+ **kwargs,
88
+ ):
89
+ r"""
90
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
91
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
92
+
93
+ Args:
94
+ scheduler (`SchedulerMixin`):
95
+ The scheduler to get timesteps from.
96
+ num_inference_steps (`int`):
97
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
98
+ must be `None`.
99
+ device (`str` or `torch.device`, *optional*):
100
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
101
+ timesteps (`List[int]`, *optional*):
102
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
103
+ `num_inference_steps` and `sigmas` must be `None`.
104
+ sigmas (`List[float]`, *optional*):
105
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
106
+ `num_inference_steps` and `timesteps` must be `None`.
107
+
108
+ Returns:
109
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
110
+ second element is the number of inference steps.
111
+ """
112
+ if timesteps is not None and sigmas is not None:
113
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
114
+ if timesteps is not None:
115
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
116
+ if not accepts_timesteps:
117
+ raise ValueError(
118
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
119
+ f" timestep schedules. Please check whether you are using the correct scheduler."
120
+ )
121
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
122
+ timesteps = scheduler.timesteps
123
+ num_inference_steps = len(timesteps)
124
+ elif sigmas is not None:
125
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
126
+ if not accept_sigmas:
127
+ raise ValueError(
128
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
129
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
130
+ )
131
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
132
+ timesteps = scheduler.timesteps
133
+ num_inference_steps = len(timesteps)
134
+ else:
135
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
136
+ timesteps = scheduler.timesteps
137
+ return timesteps, num_inference_steps
138
+
139
+
140
+ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
141
+ r"""
142
+ Pipeline for text-to-video generation using CogVideoX.
143
+
144
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
145
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
146
+
147
+ Args:
148
+ vae ([`AutoencoderKL`]):
149
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
150
+ text_encoder ([`T5EncoderModel`]):
151
+ Frozen text-encoder. CogVideoX uses
152
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
153
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
154
+ tokenizer (`T5Tokenizer`):
155
+ Tokenizer of class
156
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
157
+ transformer ([`CogVideoXTransformer3DModel`]):
158
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
159
+ scheduler ([`SchedulerMixin`]):
160
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
161
+ """
162
+
163
+ _optional_components = []
164
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
165
+
166
+ _callback_tensor_inputs = [
167
+ "latents",
168
+ "prompt_embeds",
169
+ "negative_prompt_embeds",
170
+ ]
171
+
172
+ def __init__(
173
+ self,
174
+ tokenizer: T5Tokenizer,
175
+ text_encoder: T5EncoderModel,
176
+ vae: AutoencoderKLCogVideoX,
177
+ transformer: CogVideoXTransformer3DModel,
178
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
179
+ ):
180
+ super().__init__()
181
+
182
+ self.register_modules(
183
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
184
+ )
185
+ self.vae_scale_factor_spatial = (
186
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
187
+ )
188
+ self.vae_scale_factor_temporal = (
189
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
190
+ )
191
+ self.vae_scaling_factor_image = (
192
+ self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
193
+ )
194
+
195
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
196
+
197
+ def _get_t5_prompt_embeds(
198
+ self,
199
+ prompt: Union[str, List[str]] = None,
200
+ num_videos_per_prompt: int = 1,
201
+ max_sequence_length: int = 226,
202
+ device: Optional[torch.device] = None,
203
+ dtype: Optional[torch.dtype] = None,
204
+ ):
205
+ device = device or self._execution_device
206
+ dtype = dtype or self.text_encoder.dtype
207
+
208
+ prompt = [prompt] if isinstance(prompt, str) else prompt
209
+ batch_size = len(prompt)
210
+
211
+ text_inputs = self.tokenizer(
212
+ prompt,
213
+ padding="max_length",
214
+ max_length=max_sequence_length,
215
+ truncation=True,
216
+ add_special_tokens=True,
217
+ return_tensors="pt",
218
+ )
219
+ text_input_ids = text_inputs.input_ids
220
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
221
+
222
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
223
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
224
+ logger.warning(
225
+ "The following part of your input was truncated because `max_sequence_length` is set to "
226
+ f" {max_sequence_length} tokens: {removed_text}"
227
+ )
228
+
229
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
230
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
231
+
232
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
233
+ _, seq_len, _ = prompt_embeds.shape
234
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
235
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
236
+
237
+ return prompt_embeds
238
+
239
+ def encode_prompt(
240
+ self,
241
+ prompt: Union[str, List[str]],
242
+ negative_prompt: Optional[Union[str, List[str]]] = None,
243
+ do_classifier_free_guidance: bool = True,
244
+ num_videos_per_prompt: int = 1,
245
+ prompt_embeds: Optional[torch.Tensor] = None,
246
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
247
+ max_sequence_length: int = 226,
248
+ device: Optional[torch.device] = None,
249
+ dtype: Optional[torch.dtype] = None,
250
+ ):
251
+ r"""
252
+ Encodes the prompt into text encoder hidden states.
253
+
254
+ Args:
255
+ prompt (`str` or `List[str]`, *optional*):
256
+ prompt to be encoded
257
+ negative_prompt (`str` or `List[str]`, *optional*):
258
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
259
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
260
+ less than `1`).
261
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
262
+ Whether to use classifier free guidance or not.
263
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
264
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
265
+ prompt_embeds (`torch.Tensor`, *optional*):
266
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
267
+ provided, text embeddings will be generated from `prompt` input argument.
268
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
269
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
270
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
271
+ argument.
272
+ device: (`torch.device`, *optional*):
273
+ torch device
274
+ dtype: (`torch.dtype`, *optional*):
275
+ torch dtype
276
+ """
277
+ device = device or self._execution_device
278
+
279
+ prompt = [prompt] if isinstance(prompt, str) else prompt
280
+ if prompt is not None:
281
+ batch_size = len(prompt)
282
+ else:
283
+ batch_size = prompt_embeds.shape[0]
284
+
285
+ if prompt_embeds is None:
286
+ prompt_embeds = self._get_t5_prompt_embeds(
287
+ prompt=prompt,
288
+ num_videos_per_prompt=num_videos_per_prompt,
289
+ max_sequence_length=max_sequence_length,
290
+ device=device,
291
+ dtype=dtype,
292
+ )
293
+
294
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
295
+ negative_prompt = negative_prompt or ""
296
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
297
+
298
+ if prompt is not None and type(prompt) is not type(negative_prompt):
299
+ raise TypeError(
300
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
301
+ f" {type(prompt)}."
302
+ )
303
+ elif batch_size != len(negative_prompt):
304
+ raise ValueError(
305
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
306
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
307
+ " the batch size of `prompt`."
308
+ )
309
+
310
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
311
+ prompt=negative_prompt,
312
+ num_videos_per_prompt=num_videos_per_prompt,
313
+ max_sequence_length=max_sequence_length,
314
+ device=device,
315
+ dtype=dtype,
316
+ )
317
+
318
+ return prompt_embeds, negative_prompt_embeds
319
+
320
+ def prepare_latents(
321
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
322
+ ):
323
+ if isinstance(generator, list) and len(generator) != batch_size:
324
+ raise ValueError(
325
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
326
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
327
+ )
328
+
329
+ shape = (
330
+ batch_size,
331
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
332
+ num_channels_latents,
333
+ height // self.vae_scale_factor_spatial,
334
+ width // self.vae_scale_factor_spatial,
335
+ )
336
+
337
+ if latents is None:
338
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
339
+ else:
340
+ latents = latents.to(device)
341
+
342
+ # scale the initial noise by the standard deviation required by the scheduler
343
+ latents = latents * self.scheduler.init_noise_sigma
344
+ return latents
345
+
346
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
347
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
348
+ latents = 1 / self.vae_scaling_factor_image * latents
349
+
350
+ frames = self.vae.decode(latents).sample
351
+ return frames
352
+
353
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
354
+ def prepare_extra_step_kwargs(self, generator, eta):
355
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
356
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
357
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
358
+ # and should be between [0, 1]
359
+
360
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
361
+ extra_step_kwargs = {}
362
+ if accepts_eta:
363
+ extra_step_kwargs["eta"] = eta
364
+
365
+ # check if the scheduler accepts generator
366
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
367
+ if accepts_generator:
368
+ extra_step_kwargs["generator"] = generator
369
+ return extra_step_kwargs
370
+
371
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
372
+ def check_inputs(
373
+ self,
374
+ prompt,
375
+ height,
376
+ width,
377
+ negative_prompt,
378
+ callback_on_step_end_tensor_inputs,
379
+ prompt_embeds=None,
380
+ negative_prompt_embeds=None,
381
+ ):
382
+ if height % 8 != 0 or width % 8 != 0:
383
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
384
+
385
+ if callback_on_step_end_tensor_inputs is not None and not all(
386
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
387
+ ):
388
+ raise ValueError(
389
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
390
+ )
391
+ if prompt is not None and prompt_embeds is not None:
392
+ raise ValueError(
393
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
394
+ " only forward one of the two."
395
+ )
396
+ elif prompt is None and prompt_embeds is None:
397
+ raise ValueError(
398
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
399
+ )
400
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
401
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
402
+
403
+ if prompt is not None and negative_prompt_embeds is not None:
404
+ raise ValueError(
405
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
406
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
407
+ )
408
+
409
+ if negative_prompt is not None and negative_prompt_embeds is not None:
410
+ raise ValueError(
411
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
412
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
413
+ )
414
+
415
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
416
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
417
+ raise ValueError(
418
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
419
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
420
+ f" {negative_prompt_embeds.shape}."
421
+ )
422
+
423
+ def fuse_qkv_projections(self) -> None:
424
+ r"""Enables fused QKV projections."""
425
+ self.fusing_transformer = True
426
+ self.transformer.fuse_qkv_projections()
427
+
428
+ def unfuse_qkv_projections(self) -> None:
429
+ r"""Disable QKV projection fusion if enabled."""
430
+ if not self.fusing_transformer:
431
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
432
+ else:
433
+ self.transformer.unfuse_qkv_projections()
434
+ self.fusing_transformer = False
435
+
436
+ def _prepare_rotary_positional_embeddings(
437
+ self,
438
+ height: int,
439
+ width: int,
440
+ num_frames: int,
441
+ device: torch.device,
442
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
443
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
444
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
445
+ base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
446
+ base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
447
+
448
+ grid_crops_coords = get_resize_crop_region_for_grid(
449
+ (grid_height, grid_width), base_size_width, base_size_height
450
+ )
451
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
452
+ embed_dim=self.transformer.config.attention_head_dim,
453
+ crops_coords=grid_crops_coords,
454
+ grid_size=(grid_height, grid_width),
455
+ temporal_size=num_frames,
456
+ )
457
+
458
+ freqs_cos = freqs_cos.to(device=device)
459
+ freqs_sin = freqs_sin.to(device=device)
460
+ return freqs_cos, freqs_sin
461
+
462
+ @property
463
+ def guidance_scale(self):
464
+ return self._guidance_scale
465
+
466
+ @property
467
+ def num_timesteps(self):
468
+ return self._num_timesteps
469
+
470
+ @property
471
+ def attention_kwargs(self):
472
+ return self._attention_kwargs
473
+
474
+ @property
475
+ def interrupt(self):
476
+ return self._interrupt
477
+
478
+ @torch.no_grad()
479
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
480
+ def __call__(
481
+ self,
482
+ prompt: Optional[Union[str, List[str]]] = None,
483
+ negative_prompt: Optional[Union[str, List[str]]] = None,
484
+ height: int = 480,
485
+ width: int = 720,
486
+ num_frames: int = 10, #was 49
487
+ num_inference_steps: int = 5, #was 50
488
+ timesteps: Optional[List[int]] = None,
489
+ guidance_scale: float = 6,
490
+ use_dynamic_cfg: bool = False,
491
+ num_videos_per_prompt: int = 1,
492
+ eta: float = 0.0,
493
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
494
+ latents: Optional[torch.FloatTensor] = None,
495
+ prompt_embeds: Optional[torch.FloatTensor] = None,
496
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
497
+ output_type: str = "pil",
498
+ return_dict: bool = True,
499
+ attention_kwargs: Optional[Dict[str, Any]] = None,
500
+ callback_on_step_end: Optional[
501
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
502
+ ] = None,
503
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
504
+ max_sequence_length: int = 226,
505
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
506
+ """
507
+ Function invoked when calling the pipeline for generation.
508
+
509
+ Args:
510
+ prompt (`str` or `List[str]`, *optional*):
511
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
512
+ instead.
513
+ negative_prompt (`str` or `List[str]`, *optional*):
514
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
515
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
516
+ less than `1`).
517
+ height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
518
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
519
+ width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
520
+ The width in pixels of the generated image. This is set to 720 by default for the best results.
521
+ num_frames (`int`, defaults to `48`):
522
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
523
+ contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
524
+ num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
525
+ needs to be satisfied is that of divisibility mentioned above.
526
+ num_inference_steps (`int`, *optional*, defaults to 50):
527
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
528
+ expense of slower inference.
529
+ timesteps (`List[int]`, *optional*):
530
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
531
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
532
+ passed will be used. Must be in descending order.
533
+ guidance_scale (`float`, *optional*, defaults to 7.0):
534
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
535
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
536
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
537
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
538
+ usually at the expense of lower image quality.
539
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
540
+ The number of videos to generate per prompt.
541
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
542
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
543
+ to make generation deterministic.
544
+ latents (`torch.FloatTensor`, *optional*):
545
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
546
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
547
+ tensor will ge generated by sampling using the supplied random `generator`.
548
+ prompt_embeds (`torch.FloatTensor`, *optional*):
549
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
550
+ provided, text embeddings will be generated from `prompt` input argument.
551
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
552
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
553
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
554
+ argument.
555
+ output_type (`str`, *optional*, defaults to `"pil"`):
556
+ The output format of the generate image. Choose between
557
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
558
+ return_dict (`bool`, *optional*, defaults to `True`):
559
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
560
+ of a plain tuple.
561
+ attention_kwargs (`dict`, *optional*):
562
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
563
+ `self.processor` in
564
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
565
+ callback_on_step_end (`Callable`, *optional*):
566
+ A function that calls at the end of each denoising steps during the inference. The function is called
567
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
568
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
569
+ `callback_on_step_end_tensor_inputs`.
570
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
571
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
572
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
573
+ `._callback_tensor_inputs` attribute of your pipeline class.
574
+ max_sequence_length (`int`, defaults to `226`):
575
+ Maximum sequence length in encoded prompt. Must be consistent with
576
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
577
+
578
+ Examples:
579
+
580
+ Returns:
581
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
582
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
583
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
584
+ """
585
+
586
+ if num_frames > 49:
587
+ raise ValueError(
588
+ "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
589
+ )
590
+
591
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
592
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
593
+
594
+ num_videos_per_prompt = 1
595
+
596
+ # 1. Check inputs. Raise error if not correct
597
+ self.check_inputs(
598
+ prompt,
599
+ height,
600
+ width,
601
+ negative_prompt,
602
+ callback_on_step_end_tensor_inputs,
603
+ prompt_embeds,
604
+ negative_prompt_embeds,
605
+ )
606
+ self._guidance_scale = guidance_scale
607
+ self._attention_kwargs = attention_kwargs
608
+ self._interrupt = False
609
+
610
+ # 2. Default call parameters
611
+ if prompt is not None and isinstance(prompt, str):
612
+ batch_size = 1
613
+ elif prompt is not None and isinstance(prompt, list):
614
+ batch_size = len(prompt)
615
+ else:
616
+ batch_size = prompt_embeds.shape[0]
617
+
618
+ device = self._execution_device
619
+
620
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
621
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
622
+ # corresponds to doing no classifier free guidance.
623
+ do_classifier_free_guidance = guidance_scale > 1.0
624
+
625
+ # 3. Encode input prompt
626
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
627
+ prompt,
628
+ negative_prompt,
629
+ do_classifier_free_guidance,
630
+ num_videos_per_prompt=num_videos_per_prompt,
631
+ prompt_embeds=prompt_embeds,
632
+ negative_prompt_embeds=negative_prompt_embeds,
633
+ max_sequence_length=max_sequence_length,
634
+ device=device,
635
+ )
636
+ if do_classifier_free_guidance:
637
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
638
+
639
+ # 4. Prepare timesteps
640
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
641
+ self._num_timesteps = len(timesteps)
642
+
643
+ # 5. Prepare latents.
644
+ latent_channels = self.transformer.config.in_channels
645
+ latents = self.prepare_latents(
646
+ batch_size * num_videos_per_prompt,
647
+ latent_channels,
648
+ num_frames,
649
+ height,
650
+ width,
651
+ prompt_embeds.dtype,
652
+ device,
653
+ generator,
654
+ latents,
655
+ ).repeat(1,2,1,1,1) # Luozhou
656
+
657
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
658
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
659
+
660
+ # 7. Create rotary embeds if required
661
+ image_rotary_emb = (
662
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1) // 2, device) # Luozhou
663
+ if self.transformer.config.use_rotary_positional_embeddings
664
+ else None
665
+ )
666
+
667
+ # 8. Denoising loop
668
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
669
+
670
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
671
+ # for DPM-solver++
672
+ old_pred_original_sample = None
673
+ for i, t in enumerate(timesteps):
674
+ if self.interrupt:
675
+ continue
676
+
677
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
678
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
679
+
680
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
681
+ timestep = t.expand(latent_model_input.shape[0])
682
+
683
+ # predict noise model_output
684
+ noise_pred = self.transformer(
685
+ hidden_states=latent_model_input,
686
+ encoder_hidden_states=prompt_embeds,
687
+ timestep=timestep,
688
+ image_rotary_emb=image_rotary_emb,
689
+ attention_kwargs=attention_kwargs,
690
+ return_dict=False,
691
+ )[0]
692
+ noise_pred = noise_pred.float()
693
+
694
+ # perform guidance
695
+ if use_dynamic_cfg:
696
+ self._guidance_scale = 1 + guidance_scale * (
697
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
698
+ )
699
+ if do_classifier_free_guidance:
700
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
701
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
702
+
703
+ # compute the previous noisy sample x_t -> x_t-1
704
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
705
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
706
+ else:
707
+ latents, old_pred_original_sample = self.scheduler.step(
708
+ noise_pred,
709
+ old_pred_original_sample,
710
+ t,
711
+ timesteps[i - 1] if i > 0 else None,
712
+ latents,
713
+ **extra_step_kwargs,
714
+ return_dict=False,
715
+ )
716
+ latents = latents.to(prompt_embeds.dtype)
717
+
718
+ # call the callback, if provided
719
+ if callback_on_step_end is not None:
720
+ callback_kwargs = {}
721
+ for k in callback_on_step_end_tensor_inputs:
722
+ callback_kwargs[k] = locals()[k]
723
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
724
+
725
+ latents = callback_outputs.pop("latents", latents)
726
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
727
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
728
+
729
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
730
+ progress_bar.update()
731
+
732
+ if not output_type == "latent":
733
+ video = self.decode_latents(latents)
734
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
735
+ else:
736
+ video = latents
737
+
738
+ # Offload all models
739
+ self.maybe_free_model_hooks()
740
+
741
+ if not return_dict:
742
+ return (video,)
743
+
744
+ return CogVideoXPipelineOutput(frames=video)
CogVideoX/rgba_utils.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Any, Dict, Optional, Tuple, Union
5
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
6
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
7
+ from safetensors.torch import load_file
8
+
9
+ logger = logging.get_logger(__name__)
10
+
11
+ @torch.no_grad()
12
+ def decode_latents(pipe, latents):
13
+ video = pipe.decode_latents(latents)
14
+ video = pipe.video_processor.postprocess_video(video=video, output_type="np")
15
+ return video
16
+
17
+ def create_attention_mask(text_length: int, seq_length: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
18
+ """
19
+ Create an attention mask to block text from attending to alpha.
20
+
21
+ Args:
22
+ text_length: Length of the text sequence.
23
+ seq_length: Length of the other sequence.
24
+ device: The device where the mask will be stored.
25
+ dtype: The data type of the mask tensor.
26
+
27
+ Returns:
28
+ An attention mask tensor.
29
+ """
30
+ total_length = text_length + seq_length
31
+ dense_mask = torch.ones((total_length, total_length), dtype=torch.bool)
32
+ dense_mask[:text_length, text_length + seq_length // 2:] = False
33
+ return dense_mask.to(device=device, dtype=dtype)
34
+
35
+ class RGBALoRACogVideoXAttnProcessor:
36
+ r"""
37
+ Processor for implementing scaled dot-product attention for the CogVideoX model.
38
+ It applies a rotary embedding on query and key vectors, but does not include spatial normalization.
39
+ """
40
+
41
+ def __init__(self, device, dtype, attention_mask, lora_rank=128, lora_alpha=1.0, latent_dim=3072):
42
+ if not hasattr(F, "scaled_dot_product_attention"):
43
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0 or later.")
44
+
45
+ # Initialize LoRA layers
46
+ self.lora_alpha = lora_alpha
47
+ self.lora_rank = lora_rank
48
+
49
+ # Helper function to create LoRA layers
50
+ def create_lora_layer(in_dim, mid_dim, out_dim):
51
+ return nn.Sequential(
52
+ nn.Linear(in_dim, mid_dim, bias=False, device=device, dtype=dtype),
53
+ nn.Linear(mid_dim, out_dim, bias=False, device=device, dtype=dtype)
54
+ )
55
+
56
+ self.to_q_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
57
+ self.to_k_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
58
+ self.to_v_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
59
+ self.to_out_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
60
+
61
+ # Store attention mask
62
+ self.attention_mask = attention_mask
63
+
64
+ def _apply_lora(self, hidden_states, seq_len, query, key, value, scaling):
65
+ """Applies LoRA updates to query, key, and value tensors."""
66
+ query_delta = self.to_q_lora(hidden_states).to(query.device)
67
+ query[:, -seq_len // 2:, :] += query_delta[:, -seq_len // 2:, :] * scaling
68
+
69
+ key_delta = self.to_k_lora(hidden_states).to(key.device)
70
+ key[:, -seq_len // 2:, :] += key_delta[:, -seq_len // 2:, :] * scaling
71
+
72
+ value_delta = self.to_v_lora(hidden_states).to(value.device)
73
+ value[:, -seq_len // 2:, :] += value_delta[:, -seq_len // 2:, :] * scaling
74
+
75
+ return query, key, value
76
+
77
+ def _apply_rotary_embedding(self, query, key, image_rotary_emb, seq_len, text_seq_length, attn):
78
+ """Applies rotary embeddings to query and key tensors."""
79
+ from diffusers.models.embeddings import apply_rotary_emb
80
+
81
+ # Apply rotary embedding to RGB and alpha sections
82
+ query[:, :, text_seq_length:text_seq_length + seq_len // 2] = apply_rotary_emb(
83
+ query[:, :, text_seq_length:text_seq_length + seq_len // 2], image_rotary_emb)
84
+ query[:, :, text_seq_length + seq_len // 2:] = apply_rotary_emb(
85
+ query[:, :, text_seq_length + seq_len // 2:], image_rotary_emb)
86
+
87
+ if not attn.is_cross_attention:
88
+ key[:, :, text_seq_length:text_seq_length + seq_len // 2] = apply_rotary_emb(
89
+ key[:, :, text_seq_length:text_seq_length + seq_len // 2], image_rotary_emb)
90
+ key[:, :, text_seq_length + seq_len // 2:] = apply_rotary_emb(
91
+ key[:, :, text_seq_length + seq_len // 2:], image_rotary_emb)
92
+
93
+ return query, key
94
+
95
+ def __call__(
96
+ self,
97
+ attn,
98
+ hidden_states: torch.Tensor,
99
+ encoder_hidden_states: torch.Tensor,
100
+ attention_mask: Optional[torch.Tensor] = None,
101
+ image_rotary_emb: Optional[torch.Tensor] = None,
102
+ ) -> torch.Tensor:
103
+ # Concatenate encoder and decoder hidden states
104
+ text_seq_length = encoder_hidden_states.size(1)
105
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
106
+
107
+ batch_size, sequence_length, _ = hidden_states.shape
108
+ seq_len = hidden_states.shape[1] - text_seq_length
109
+ scaling = self.lora_alpha / self.lora_rank
110
+
111
+ # Apply LoRA to query, key, value
112
+ query = attn.to_q(hidden_states)
113
+ key = attn.to_k(hidden_states)
114
+ value = attn.to_v(hidden_states)
115
+
116
+ query, key, value = self._apply_lora(hidden_states, seq_len, query, key, value, scaling)
117
+
118
+ # Reshape query, key, value for multi-head attention
119
+ inner_dim = key.shape[-1]
120
+ head_dim = inner_dim // attn.heads
121
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
122
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
123
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
124
+
125
+ # Normalize query and key if required
126
+ if attn.norm_q is not None:
127
+ query = attn.norm_q(query)
128
+ if attn.norm_k is not None:
129
+ key = attn.norm_k(key)
130
+
131
+ # Apply rotary embeddings if provided
132
+ if image_rotary_emb is not None:
133
+ query, key = self._apply_rotary_embedding(query, key, image_rotary_emb, seq_len, text_seq_length, attn)
134
+
135
+ # Compute scaled dot-product attention
136
+ hidden_states = F.scaled_dot_product_attention(
137
+ query, key, value, attn_mask=self.attention_mask, dropout_p=0.0, is_causal=False
138
+ )
139
+
140
+ # Reshape the output tensor back to the original shape
141
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
142
+
143
+ # Apply linear projection and LoRA to the output
144
+ original_hidden_states = attn.to_out[0](hidden_states)
145
+ hidden_states_delta = self.to_out_lora(hidden_states).to(hidden_states.device)
146
+ original_hidden_states[:, -seq_len // 2:, :] += hidden_states_delta[:, -seq_len // 2:, :] * scaling
147
+
148
+ # Apply dropout
149
+ hidden_states = attn.to_out[1](original_hidden_states)
150
+
151
+ # Split back into encoder and decoder hidden states
152
+ encoder_hidden_states, hidden_states = hidden_states.split(
153
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
154
+ )
155
+
156
+ return hidden_states, encoder_hidden_states
157
+
158
+ def prepare_for_rgba_inference(
159
+ model, rgba_weights_path: str, device: torch.device, dtype: torch.dtype,
160
+ lora_rank: int = 128, lora_alpha: float = 1.0, text_length: int = 226, seq_length: int = 35100
161
+ ):
162
+ def load_lora_sequential_weights(lora_layer, lora_layers, prefix):
163
+ lora_layer[0].load_state_dict({'weight': lora_layers[f"{prefix}.lora_A.weight"]})
164
+ lora_layer[1].load_state_dict({'weight': lora_layers[f"{prefix}.lora_B.weight"]})
165
+
166
+
167
+ rgba_weights = load_file(rgba_weights_path)
168
+ aux_emb = rgba_weights['domain_emb']
169
+
170
+ attention_mask = create_attention_mask(text_length, seq_length, device, dtype)
171
+ attn_procs = {}
172
+
173
+ for name in model.attn_processors.keys():
174
+ attn_processor = RGBALoRACogVideoXAttnProcessor(
175
+ device=device, dtype=dtype, attention_mask=attention_mask,
176
+ lora_rank=lora_rank, lora_alpha=lora_alpha
177
+ )
178
+
179
+ index = name.split('.')[1]
180
+ base_prefix = f'transformer.transformer_blocks.{index}.attn1'
181
+
182
+ for lora_layer, prefix in [
183
+ (attn_processor.to_q_lora, f'{base_prefix}.to_q'),
184
+ (attn_processor.to_k_lora, f'{base_prefix}.to_k'),
185
+ (attn_processor.to_v_lora, f'{base_prefix}.to_v'),
186
+ (attn_processor.to_out_lora, f'{base_prefix}.to_out.0'),
187
+ ]:
188
+ load_lora_sequential_weights(lora_layer, rgba_weights, prefix)
189
+
190
+ attn_procs[name] = attn_processor
191
+
192
+ model.set_attn_processor(attn_procs)
193
+
194
+ def custom_forward(self):
195
+ def forward(
196
+ hidden_states: torch.Tensor,
197
+ encoder_hidden_states: torch.Tensor,
198
+ timestep: Union[int, float, torch.LongTensor],
199
+ timestep_cond: Optional[torch.Tensor] = None,
200
+ ofs: Optional[Union[int, float, torch.LongTensor]] = None,
201
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
202
+ attention_kwargs: Optional[Dict[str, Any]] = None,
203
+ return_dict: bool = True,
204
+ ):
205
+ if attention_kwargs is not None:
206
+ attention_kwargs = attention_kwargs.copy()
207
+ lora_scale = attention_kwargs.pop("scale", 1.0)
208
+ else:
209
+ lora_scale = 1.0
210
+
211
+ if USE_PEFT_BACKEND:
212
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
213
+ scale_lora_layers(self, lora_scale)
214
+ else:
215
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
216
+ logger.warning(
217
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
218
+ )
219
+
220
+ batch_size, num_frames, channels, height, width = hidden_states.shape
221
+
222
+ # 1. Time embedding
223
+ timesteps = timestep
224
+ t_emb = self.time_proj(timesteps)
225
+
226
+ # timesteps does not contain any weights and will always return f32 tensors
227
+ # but time_embedding might actually be running in fp16. so we need to cast here.
228
+ # there might be better ways to encapsulate this.
229
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
230
+ emb = self.time_embedding(t_emb, timestep_cond)
231
+
232
+ if self.ofs_embedding is not None:
233
+ ofs_emb = self.ofs_proj(ofs)
234
+ ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
235
+ ofs_emb = self.ofs_embedding(ofs_emb)
236
+ emb = emb + ofs_emb
237
+
238
+ # 2. Patch embedding
239
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
240
+ hidden_states = self.embedding_dropout(hidden_states)
241
+
242
+ text_seq_length = encoder_hidden_states.shape[1]
243
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
244
+ hidden_states = hidden_states[:, text_seq_length:]
245
+
246
+ hidden_states[:, hidden_states.size(1) // 2:, :] += aux_emb.expand(batch_size, -1, -1).to(hidden_states.device, dtype=hidden_states.dtype)
247
+
248
+ # 3. Transformer blocks
249
+ for i, block in enumerate(self.transformer_blocks):
250
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
251
+
252
+ def create_custom_forward(module):
253
+ def custom_forward(*inputs):
254
+ return module(*inputs)
255
+
256
+ return custom_forward
257
+
258
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
259
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
260
+ create_custom_forward(block),
261
+ hidden_states,
262
+ encoder_hidden_states,
263
+ emb,
264
+ image_rotary_emb,
265
+ **ckpt_kwargs,
266
+ )
267
+ else:
268
+ hidden_states, encoder_hidden_states = block(
269
+ hidden_states=hidden_states,
270
+ encoder_hidden_states=encoder_hidden_states,
271
+ temb=emb,
272
+ image_rotary_emb=image_rotary_emb,
273
+ )
274
+
275
+ if not self.config.use_rotary_positional_embeddings:
276
+ # CogVideoX-2B
277
+ hidden_states = self.norm_final(hidden_states)
278
+ else:
279
+ # CogVideoX-5B
280
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
281
+ hidden_states = self.norm_final(hidden_states)
282
+ hidden_states = hidden_states[:, text_seq_length:]
283
+
284
+ # 4. Final block
285
+ hidden_states = self.norm_out(hidden_states, temb=emb)
286
+ hidden_states = self.proj_out(hidden_states)
287
+
288
+ # 5. Unpatchify
289
+ p = self.config.patch_size
290
+ p_t = self.config.patch_size_t
291
+
292
+ if p_t is None:
293
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
294
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
295
+ else:
296
+ output = hidden_states.reshape(
297
+ batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
298
+ )
299
+ output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
300
+
301
+ if USE_PEFT_BACKEND:
302
+ # remove `lora_scale` from each PEFT layer
303
+ unscale_lora_layers(self, lora_scale)
304
+
305
+ if not return_dict:
306
+ return (output,)
307
+ return Transformer2DModelOutput(sample=output)
308
+
309
+
310
+ return forward
311
+
312
+ model.forward = custom_forward(model)
313
+
LICENSE.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **ADOBE RESEARCH LICENSE**
2
+
3
+ This license agreement (the “License”) between Adobe Inc., having a place of business at 345 Park Avenue, San Jose, California 95110-2704 (“Adobe”), and you, the individual or entity exercising rights under this License (“you” or “your”), sets forth the terms for your use of certain research materials that are owned by Adobe (the “Licensed Materials”). By exercising rights under this License, you accept and agree to be bound by its terms. If you are exercising rights under this License on behalf of an entity, then “you” means you and such entity, and you (personally) represent and warrant that you (personally) have all necessary authority to bind that entity to the terms of this License.
4
+
5
+ 1. **GRANT OF LICENSE.**<br/>
6
+ 1.1 Adobe grants you a nonexclusive, worldwide, royalty-free, revocable, fully paid license to (A) reproduce, use, modify, and publicly display the Licensed Materials for noncommercial research purposes only; and (B) redistribute the Licensed Materials, and modifications or derivative works thereof, for noncommercial research purposes only, provided that you give recipients a copy of this License upon redistribution.<br/>
7
+ 1.2 You may add your own copyright statement to your modifications and/or provide additional or different license terms for use, reproduction, modification, public display, and redistribution of your modifications and derivative works, provided that such license terms limit the use, reproduction, modification, public display, and redistribution of such modifications and derivative works to noncommercial research purposes only.<br/>
8
+ 1.3 For purposes of this License, noncommercial research purposes include academic research and teaching only. Noncommercial research purposes do not include commercial licensing or distribution, development of commercial products, or any other activity that results in commercial gain.<br/>
9
+ 2. **OWNERSHIP AND ATTRIBUTION.** Adobe and its licensors own all right, title, and interest in the Licensed Materials. You must retain all copyright notices and/or disclaimers in the Licensed Materials.
10
+ 3. **DISCLAIMER OF WARRANTIES.** THE LICENSED MATERIALS ARE PROVIDED “AS IS” WITHOUT WARRANTY OF ANY KIND. THE ENTIRE RISK AS TO THE USE, RESULTS, AND PERFORMANCE OF THE LICENSED MATERIALS IS ASSUMED BY YOU. ADOBE DISCLAIMS ALL WARRANTIES, EXPRESS, IMPLIED OR STATUTORY, WITH REGARD TO YOUR USE OF THE LICENSED MATERIALS, INCLUDING, BUT NOT LIMITED TO, NONINFRINGEMENT OF THIRD-PARTY RIGHTS.
11
+ 4. **LIMITATION OF LIABILITY.** IN NO EVENT WILL ADOBE BE LIABLE FOR ANY ACTUAL, INCIDENTAL, SPECIAL OR CONSEQUENTIAL DAMAGES, INCLUDING WITHOUT LIMITATION, LOSS OF PROFITS OR OTHER COMMERCIAL LOSS, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THE LICENSED MATERIALS, EVEN IF ADOBE HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
12
+ 5. **TERM AND TERMINATION.**<br/>
13
+ 5.1 The License is effective upon acceptance by you and will remain in effect unless terminated earlier in accordance with Section 5.2.<br/>
14
+ 5.2 Any breach of any material provision of this License will automatically terminate the rights granted herein.<br/>
15
+ 5.3 Sections 2 (Ownership and Attribution), 3 (Disclaimer of Warranties), 4 (Limitation of Liability) will survive termination of this License.<br/>
Mochi/README.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RGBA LoRA Training Instructions
2
+
3
+ <!-- <table align=center>
4
+ <tr>
5
+ <th align=center> Dataset Sample </th>
6
+ <th align=center> Test Sample </th>
7
+ </tr>
8
+ <tr>
9
+ <td align=center><video src="https://github.com/user-attachments/assets/6f906a32-b169-493f-a713-07679e87cd91"> Your browser does not support the video tag. </video></td>
10
+ <td align=center><video src="https://github.com/user-attachments/assets/d356e70f-ccf4-47f7-be1d-8d21108d8a84"> Your browser does not support the video tag. </video></td>
11
+ </tr>
12
+ </table> -->
13
+ <!--
14
+ Now you can make Mochi-1 your own with `diffusers`, too 🤗 🧨
15
+
16
+ We provide a minimal and faithful reimplementation of the [Mochi-1 original fine-tuner](https://github.com/genmoai/mochi/tree/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner). As usual, we leverage `peft` for things LoRA in our implementation.
17
+
18
+ **Updates**
19
+
20
+ December 1 2024: Support for checkpoint saving and loading. -->
21
+
22
+ We follow the same steps as the original [finetrainers](https://github.com/a-r-r-o-w/finetrainers/blob/main/training/mochi-1/README.md) to prepare the [RGBA dataset](https://grail.cs.washington.edu/projects/background-matting-v2/#/datasets).
23
+ For RGBA dataset, you can follow the instructions above to preprocess the dataset yourself.
24
+
25
+ Here are some detailed steps to prepare the dataset for Mochi-1 fine-tuning:
26
+
27
+ 1. Download our preprocessed [Video RGBA dataset](https://hkustgz-my.sharepoint.com/:u:/g/personal/lwang592_connect_hkust-gz_edu_cn/EezKQoum3IVJiJ9c8GebNfYBe-xN0OS5mVUvAwyL_rQLuw?e=1obdbA), which has undergone preprocessing operations such as color decontamination and background blur.
28
+ 2. Use `trim_and_crop_videos.py` to crop and trim the RGB and Alpha videos as needed.
29
+ 3. Use `embed.py` to encode the RGB videos into latent representations and embed the video captions into embeddings.
30
+ 4. Use `embed.py` to encode the Alpha videos into latent representations.
31
+ 5. Concatenate the RGB and Alpha latent representations along the frames dimension.
32
+
33
+ Finally, the dataset should be in the following format:
34
+ ```
35
+ <video_1_concatenated>.latent.pt
36
+ <video_1_captions>.embed.pt
37
+ <video_2_concatenated>.latent.pt
38
+ <video_2_captions>.embed.pt
39
+ ```
40
+
41
+
42
+ Now, we're ready to fine-tune. To launch, run:
43
+
44
+ ```bash
45
+ bash train.sh
46
+ ```
47
+ **Note:**
48
+
49
+ The arg `--num_frames` is used to specify the number of frames of generated **RGB** video. During generation, we will actually double the number of frames to generate the **RGB** video and **Alpha** video jointly. This double operation is automatically handled by our implementation.
50
+
51
+ For an 80GB GPU, we support processing RGB videos with dimensions of 480 × 848 × 79 (Height × Width × Frames) at a batch size of 1 using bfloat16 precision for training. However, the training is relatively slow (over one minute per iteration) because the model processes a total of 79 × 2 frames as input.
52
+
53
+
54
+
55
+
56
+
57
+ ~~We haven't rigorously tested but without validation enabled, this script should run under 40GBs of GPU VRAM.~~
58
+
59
+ ## Inference
60
+
61
+ To generate the RGBA video, run:
62
+
63
+ ```bash
64
+ python cli.py \
65
+ --lora_path /path/to/lora \
66
+ --prompt "..." \
67
+ ```
68
+
69
+ This command generates the RGB and Alpha videos simultaneously and saves them. Specifically, the RGB video is saved in its premultiplied form. To blend this video with any background image, you can simply use the following formula:
70
+
71
+ ```python
72
+ com = rgb + (1 - alpha) * bgr
73
+ ```
74
+
75
+ ## Known limitations
76
+
77
+ (Contributions are welcome 🤗)
78
+
79
+ Our script currently doesn't leverage `accelerate` and some of its consequences are detailed below:
80
+
81
+ * No support for distributed training.
82
+ * `train_batch_size > 1` are supported but can potentially lead to OOMs because we currently don't have gradient accumulation support.
83
+ * No support for 8bit optimizers (but should be relatively easy to add).
84
+
85
+ **Misc**:
86
+
87
+ * We're aware of the quality issues in the `diffusers` implementation of Mochi-1. This is being fixed in [this PR](https://github.com/huggingface/diffusers/pull/10033).
88
+ * `embed.py` script is non-batched.
Mochi/args.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Default values taken from
3
+ https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/configs/lora.yaml
4
+ when applicable.
5
+ """
6
+
7
+ import argparse
8
+
9
+
10
+ def _get_model_args(parser: argparse.ArgumentParser) -> None:
11
+ parser.add_argument(
12
+ "--pretrained_model_name_or_path",
13
+ type=str,
14
+ default=None,
15
+ required=True,
16
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
17
+ )
18
+ parser.add_argument(
19
+ "--revision",
20
+ type=str,
21
+ default=None,
22
+ required=False,
23
+ help="Revision of pretrained model identifier from huggingface.co/models.",
24
+ )
25
+ parser.add_argument(
26
+ "--variant",
27
+ type=str,
28
+ default=None,
29
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
30
+ )
31
+ parser.add_argument(
32
+ "--cache_dir",
33
+ type=str,
34
+ default=None,
35
+ help="The directory where the downloaded models and datasets will be stored.",
36
+ )
37
+ parser.add_argument(
38
+ "--cast_dit",
39
+ action="store_true",
40
+ help="If we should cast DiT params to a lower precision.",
41
+ )
42
+ parser.add_argument(
43
+ "--compile_dit",
44
+ action="store_true",
45
+ help="If we should compile the DiT.",
46
+ )
47
+
48
+
49
+ def _get_dataset_args(parser: argparse.ArgumentParser) -> None:
50
+ parser.add_argument(
51
+ "--data_root",
52
+ type=str,
53
+ default=None,
54
+ help=("A folder containing the training data."),
55
+ )
56
+ parser.add_argument(
57
+ "--caption_dropout",
58
+ type=float,
59
+ default=None,
60
+ help=("Probability to drop out captions randomly."),
61
+ )
62
+
63
+ parser.add_argument(
64
+ "--dataloader_num_workers",
65
+ type=int,
66
+ default=0,
67
+ help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
68
+ )
69
+ parser.add_argument(
70
+ "--pin_memory",
71
+ action="store_true",
72
+ help="Whether or not to use the pinned memory setting in pytorch dataloader.",
73
+ )
74
+
75
+
76
+ def _get_validation_args(parser: argparse.ArgumentParser) -> None:
77
+ parser.add_argument(
78
+ "--validation_prompt",
79
+ type=str,
80
+ default=None,
81
+ help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
82
+ )
83
+ parser.add_argument(
84
+ "--validation_images",
85
+ type=str,
86
+ default=None,
87
+ help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
88
+ )
89
+ parser.add_argument(
90
+ "--validation_prompt_separator",
91
+ type=str,
92
+ default=":::",
93
+ help="String that separates multiple validation prompts",
94
+ )
95
+ parser.add_argument(
96
+ "--num_validation_videos",
97
+ type=int,
98
+ default=1,
99
+ help="Number of videos that should be generated during validation per `validation_prompt`.",
100
+ )
101
+ parser.add_argument(
102
+ "--validation_epochs",
103
+ type=int,
104
+ default=50,
105
+ help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.",
106
+ )
107
+ parser.add_argument(
108
+ "--enable_slicing",
109
+ action="store_true",
110
+ default=False,
111
+ help="Whether or not to use VAE slicing for saving memory.",
112
+ )
113
+ parser.add_argument(
114
+ "--enable_tiling",
115
+ action="store_true",
116
+ default=False,
117
+ help="Whether or not to use VAE tiling for saving memory.",
118
+ )
119
+ parser.add_argument(
120
+ "--enable_model_cpu_offload",
121
+ action="store_true",
122
+ default=False,
123
+ help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.",
124
+ )
125
+ parser.add_argument(
126
+ "--fps",
127
+ type=int,
128
+ default=30,
129
+ help="FPS to use when serializing the output videos.",
130
+ )
131
+ parser.add_argument(
132
+ "--height",
133
+ type=int,
134
+ default=480,
135
+ )
136
+ parser.add_argument(
137
+ "--width",
138
+ type=int,
139
+ default=848,
140
+ )
141
+
142
+
143
+ def _get_training_args(parser: argparse.ArgumentParser) -> None:
144
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
145
+ parser.add_argument("--rank", type=int, default=16, help="The rank for LoRA matrices.")
146
+ parser.add_argument(
147
+ "--lora_alpha",
148
+ type=int,
149
+ default=16,
150
+ help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.",
151
+ )
152
+ parser.add_argument(
153
+ "--target_modules",
154
+ nargs="+",
155
+ type=str,
156
+ default=["to_k", "to_q", "to_v", "to_out.0"],
157
+ help="Target modules to train LoRA for.",
158
+ )
159
+ parser.add_argument(
160
+ "--output_dir",
161
+ type=str,
162
+ default="mochi-lora",
163
+ help="The output directory where the model predictions and checkpoints will be written.",
164
+ )
165
+ parser.add_argument(
166
+ "--train_batch_size",
167
+ type=int,
168
+ default=4,
169
+ help="Batch size (per device) for the training dataloader.",
170
+ )
171
+ parser.add_argument("--num_train_epochs", type=int, default=1)
172
+ parser.add_argument(
173
+ "--max_train_steps",
174
+ type=int,
175
+ default=None,
176
+ help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.",
177
+ )
178
+ parser.add_argument(
179
+ "--gradient_checkpointing",
180
+ action="store_true",
181
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
182
+ )
183
+ parser.add_argument(
184
+ "--learning_rate",
185
+ type=float,
186
+ default=2e-4,
187
+ help="Initial learning rate (after the potential warmup period) to use.",
188
+ )
189
+ parser.add_argument(
190
+ "--scale_lr",
191
+ action="store_true",
192
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
193
+ )
194
+ parser.add_argument(
195
+ "--lr_warmup_steps",
196
+ type=int,
197
+ default=200,
198
+ help="Number of steps for the warmup in the lr scheduler.",
199
+ )
200
+ parser.add_argument(
201
+ "--checkpointing_steps",
202
+ type=int,
203
+ default=1000,
204
+ )
205
+ parser.add_argument(
206
+ "--resume_from_checkpoint",
207
+ type=str,
208
+ default=None,
209
+ )
210
+
211
+
212
+ def _get_optimizer_args(parser: argparse.ArgumentParser) -> None:
213
+ parser.add_argument(
214
+ "--optimizer",
215
+ type=lambda s: s.lower(),
216
+ default="adam",
217
+ choices=["adam", "adamw"],
218
+ help=("The optimizer type to use."),
219
+ )
220
+ parser.add_argument(
221
+ "--weight_decay",
222
+ type=float,
223
+ default=0.01,
224
+ help="Weight decay to use for optimizer.",
225
+ )
226
+
227
+
228
+ def _get_configuration_args(parser: argparse.ArgumentParser) -> None:
229
+ parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name")
230
+ parser.add_argument(
231
+ "--push_to_hub",
232
+ action="store_true",
233
+ help="Whether or not to push the model to the Hub.",
234
+ )
235
+ parser.add_argument(
236
+ "--hub_token",
237
+ type=str,
238
+ default=None,
239
+ help="The token to use to push to the Model Hub.",
240
+ )
241
+ parser.add_argument(
242
+ "--hub_model_id",
243
+ type=str,
244
+ default=None,
245
+ help="The name of the repository to keep in sync with the local `output_dir`.",
246
+ )
247
+ parser.add_argument(
248
+ "--allow_tf32",
249
+ action="store_true",
250
+ help=(
251
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
252
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
253
+ ),
254
+ )
255
+ parser.add_argument("--report_to", type=str, default=None, help="If logging to wandb.")
256
+
257
+
258
+ def get_args():
259
+ parser = argparse.ArgumentParser(description="Simple example of a training script for Mochi-1.")
260
+
261
+ _get_model_args(parser)
262
+ _get_dataset_args(parser)
263
+ _get_training_args(parser)
264
+ _get_validation_args(parser)
265
+ _get_optimizer_args(parser)
266
+ _get_configuration_args(parser)
267
+
268
+ return parser.parse_args()
Mochi/cli.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ # from diffusers import MochiPipeline
4
+ from pipeline_mochi_rgba import MochiPipeline
5
+ from diffusers.utils import export_to_video
6
+ import argparse
7
+ from rgba_utils import *
8
+ import numpy as np
9
+
10
+
11
+ def main(args):
12
+ # 1. load pipeline
13
+ pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16).to("cuda")
14
+ pipe.enable_vae_tiling()
15
+
16
+ # 2. define prompt and arguments
17
+ pipeline_args = {
18
+ "prompt": args.prompt,
19
+ "guidance_scale": args.guidance_scale,
20
+ "num_inference_steps": args.num_inference_steps,
21
+ "height": args.height,
22
+ "width": args.width,
23
+ "num_frames": args.num_frames,
24
+ "max_sequence_length": 256,
25
+ "output_type": "latent",
26
+ }
27
+
28
+ # 3. prepare rgbx utils
29
+ prepare_for_rgba_inference(
30
+ pipe.transformer,
31
+ device="cuda",
32
+ dtype=torch.bfloat16,
33
+ )
34
+
35
+ if args.lora_path is not None:
36
+ checkpoint = torch.load(args.lora_path, map_location="cpu")
37
+ processor_state_dict = checkpoint["state_dict"]
38
+ load_processor_state_dict(pipe.transformer, processor_state_dict)
39
+
40
+
41
+ # 4. inference
42
+ generator = torch.manual_seed(args.seed) if args.seed else None
43
+ frames_latents = pipe(**pipeline_args, generator=generator).frames
44
+
45
+ frames_latents_rgb, frames_latents_alpha = frames_latents.chunk(2, dim=2)
46
+
47
+ frames_rgb = decode_latents(pipe, frames_latents_rgb)
48
+ frames_alpha = decode_latents(pipe, frames_latents_alpha)
49
+
50
+ pooled_alpha = np.max(frames_alpha, axis=-1, keepdims=True)
51
+ frames_alpha_pooled = np.repeat(pooled_alpha, 3, axis=-1)
52
+ premultiplied_rgb = frames_rgb * frames_alpha_pooled
53
+
54
+ if os.path.exists(args.output_path) == False:
55
+ os.makedirs(args.output_path)
56
+
57
+ export_to_video(premultiplied_rgb[0], os.path.join(args.output_path, "rgb.mp4"), fps=args.fps)
58
+ export_to_video(frames_alpha_pooled[0], os.path.join(args.output_path, "alpha.mp4"), fps=args.fps)
59
+ export_to_video(frames_rgb[0], os.path.join(args.output_path, "original_rgb.mp4"), fps=args.fps)
60
+
61
+ if __name__ == "__main__":
62
+ parser = argparse.ArgumentParser(description="Generate a video from a text prompt")
63
+ parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
64
+ parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
65
+
66
+ parser.add_argument(
67
+ "--model_path", type=str, default="genmo/mochi-1-preview", help="Path of the pre-trained model use"
68
+ )
69
+ parser.add_argument("--output_path", type=str, default="./output", help="The path save generated video")
70
+ parser.add_argument("--guidance_scale", type=float, default=6, help="The scale for classifier-free guidance")
71
+ parser.add_argument("--num_inference_steps", type=int, default=64, help="Inference steps")
72
+ parser.add_argument("--num_frames", type=int, default=79, help="Number of steps for the inference process")
73
+ parser.add_argument("--width", type=int, default=848, help="Number of steps for the inference process")
74
+ parser.add_argument("--height", type=int, default=480, help="Number of steps for the inference process")
75
+ parser.add_argument("--fps", type=int, default=30, help="Number of steps for the inference process")
76
+ parser.add_argument("--seed", type=int, default=None, help="The seed for reproducibility")
77
+ args = parser.parse_args()
78
+
79
+ main(args)
Mochi/dataset_simple.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from
3
+ https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/dataset.py
4
+ """
5
+
6
+ from pathlib import Path
7
+
8
+ import click
9
+ import torch
10
+ from torch.utils.data import DataLoader, Dataset
11
+
12
+
13
+ def load_to_cpu(x):
14
+ return torch.load(x, map_location=torch.device("cpu"), weights_only=True)
15
+
16
+
17
+ class LatentEmbedDataset(Dataset):
18
+ def __init__(self, file_paths, repeat=1):
19
+ self.items = [
20
+ (Path(p).with_suffix(".latent.pt"), Path(p).with_suffix(".embed.pt"))
21
+ for p in file_paths
22
+ if Path(p).with_suffix(".latent.pt").is_file() and Path(p).with_suffix(".embed.pt").is_file()
23
+ ]
24
+ self.items = self.items * repeat
25
+ print(f"Loaded {len(self.items)}/{len(file_paths)} valid file pairs.")
26
+
27
+ def __len__(self):
28
+ return len(self.items)
29
+
30
+ def __getitem__(self, idx):
31
+ latent_path, embed_path = self.items[idx]
32
+ return load_to_cpu(latent_path), load_to_cpu(embed_path)
33
+
34
+
35
+ @click.command()
36
+ @click.argument("directory", type=click.Path(exists=True, file_okay=False))
37
+ def process_videos(directory):
38
+ dir_path = Path(directory)
39
+ mp4_files = [str(f) for f in dir_path.glob("**/*.mp4") if not f.name.endswith(".recon.mp4")]
40
+ assert mp4_files, f"No mp4 files found"
41
+
42
+ dataset = LatentEmbedDataset(mp4_files)
43
+ dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
44
+
45
+ for latents, embeds in dataloader:
46
+ print([(k, v.shape) for k, v in latents.items()])
47
+
48
+
49
+ if __name__ == "__main__":
50
+ process_videos()
Mochi/embed.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from:
3
+ https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/encode_videos.py
4
+ https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/embed_captions.py
5
+ """
6
+
7
+ import click
8
+ import torch
9
+ import torchvision
10
+ from pathlib import Path
11
+ from diffusers import AutoencoderKLMochi, MochiPipeline
12
+ from transformers import T5EncoderModel, T5Tokenizer
13
+ from tqdm.auto import tqdm
14
+
15
+
16
+ def encode_videos(model: torch.nn.Module, vid_path: Path, shape: str):
17
+ T, H, W = [int(s) for s in shape.split("x")]
18
+ assert (T - 1) % 6 == 0, "Expected T to be 1 mod 6"
19
+ video, _, metadata = torchvision.io.read_video(str(vid_path), output_format="THWC", pts_unit="secs")
20
+ fps = metadata["video_fps"]
21
+ video = video.permute(3, 0, 1, 2)
22
+ og_shape = video.shape
23
+ assert video.shape[2] == H, f"Expected {vid_path} to have height {H}, got {video.shape}"
24
+ assert video.shape[3] == W, f"Expected {vid_path} to have width {W}, got {video.shape}"
25
+ assert video.shape[1] >= T, f"Expected {vid_path} to have at least {T} frames, got {video.shape}"
26
+ if video.shape[1] > T:
27
+ video = video[:, :T]
28
+ print(f"Trimmed video from {og_shape[1]} to first {T} frames")
29
+ video = video.unsqueeze(0)
30
+ video = video.float() / 127.5 - 1.0
31
+ video = video.to(model.device)
32
+
33
+ assert video.ndim == 5
34
+
35
+ with torch.inference_mode():
36
+ with torch.autocast("cuda", dtype=torch.bfloat16):
37
+ ldist = model._encode(video)
38
+
39
+ torch.save(dict(ldist=ldist), vid_path.with_suffix(".latent.pt"))
40
+
41
+
42
+ @click.command()
43
+ @click.argument("output_dir", type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path))
44
+ @click.option(
45
+ "--model_id",
46
+ type=str,
47
+ help="Repo id. Should be genmo/mochi-1-preview",
48
+ default="genmo/mochi-1-preview",
49
+ )
50
+ @click.option("--shape", default="163x480x848", help="Shape of the video to encode")
51
+ @click.option("--overwrite", "-ow", is_flag=True, help="Overwrite existing latents and caption embeddings.")
52
+ def batch_process(output_dir: Path, model_id: Path, shape: str, overwrite: bool) -> None:
53
+ """Process all videos and captions in a directory using a single GPU."""
54
+ # comment out when running on unsupported hardware
55
+ torch.backends.cuda.matmul.allow_tf32 = True
56
+ torch.backends.cudnn.allow_tf32 = True
57
+
58
+ # Get all video paths
59
+ video_paths = list(output_dir.glob("**/*.mp4"))
60
+ if not video_paths:
61
+ print(f"No MP4 files found in {output_dir}")
62
+ return
63
+
64
+ text_paths = list(output_dir.glob("**/*.txt"))
65
+ if not text_paths:
66
+ print(f"No text files found in {output_dir}")
67
+ return
68
+
69
+ # load the models
70
+ vae = AutoencoderKLMochi.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32).to("cuda")
71
+ text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder")
72
+ tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer")
73
+ pipeline = MochiPipeline.from_pretrained(
74
+ model_id, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None, vae=None
75
+ ).to("cuda")
76
+
77
+ for idx, video_path in tqdm(enumerate(sorted(video_paths))):
78
+ print(f"Processing {video_path}")
79
+ try:
80
+ if video_path.with_suffix(".latent.pt").exists() and not overwrite:
81
+ print(f"Skipping {video_path}")
82
+ continue
83
+
84
+ # encode videos.
85
+ encode_videos(vae, vid_path=video_path, shape=shape)
86
+
87
+ # embed captions.
88
+ prompt_path = Path("/".join(str(video_path).split(".")[:-1]) + ".txt")
89
+ embed_path = prompt_path.with_suffix(".embed.pt")
90
+
91
+ if embed_path.exists() and not overwrite:
92
+ print(f"Skipping {prompt_path} - embeddings already exist")
93
+ continue
94
+
95
+ with open(prompt_path) as f:
96
+ text = f.read().strip()
97
+ with torch.inference_mode():
98
+ conditioning = pipeline.encode_prompt(prompt=[text])
99
+
100
+ conditioning = {"prompt_embeds": conditioning[0], "prompt_attention_mask": conditioning[1]}
101
+ torch.save(conditioning, embed_path)
102
+
103
+ except Exception as e:
104
+ import traceback
105
+
106
+ traceback.print_exc()
107
+ print(f"Error processing {video_path}: {str(e)}")
108
+
109
+
110
+ if __name__ == "__main__":
111
+ batch_process()
Mochi/pipeline_mochi_rgba.py ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Genmo and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import T5EncoderModel, T5TokenizerFast
21
+
22
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
23
+ from diffusers.loaders import Mochi1LoraLoaderMixin
24
+ from diffusers.models.autoencoders import AutoencoderKL
25
+ from diffusers.models.transformers import MochiTransformer3DModel
26
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
27
+ from diffusers.utils import (
28
+ is_torch_xla_available,
29
+ logging,
30
+ replace_example_docstring,
31
+ )
32
+ from diffusers.utils.torch_utils import randn_tensor
33
+ from diffusers.video_processor import VideoProcessor
34
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
35
+ from diffusers.pipelines.mochi.pipeline_output import MochiPipelineOutput
36
+
37
+
38
+ if is_torch_xla_available():
39
+ import torch_xla.core.xla_model as xm
40
+
41
+ XLA_AVAILABLE = True
42
+ else:
43
+ XLA_AVAILABLE = False
44
+
45
+
46
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
+
48
+ EXAMPLE_DOC_STRING = """
49
+ Examples:
50
+ ```py
51
+ >>> import torch
52
+ >>> from diffusers import MochiPipeline
53
+ >>> from diffusers.utils import export_to_video
54
+
55
+ >>> pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16)
56
+ >>> pipe.enable_model_cpu_offload()
57
+ >>> pipe.enable_vae_tiling()
58
+ >>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
59
+ >>> frames = pipe(prompt, num_inference_steps=28, guidance_scale=3.5).frames[0]
60
+ >>> export_to_video(frames, "mochi.mp4")
61
+ ```
62
+ """
63
+
64
+
65
+ def calculate_shift(
66
+ image_seq_len,
67
+ base_seq_len: int = 256,
68
+ max_seq_len: int = 4096,
69
+ base_shift: float = 0.5,
70
+ max_shift: float = 1.16,
71
+ ):
72
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
73
+ b = base_shift - m * base_seq_len
74
+ mu = image_seq_len * m + b
75
+ return mu
76
+
77
+
78
+ # from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
79
+ def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
80
+ if linear_steps is None:
81
+ linear_steps = num_steps // 2
82
+ linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
83
+ threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
84
+ quadratic_steps = num_steps - linear_steps
85
+ quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
86
+ linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
87
+ const = quadratic_coef * (linear_steps**2)
88
+ quadratic_sigma_schedule = [
89
+ quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
90
+ ]
91
+ sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule
92
+ sigma_schedule = [1.0 - x for x in sigma_schedule]
93
+ return sigma_schedule
94
+
95
+
96
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
97
+ def retrieve_timesteps(
98
+ scheduler,
99
+ num_inference_steps: Optional[int] = None,
100
+ device: Optional[Union[str, torch.device]] = None,
101
+ timesteps: Optional[List[int]] = None,
102
+ sigmas: Optional[List[float]] = None,
103
+ **kwargs,
104
+ ):
105
+ r"""
106
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
107
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
108
+
109
+ Args:
110
+ scheduler (`SchedulerMixin`):
111
+ The scheduler to get timesteps from.
112
+ num_inference_steps (`int`):
113
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
114
+ must be `None`.
115
+ device (`str` or `torch.device`, *optional*):
116
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
117
+ timesteps (`List[int]`, *optional*):
118
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
119
+ `num_inference_steps` and `sigmas` must be `None`.
120
+ sigmas (`List[float]`, *optional*):
121
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
122
+ `num_inference_steps` and `timesteps` must be `None`.
123
+
124
+ Returns:
125
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
126
+ second element is the number of inference steps.
127
+ """
128
+ if timesteps is not None and sigmas is not None:
129
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
130
+ if timesteps is not None:
131
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
132
+ if not accepts_timesteps:
133
+ raise ValueError(
134
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
135
+ f" timestep schedules. Please check whether you are using the correct scheduler."
136
+ )
137
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
138
+ timesteps = scheduler.timesteps
139
+ num_inference_steps = len(timesteps)
140
+ elif sigmas is not None:
141
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
142
+ if not accept_sigmas:
143
+ raise ValueError(
144
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
145
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
146
+ )
147
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
148
+ timesteps = scheduler.timesteps
149
+ num_inference_steps = len(timesteps)
150
+ else:
151
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
152
+ timesteps = scheduler.timesteps
153
+ return timesteps, num_inference_steps
154
+
155
+
156
+
157
+
158
+
159
+
160
+ def prepare_attention_mask(prompt_attention_mask, latents):
161
+
162
+ device = prompt_attention_mask.device
163
+
164
+ (_, _, num_frames, height, width) = latents.shape # shape of two modalities
165
+ seq_length = (height // 2) * (width // 2) * num_frames
166
+
167
+ rect_attention_mask = []
168
+ for prompt_attention_mask_i in prompt_attention_mask:
169
+ text_length = torch.sum(prompt_attention_mask_i).item()
170
+ total_length = text_length + seq_length
171
+
172
+ if text_length == 0:
173
+ rect_attention_mask.append(None)
174
+ else:
175
+ dense_mask = torch.ones((total_length, total_length), dtype=torch.bool)
176
+ dense_mask[seq_length:, seq_length // 2: seq_length] = False
177
+ rect_attention_mask.append(dense_mask.to(device))
178
+
179
+ return {
180
+ "prompt_attention_mask": prompt_attention_mask,
181
+ "rect_attention_mask": rect_attention_mask,
182
+ }
183
+
184
+
185
+ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
186
+ r"""
187
+ The mochi pipeline for text-to-video generation.
188
+
189
+ Reference: https://github.com/genmoai/models
190
+
191
+ Args:
192
+ transformer ([`MochiTransformer3DModel`]):
193
+ Conditional Transformer architecture to denoise the encoded video latents.
194
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
195
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
196
+ vae ([`AutoencoderKL`]):
197
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
198
+ text_encoder ([`T5EncoderModel`]):
199
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
200
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
201
+ tokenizer (`CLIPTokenizer`):
202
+ Tokenizer of class
203
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
204
+ tokenizer (`T5TokenizerFast`):
205
+ Second Tokenizer of class
206
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
207
+ """
208
+
209
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
210
+ _optional_components = []
211
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
212
+
213
+ def __init__(
214
+ self,
215
+ scheduler: FlowMatchEulerDiscreteScheduler,
216
+ vae: AutoencoderKL,
217
+ text_encoder: T5EncoderModel,
218
+ tokenizer: T5TokenizerFast,
219
+ transformer: MochiTransformer3DModel,
220
+ force_zeros_for_empty_prompt: bool = False,
221
+ ):
222
+ super().__init__()
223
+
224
+ self.register_modules(
225
+ vae=vae,
226
+ text_encoder=text_encoder,
227
+ tokenizer=tokenizer,
228
+ transformer=transformer,
229
+ scheduler=scheduler,
230
+ )
231
+ # TODO: determine these scaling factors from model parameters
232
+ self.vae_spatial_scale_factor = 8
233
+ self.vae_temporal_scale_factor = 6
234
+ self.patch_size = 2
235
+
236
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)
237
+ self.tokenizer_max_length = (
238
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 256
239
+ )
240
+ self.default_height = 480
241
+ self.default_width = 848
242
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
243
+
244
+ def _get_t5_prompt_embeds(
245
+ self,
246
+ prompt: Union[str, List[str]] = None,
247
+ num_videos_per_prompt: int = 1,
248
+ max_sequence_length: int = 256,
249
+ device: Optional[torch.device] = None,
250
+ dtype: Optional[torch.dtype] = None,
251
+ ):
252
+ device = device or self._execution_device
253
+ dtype = dtype or self.text_encoder.dtype
254
+
255
+ prompt = [prompt] if isinstance(prompt, str) else prompt
256
+ batch_size = len(prompt)
257
+
258
+ text_inputs = self.tokenizer(
259
+ prompt,
260
+ padding="max_length",
261
+ max_length=max_sequence_length,
262
+ truncation=True,
263
+ add_special_tokens=True,
264
+ return_tensors="pt",
265
+ )
266
+
267
+ text_input_ids = text_inputs.input_ids
268
+ prompt_attention_mask = text_inputs.attention_mask
269
+ prompt_attention_mask = prompt_attention_mask.bool().to(device)
270
+
271
+ # The original Mochi implementation zeros out empty negative prompts
272
+ # but this can lead to overflow when placing the entire pipeline under the autocast context
273
+ # adding this here so that we can enable zeroing prompts if necessary
274
+ if self.config.force_zeros_for_empty_prompt and (prompt == "" or prompt[-1] == ""):
275
+ text_input_ids = torch.zeros_like(text_input_ids, device=device)
276
+ prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device)
277
+
278
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
279
+
280
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
281
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
282
+ logger.warning(
283
+ "The following part of your input was truncated because `max_sequence_length` is set to "
284
+ f" {max_sequence_length} tokens: {removed_text}"
285
+ )
286
+
287
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
288
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
289
+
290
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
291
+ _, seq_len, _ = prompt_embeds.shape
292
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
293
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
294
+
295
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
296
+ prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
297
+
298
+ return prompt_embeds, prompt_attention_mask
299
+
300
+ # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
301
+ def encode_prompt(
302
+ self,
303
+ prompt: Union[str, List[str]],
304
+ negative_prompt: Optional[Union[str, List[str]]] = None,
305
+ do_classifier_free_guidance: bool = True,
306
+ num_videos_per_prompt: int = 1,
307
+ prompt_embeds: Optional[torch.Tensor] = None,
308
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
309
+ prompt_attention_mask: Optional[torch.Tensor] = None,
310
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
311
+ max_sequence_length: int = 256,
312
+ device: Optional[torch.device] = None,
313
+ dtype: Optional[torch.dtype] = None,
314
+ ):
315
+ r"""
316
+ Encodes the prompt into text encoder hidden states.
317
+
318
+ Args:
319
+ prompt (`str` or `List[str]`, *optional*):
320
+ prompt to be encoded
321
+ negative_prompt (`str` or `List[str]`, *optional*):
322
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
323
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
324
+ less than `1`).
325
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
326
+ Whether to use classifier free guidance or not.
327
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
328
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
329
+ prompt_embeds (`torch.Tensor`, *optional*):
330
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
331
+ provided, text embeddings will be generated from `prompt` input argument.
332
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
333
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
334
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
335
+ argument.
336
+ device: (`torch.device`, *optional*):
337
+ torch device
338
+ dtype: (`torch.dtype`, *optional*):
339
+ torch dtype
340
+ """
341
+ device = device or self._execution_device
342
+
343
+ prompt = [prompt] if isinstance(prompt, str) else prompt
344
+ if prompt is not None:
345
+ batch_size = len(prompt)
346
+ else:
347
+ batch_size = prompt_embeds.shape[0]
348
+
349
+ if prompt_embeds is None:
350
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
351
+ prompt=prompt,
352
+ num_videos_per_prompt=num_videos_per_prompt,
353
+ max_sequence_length=max_sequence_length,
354
+ device=device,
355
+ dtype=dtype,
356
+ )
357
+
358
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
359
+ negative_prompt = negative_prompt or ""
360
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
361
+
362
+ if prompt is not None and type(prompt) is not type(negative_prompt):
363
+ raise TypeError(
364
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
365
+ f" {type(prompt)}."
366
+ )
367
+ elif batch_size != len(negative_prompt):
368
+ raise ValueError(
369
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
370
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
371
+ " the batch size of `prompt`."
372
+ )
373
+
374
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
375
+ prompt=negative_prompt,
376
+ num_videos_per_prompt=num_videos_per_prompt,
377
+ max_sequence_length=max_sequence_length,
378
+ device=device,
379
+ dtype=dtype,
380
+ )
381
+
382
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
383
+
384
+ def check_inputs(
385
+ self,
386
+ prompt,
387
+ height,
388
+ width,
389
+ callback_on_step_end_tensor_inputs=None,
390
+ prompt_embeds=None,
391
+ negative_prompt_embeds=None,
392
+ prompt_attention_mask=None,
393
+ negative_prompt_attention_mask=None,
394
+ ):
395
+ if height % 8 != 0 or width % 8 != 0:
396
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
397
+
398
+ if callback_on_step_end_tensor_inputs is not None and not all(
399
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
400
+ ):
401
+ raise ValueError(
402
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
403
+ )
404
+
405
+ if prompt is not None and prompt_embeds is not None:
406
+ raise ValueError(
407
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
408
+ " only forward one of the two."
409
+ )
410
+ elif prompt is None and prompt_embeds is None:
411
+ raise ValueError(
412
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
413
+ )
414
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
415
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
416
+
417
+ if prompt_embeds is not None and prompt_attention_mask is None:
418
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
419
+
420
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
421
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
422
+
423
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
424
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
425
+ raise ValueError(
426
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
427
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
428
+ f" {negative_prompt_embeds.shape}."
429
+ )
430
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
431
+ raise ValueError(
432
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
433
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
434
+ f" {negative_prompt_attention_mask.shape}."
435
+ )
436
+
437
+ def enable_vae_slicing(self):
438
+ r"""
439
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
440
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
441
+ """
442
+ self.vae.enable_slicing()
443
+
444
+ def disable_vae_slicing(self):
445
+ r"""
446
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
447
+ computing decoding in one step.
448
+ """
449
+ self.vae.disable_slicing()
450
+
451
+ def enable_vae_tiling(self):
452
+ r"""
453
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
454
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
455
+ processing larger images.
456
+ """
457
+ self.vae.enable_tiling()
458
+
459
+ def disable_vae_tiling(self):
460
+ r"""
461
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
462
+ computing decoding in one step.
463
+ """
464
+ self.vae.disable_tiling()
465
+
466
+ def prepare_latents(
467
+ self,
468
+ batch_size,
469
+ num_channels_latents,
470
+ height,
471
+ width,
472
+ num_frames,
473
+ dtype,
474
+ device,
475
+ generator,
476
+ latents=None,
477
+ ):
478
+ height = height // self.vae_spatial_scale_factor
479
+ width = width // self.vae_spatial_scale_factor
480
+ num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1
481
+
482
+ shape = (batch_size, num_channels_latents, num_frames, height, width)
483
+
484
+ if latents is not None:
485
+ return latents.to(device=device, dtype=dtype)
486
+ if isinstance(generator, list) and len(generator) != batch_size:
487
+ raise ValueError(
488
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
489
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
490
+ )
491
+
492
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)
493
+ latents = latents.to(dtype)
494
+ return latents
495
+
496
+ @property
497
+ def guidance_scale(self):
498
+ return self._guidance_scale
499
+
500
+ @property
501
+ def do_classifier_free_guidance(self):
502
+ return self._guidance_scale > 1.0
503
+
504
+ @property
505
+ def num_timesteps(self):
506
+ return self._num_timesteps
507
+
508
+ @property
509
+ def attention_kwargs(self):
510
+ return self._attention_kwargs
511
+
512
+ @property
513
+ def interrupt(self):
514
+ return self._interrupt
515
+
516
+
517
+ @torch.no_grad()
518
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
519
+ def __call__(
520
+ self,
521
+ prompt: Union[str, List[str]] = None,
522
+ negative_prompt: Optional[Union[str, List[str]]] = None,
523
+ height: Optional[int] = None,
524
+ width: Optional[int] = None,
525
+ num_frames: int = 19,
526
+ num_inference_steps: int = 64,
527
+ timesteps: List[int] = None,
528
+ guidance_scale: float = 4.5,
529
+ num_videos_per_prompt: Optional[int] = 1,
530
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
531
+ latents: Optional[torch.Tensor] = None,
532
+ prompt_embeds: Optional[torch.Tensor] = None,
533
+ prompt_attention_mask: Optional[torch.Tensor] = None,
534
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
535
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
536
+ output_type: Optional[str] = "pil",
537
+ return_dict: bool = True,
538
+ attention_kwargs: Optional[Dict[str, Any]] = None,
539
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
540
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
541
+ max_sequence_length: int = 256,
542
+ ):
543
+ r"""
544
+ Function invoked when calling the pipeline for generation.
545
+
546
+ Args:
547
+ prompt (`str` or `List[str]`, *optional*):
548
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
549
+ instead.
550
+ height (`int`, *optional*, defaults to `self.default_height`):
551
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
552
+ width (`int`, *optional*, defaults to `self.default_width`):
553
+ The width in pixels of the generated image. This is set to 848 by default for the best results.
554
+ num_frames (`int`, defaults to `19`):
555
+ The number of video frames to generate
556
+ num_inference_steps (`int`, *optional*, defaults to 50):
557
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
558
+ expense of slower inference.
559
+ timesteps (`List[int]`, *optional*):
560
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
561
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
562
+ passed will be used. Must be in descending order.
563
+ guidance_scale (`float`, defaults to `4.5`):
564
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
565
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
566
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
567
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
568
+ usually at the expense of lower image quality.
569
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
570
+ The number of videos to generate per prompt.
571
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
572
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
573
+ to make generation deterministic.
574
+ latents (`torch.Tensor`, *optional*):
575
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
576
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
577
+ tensor will ge generated by sampling using the supplied random `generator`.
578
+ prompt_embeds (`torch.Tensor`, *optional*):
579
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
580
+ provided, text embeddings will be generated from `prompt` input argument.
581
+ prompt_attention_mask (`torch.Tensor`, *optional*):
582
+ Pre-generated attention mask for text embeddings.
583
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
584
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
585
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
586
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
587
+ Pre-generated attention mask for negative text embeddings.
588
+ output_type (`str`, *optional*, defaults to `"pil"`):
589
+ The output format of the generate image. Choose between
590
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
591
+ return_dict (`bool`, *optional*, defaults to `True`):
592
+ Whether or not to return a [`~pipelines.mochi.MochiPipelineOutput`] instead of a plain tuple.
593
+ attention_kwargs (`dict`, *optional*):
594
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
595
+ `self.processor` in
596
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
597
+ callback_on_step_end (`Callable`, *optional*):
598
+ A function that calls at the end of each denoising steps during the inference. The function is called
599
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
600
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
601
+ `callback_on_step_end_tensor_inputs`.
602
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
603
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
604
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
605
+ `._callback_tensor_inputs` attribute of your pipeline class.
606
+ max_sequence_length (`int` defaults to `256`):
607
+ Maximum sequence length to use with the `prompt`.
608
+
609
+ Examples:
610
+
611
+ Returns:
612
+ [`~pipelines.mochi.MochiPipelineOutput`] or `tuple`:
613
+ If `return_dict` is `True`, [`~pipelines.mochi.MochiPipelineOutput`] is returned, otherwise a `tuple`
614
+ is returned where the first element is a list with the generated images.
615
+ """
616
+
617
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
618
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
619
+
620
+ height = height or self.default_height
621
+ width = width or self.default_width
622
+
623
+ # 1. Check inputs. Raise error if not correct
624
+ self.check_inputs(
625
+ prompt=prompt,
626
+ height=height,
627
+ width=width,
628
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
629
+ prompt_embeds=prompt_embeds,
630
+ negative_prompt_embeds=negative_prompt_embeds,
631
+ prompt_attention_mask=prompt_attention_mask,
632
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
633
+ )
634
+
635
+ self._guidance_scale = guidance_scale
636
+ self._attention_kwargs = attention_kwargs
637
+ self._interrupt = False
638
+
639
+ # 2. Define call parameters
640
+ if prompt is not None and isinstance(prompt, str):
641
+ batch_size = 1
642
+ elif prompt is not None and isinstance(prompt, list):
643
+ batch_size = len(prompt)
644
+ else:
645
+ batch_size = prompt_embeds.shape[0]
646
+
647
+ device = self._execution_device
648
+ # 3. Prepare text embeddings
649
+ (
650
+ prompt_embeds,
651
+ prompt_attention_mask,
652
+ negative_prompt_embeds,
653
+ negative_prompt_attention_mask,
654
+ ) = self.encode_prompt(
655
+ prompt=prompt,
656
+ negative_prompt=negative_prompt,
657
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
658
+ num_videos_per_prompt=num_videos_per_prompt,
659
+ prompt_embeds=prompt_embeds,
660
+ negative_prompt_embeds=negative_prompt_embeds,
661
+ prompt_attention_mask=prompt_attention_mask,
662
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
663
+ max_sequence_length=max_sequence_length,
664
+ device=device,
665
+ )
666
+ # 4. Prepare latent variables
667
+ num_channels_latents = self.transformer.config.in_channels
668
+ latents = self.prepare_latents(
669
+ batch_size * num_videos_per_prompt,
670
+ num_channels_latents,
671
+ height,
672
+ width,
673
+ num_frames,
674
+ prompt_embeds.dtype,
675
+ device,
676
+ generator,
677
+ latents,
678
+ ).repeat(1,1,2,1,1)
679
+
680
+ if self.do_classifier_free_guidance:
681
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
682
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
683
+
684
+
685
+ # 5.5 Prepare attention rectification masks
686
+ all_attention_mask = prepare_attention_mask(prompt_attention_mask, latents)
687
+
688
+ # 5. Prepare timestep
689
+ # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
690
+ threshold_noise = 0.025
691
+ sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
692
+ sigmas = np.array(sigmas)
693
+
694
+ timesteps, num_inference_steps = retrieve_timesteps(
695
+ self.scheduler,
696
+ num_inference_steps,
697
+ device,
698
+ timesteps,
699
+ sigmas,
700
+ )
701
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
702
+ self._num_timesteps = len(timesteps)
703
+
704
+ # 6. Denoising loop
705
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
706
+ for i, t in enumerate(timesteps):
707
+ if self.interrupt:
708
+ continue
709
+
710
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
711
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
712
+ timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
713
+
714
+ noise_pred = self.transformer(
715
+ hidden_states=latent_model_input,
716
+ encoder_hidden_states=prompt_embeds,
717
+ timestep=timestep,
718
+ encoder_attention_mask=all_attention_mask,
719
+ attention_kwargs=attention_kwargs,
720
+ return_dict=False,
721
+ )[0]
722
+ # Mochi CFG + Sampling runs in FP32
723
+ noise_pred = noise_pred.to(torch.float32)
724
+
725
+ if self.do_classifier_free_guidance:
726
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
727
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
728
+
729
+ # compute the previous noisy sample x_t -> x_t-1
730
+ latents_dtype = latents.dtype
731
+ latents = self.scheduler.step(noise_pred, t, latents.to(torch.float32), return_dict=False)[0]
732
+ latents = latents.to(latents_dtype)
733
+
734
+ if latents.dtype != latents_dtype:
735
+ if torch.backends.mps.is_available():
736
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
737
+ latents = latents.to(latents_dtype)
738
+
739
+ if callback_on_step_end is not None:
740
+ callback_kwargs = {}
741
+ for k in callback_on_step_end_tensor_inputs:
742
+ callback_kwargs[k] = locals()[k]
743
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
744
+
745
+ latents = callback_outputs.pop("latents", latents)
746
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
747
+
748
+ # call the callback, if provided
749
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
750
+ progress_bar.update()
751
+
752
+ if XLA_AVAILABLE:
753
+ xm.mark_step()
754
+
755
+ if output_type == "latent":
756
+ video = latents
757
+ else:
758
+ # unscale/denormalize the latents
759
+ # denormalize with the mean and std if available and not None
760
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
761
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
762
+ if has_latents_mean and has_latents_std:
763
+ latents_mean = (
764
+ torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
765
+ )
766
+ latents_std = (
767
+ torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
768
+ )
769
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
770
+ else:
771
+ latents = latents / self.vae.config.scaling_factor
772
+
773
+ video = self.vae.decode(latents, return_dict=False)[0]
774
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
775
+
776
+ # Offload all models
777
+ self.maybe_free_model_hooks()
778
+
779
+ if not return_dict:
780
+ return (video,)
781
+
782
+ return MochiPipelineOutput(frames=video)
Mochi/prepare_dataset.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ GPU_ID=0
4
+ VIDEO_DIR=video-dataset-disney-organized
5
+ OUTPUT_DIR=videos_prepared
6
+ NUM_FRAMES=37
7
+ RESOLUTION=480x848
8
+
9
+ # Extract width and height from RESOLUTION
10
+ WIDTH=$(echo $RESOLUTION | cut -dx -f1)
11
+ HEIGHT=$(echo $RESOLUTION | cut -dx -f2)
12
+
13
+ python trim_and_crop_videos.py $VIDEO_DIR $OUTPUT_DIR --num_frames=$NUM_FRAMES --resolution=$RESOLUTION --force_upsample
14
+
15
+ CUDA_VISIBLE_DEVICES=$GPU_ID python embed.py $OUTPUT_DIR --shape=${NUM_FRAMES}x${WIDTH}x${HEIGHT}
Mochi/rgba_utils.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from typing import Any, Dict, Optional, Tuple, Union
6
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
7
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
8
+
9
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
10
+
11
+ @torch.no_grad()
12
+ def decode_latents(pipe, latents):
13
+ has_latents_mean = hasattr(pipe.vae.config, "latents_mean") and pipe.vae.config.latents_mean is not None
14
+ has_latents_std = hasattr(pipe.vae.config, "latents_std") and pipe.vae.config.latents_std is not None
15
+ if has_latents_mean and has_latents_std:
16
+ latents_mean = (
17
+ torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
18
+ )
19
+ latents_std = (
20
+ torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
21
+ )
22
+ latents = latents * latents_std / pipe.vae.config.scaling_factor + latents_mean
23
+ else:
24
+ latents = latents / pipe.vae.config.scaling_factor
25
+
26
+ video = pipe.vae.decode(latents, return_dict=False)[0]
27
+ video = pipe.video_processor.postprocess_video(video, output_type='np')
28
+
29
+ return video
30
+
31
+
32
+ class RGBALoRAMochiAttnProcessor:
33
+ """Attention processor used in Mochi."""
34
+ def __init__(self, device, dtype, lora_rank=128, lora_alpha=1.0, latent_dim=3072):
35
+ if not hasattr(F, "scaled_dot_product_attention"):
36
+ raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
37
+
38
+
39
+ # Initialize LoRA layers
40
+ self.lora_alpha = lora_alpha
41
+ self.lora_rank = lora_rank
42
+
43
+ # Helper function to create LoRA layers
44
+ def create_lora_layer(in_dim, mid_dim, out_dim, device=device, dtype=dtype):
45
+ # Define the LoRA layers
46
+ lora_a = nn.Linear(in_dim, mid_dim, bias=False, device=device, dtype=dtype)
47
+ lora_b = nn.Linear(mid_dim, out_dim, bias=False, device=device, dtype=dtype)
48
+
49
+ # Initialize lora_a with random parameters (default initialization)
50
+ nn.init.kaiming_uniform_(lora_a.weight, a=math.sqrt(5)) # or another suitable initialization
51
+
52
+ # Initialize lora_b with zero values
53
+ nn.init.zeros_(lora_b.weight)
54
+
55
+ lora_a.weight.requires_grad = True
56
+ lora_b.weight.requires_grad = True
57
+
58
+ # Combine the layers into a sequential module
59
+ return nn.Sequential(lora_a, lora_b)
60
+
61
+ self.to_q_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
62
+ self.to_k_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
63
+ self.to_v_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
64
+ self.to_out_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
65
+
66
+ def _apply_lora(self, hidden_states, seq_len, query, key, value, scaling):
67
+ """Applies LoRA updates to query, key, and value tensors."""
68
+ query_delta = self.to_q_lora(hidden_states).to(query.device)
69
+ query[:, -seq_len // 2:, :] += query_delta[:, -seq_len // 2:, :] * scaling
70
+
71
+ key_delta = self.to_k_lora(hidden_states).to(key.device)
72
+ key[:, -seq_len // 2:, :] += key_delta[:, -seq_len // 2:, :] * scaling
73
+
74
+ value_delta = self.to_v_lora(hidden_states).to(value.device)
75
+ value[:, -seq_len // 2:, :] += value_delta[:, -seq_len // 2:, :] * scaling
76
+
77
+ return query, key, value
78
+
79
+ def __call__(
80
+ self,
81
+ attn,
82
+ hidden_states: torch.Tensor,
83
+ encoder_hidden_states: torch.Tensor,
84
+ attention_mask: Optional[torch.Tensor] = None,
85
+ image_rotary_emb: Optional[torch.Tensor] = None,
86
+ ) -> torch.Tensor:
87
+ query = attn.to_q(hidden_states)
88
+ key = attn.to_k(hidden_states)
89
+ value = attn.to_v(hidden_states)
90
+
91
+ scaling = self.lora_alpha / self.lora_rank
92
+ sequence_length = query.size(1)
93
+ query, key, value = self._apply_lora(hidden_states, sequence_length, query, key, value, scaling)
94
+
95
+ query = query.unflatten(2, (attn.heads, -1))
96
+ key = key.unflatten(2, (attn.heads, -1))
97
+ value = value.unflatten(2, (attn.heads, -1))
98
+
99
+ if attn.norm_q is not None:
100
+ query = attn.norm_q(query)
101
+ if attn.norm_k is not None:
102
+ key = attn.norm_k(key)
103
+
104
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
105
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
106
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
107
+
108
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
109
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
110
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
111
+
112
+ if attn.norm_added_q is not None:
113
+ encoder_query = attn.norm_added_q(encoder_query)
114
+ if attn.norm_added_k is not None:
115
+ encoder_key = attn.norm_added_k(encoder_key)
116
+
117
+ if image_rotary_emb is not None:
118
+
119
+ def apply_rotary_emb(x, freqs_cos, freqs_sin):
120
+ x_even = x[..., 0::2].float()
121
+ x_odd = x[..., 1::2].float()
122
+
123
+ cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
124
+ sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
125
+
126
+ return torch.stack([cos, sin], dim=-1).flatten(-2)
127
+
128
+ query[:,sequence_length//2:] = apply_rotary_emb(query[:,sequence_length//2:], *image_rotary_emb)
129
+ query[:,:sequence_length//2] = apply_rotary_emb(query[:,:sequence_length//2], *image_rotary_emb)
130
+
131
+ key[:,sequence_length//2:] = apply_rotary_emb(key[:,sequence_length//2:], *image_rotary_emb)
132
+ key[:,:sequence_length//2] = apply_rotary_emb(key[:,:sequence_length//2], *image_rotary_emb)
133
+ # query = apply_rotary_emb(query, *image_rotary_emb)
134
+ # key = apply_rotary_emb(key, *image_rotary_emb)
135
+
136
+
137
+ query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
138
+ encoder_query, encoder_key, encoder_value = (
139
+ encoder_query.transpose(1, 2),
140
+ encoder_key.transpose(1, 2),
141
+ encoder_value.transpose(1, 2),
142
+ )
143
+
144
+ sequence_length = query.size(2)
145
+ encoder_sequence_length = encoder_query.size(2)
146
+ total_length = sequence_length + encoder_sequence_length
147
+
148
+ batch_size, heads, _, dim = query.shape
149
+
150
+ attn_outputs = []
151
+ prompt_attention_mask = attention_mask["prompt_attention_mask"]
152
+ rect_attention_mask = attention_mask["rect_attention_mask"]
153
+ for idx in range(batch_size):
154
+ mask = prompt_attention_mask[idx][None, :] # two components: attention mask and prompt mask
155
+ valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()
156
+
157
+ valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :]
158
+ valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :]
159
+ valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :]
160
+
161
+ valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2)
162
+ valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2)
163
+ valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2)
164
+
165
+ attn_output = F.scaled_dot_product_attention(
166
+ valid_query,
167
+ valid_key,
168
+ valid_value,
169
+ dropout_p=0.0,
170
+ attn_mask=rect_attention_mask[idx],
171
+ is_causal=False
172
+ )
173
+ valid_sequence_length = attn_output.size(2)
174
+ attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length))
175
+ attn_outputs.append(attn_output)
176
+
177
+ hidden_states = torch.cat(attn_outputs, dim=0)
178
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
179
+
180
+ hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
181
+ (sequence_length, encoder_sequence_length), dim=1
182
+ )
183
+
184
+ # linear proj
185
+ original_hidden_states = attn.to_out[0](hidden_states)
186
+ hidden_states_delta = self.to_out_lora(hidden_states).to(hidden_states.device)
187
+ original_hidden_states[:, -sequence_length // 2:, :] += hidden_states_delta[:, -sequence_length // 2:, :] * scaling
188
+ # dropout
189
+ hidden_states = attn.to_out[1](original_hidden_states)
190
+
191
+ if hasattr(attn, "to_add_out"):
192
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
193
+
194
+ return hidden_states, encoder_hidden_states
195
+
196
+ def prepare_for_rgba_inference(
197
+ model, device: torch.device, dtype: torch.dtype,
198
+ lora_rank: int = 128, lora_alpha: float = 1.0
199
+ ):
200
+
201
+ def custom_forward(self):
202
+ def forward(
203
+ hidden_states: torch.Tensor,
204
+ encoder_hidden_states: torch.Tensor,
205
+ timestep: torch.LongTensor,
206
+ encoder_attention_mask: torch.Tensor,
207
+ attention_kwargs: Optional[Dict[str, Any]] = None,
208
+ return_dict: bool = True,
209
+ ) -> torch.Tensor:
210
+ if attention_kwargs is not None:
211
+ attention_kwargs = attention_kwargs.copy()
212
+ lora_scale = attention_kwargs.pop("scale", 1.0)
213
+ else:
214
+ lora_scale = 1.0
215
+
216
+ if USE_PEFT_BACKEND:
217
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
218
+ scale_lora_layers(self, lora_scale)
219
+ else:
220
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
221
+ logger.warning(
222
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
223
+ )
224
+
225
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
226
+ p = self.config.patch_size
227
+
228
+ post_patch_height = height // p
229
+ post_patch_width = width // p
230
+
231
+ temb, encoder_hidden_states = self.time_embed(
232
+ timestep,
233
+ encoder_hidden_states,
234
+ encoder_attention_mask["prompt_attention_mask"],
235
+ hidden_dtype=hidden_states.dtype,
236
+ )
237
+
238
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
239
+ hidden_states = self.patch_embed(hidden_states)
240
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
241
+
242
+ image_rotary_emb = self.rope(
243
+ self.pos_frequencies,
244
+ num_frames // 2, # Identitical PE for RGB and Alpha
245
+ post_patch_height,
246
+ post_patch_width,
247
+ device=hidden_states.device,
248
+ dtype=torch.float32,
249
+ )
250
+
251
+ for i, block in enumerate(self.transformer_blocks):
252
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
253
+
254
+ def create_custom_forward(module):
255
+ def custom_forward(*inputs):
256
+ return module(*inputs)
257
+
258
+ return custom_forward
259
+
260
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
261
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
262
+ create_custom_forward(block),
263
+ hidden_states,
264
+ encoder_hidden_states,
265
+ temb,
266
+ encoder_attention_mask,
267
+ image_rotary_emb,
268
+ **ckpt_kwargs,
269
+ )
270
+ else:
271
+ hidden_states, encoder_hidden_states = block(
272
+ hidden_states=hidden_states,
273
+ encoder_hidden_states=encoder_hidden_states,
274
+ temb=temb,
275
+ encoder_attention_mask=encoder_attention_mask,
276
+ image_rotary_emb=image_rotary_emb,
277
+ )
278
+ hidden_states = self.norm_out(hidden_states, temb)
279
+ hidden_states = self.proj_out(hidden_states)
280
+
281
+ hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1)
282
+ hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
283
+ output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
284
+
285
+ if USE_PEFT_BACKEND:
286
+ # remove `lora_scale` from each PEFT layer
287
+ unscale_lora_layers(self, lora_scale)
288
+
289
+ if not return_dict:
290
+ return (output,)
291
+ return Transformer2DModelOutput(sample=output)
292
+ return forward
293
+
294
+ for _, block in enumerate(model.transformer_blocks):
295
+ attn_processor = RGBALoRAMochiAttnProcessor(
296
+ device=device,
297
+ dtype=dtype,
298
+ lora_rank=lora_rank,
299
+ lora_alpha=lora_alpha
300
+ )
301
+ # block.attn1.set_processor(attn_processor)
302
+ block.attn1.processor = attn_processor
303
+
304
+ model.forward = custom_forward(model)
305
+
306
+ def get_processor_state_dict(model):
307
+ """Save trainable parameters of processors to a checkpoint."""
308
+ processor_state_dict = {}
309
+
310
+ for index, block in enumerate(model.transformer_blocks):
311
+ if hasattr(block.attn1, "processor"):
312
+ processor = block.attn1.processor
313
+ for attr_name in ["to_q_lora", "to_k_lora", "to_v_lora", "to_out_lora"]:
314
+ if hasattr(processor, attr_name):
315
+ lora_layer = getattr(processor, attr_name)
316
+ for param_name, param in lora_layer.named_parameters():
317
+ key = f"block_{index}.{attr_name}.{param_name}"
318
+ processor_state_dict[key] = param.data.clone()
319
+
320
+ # torch.save({"processor_state_dict": processor_state_dict}, checkpoint_path)
321
+ # print(f"Processor state_dict saved to {checkpoint_path}")
322
+ return processor_state_dict
323
+
324
+ def load_processor_state_dict(model, processor_state_dict):
325
+ """Load trainable parameters of processors from a checkpoint."""
326
+ for index, block in enumerate(model.transformer_blocks):
327
+ if hasattr(block.attn1, "processor"):
328
+ processor = block.attn1.processor
329
+ for attr_name in ["to_q_lora", "to_k_lora", "to_v_lora", "to_out_lora"]:
330
+ if hasattr(processor, attr_name):
331
+ lora_layer = getattr(processor, attr_name)
332
+ for param_name, param in lora_layer.named_parameters():
333
+ key = f"block_{index}.{attr_name}.{param_name}"
334
+ if key in processor_state_dict:
335
+ param.data.copy_(processor_state_dict[key])
336
+ else:
337
+ raise KeyError(f"Missing key {key} in checkpoint.")
338
+
339
+ # Prepare training parameters
340
+ def get_processor_params(processor):
341
+ params = []
342
+ for attr_name in ["to_q_lora", "to_k_lora", "to_v_lora", "to_out_lora"]:
343
+ if hasattr(processor, attr_name):
344
+ lora_layer = getattr(processor, attr_name)
345
+ params.extend(p for p in lora_layer.parameters() if p.requires_grad)
346
+ return params
347
+
348
+ def get_all_processor_params(transformer):
349
+ all_params = []
350
+ for block in transformer.transformer_blocks:
351
+ if hasattr(block.attn1, "processor"):
352
+ processor = block.attn1.processor
353
+ all_params.extend(get_processor_params(processor))
354
+ return all_params
Mochi/train.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import gc
17
+ import random
18
+ from glob import glob
19
+ import math
20
+ import os
21
+ import torch.nn.functional as F
22
+ import numpy as np
23
+ from pathlib import Path
24
+ from typing import Any, Dict, Tuple, List
25
+
26
+ import torch
27
+ import wandb
28
+ from pipeline_mochi_rgba import *
29
+ from diffusers import FlowMatchEulerDiscreteScheduler, MochiTransformer3DModel
30
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
31
+ from diffusers.training_utils import cast_training_params
32
+ from diffusers.utils import export_to_video
33
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
34
+ from huggingface_hub import create_repo, upload_folder
35
+ from torch.utils.data import DataLoader
36
+ from tqdm.auto import tqdm
37
+
38
+
39
+ from args import get_args # isort:skip
40
+ from dataset_simple import LatentEmbedDataset
41
+
42
+ from utils import print_memory, reset_memory # isort:skip
43
+ from rgba_utils import *
44
+
45
+
46
+ # Taken from
47
+ # https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/train.py#L139
48
+ def get_cosine_annealing_lr_scheduler(
49
+ optimizer: torch.optim.Optimizer,
50
+ warmup_steps: int,
51
+ total_steps: int,
52
+ ):
53
+ def lr_lambda(step):
54
+ if step < warmup_steps:
55
+ return float(step) / float(max(1, warmup_steps))
56
+ else:
57
+ return 0.5 * (1 + np.cos(np.pi * (step - warmup_steps) / (total_steps - warmup_steps)))
58
+
59
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
60
+
61
+
62
+ def save_model_card(
63
+ repo_id: str,
64
+ videos=None,
65
+ base_model: str = None,
66
+ validation_prompt=None,
67
+ repo_folder=None,
68
+ fps=30,
69
+ ):
70
+ widget_dict = []
71
+ if videos is not None and len(videos) > 0:
72
+ for i, video in enumerate(videos):
73
+ export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4"), fps=fps)
74
+ widget_dict.append(
75
+ {
76
+ "text": validation_prompt if validation_prompt else " ",
77
+ "output": {"url": f"final_video_{i}.mp4"},
78
+ }
79
+ )
80
+
81
+ model_description = f"""
82
+ # Mochi-1 Preview LoRA Finetune
83
+
84
+ <Gallery />
85
+
86
+ ## Model description
87
+
88
+ This is a lora finetune of the Mochi-1 preview model `{base_model}`.
89
+
90
+ The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX and Mochi family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
91
+
92
+ ## Download model
93
+
94
+ [Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
95
+
96
+ ## Usage
97
+
98
+ Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
99
+
100
+ ```py
101
+ from diffusers import MochiPipeline
102
+ from diffusers.utils import export_to_video
103
+ import torch
104
+
105
+ pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview")
106
+ pipe.load_lora_weights("CHANGE_ME")
107
+ pipe.enable_model_cpu_offload()
108
+
109
+ with torch.autocast("cuda", torch.bfloat16):
110
+ video = pipe(
111
+ prompt="CHANGE_ME",
112
+ guidance_scale=6.0,
113
+ num_inference_steps=64,
114
+ height=480,
115
+ width=848,
116
+ max_sequence_length=256,
117
+ output_type="np"
118
+ ).frames[0]
119
+ export_to_video(video)
120
+ ```
121
+
122
+ For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
123
+
124
+ """
125
+ model_card = load_or_create_model_card(
126
+ repo_id_or_path=repo_id,
127
+ from_training=True,
128
+ license="apache-2.0",
129
+ base_model=base_model,
130
+ prompt=validation_prompt,
131
+ model_description=model_description,
132
+ widget=widget_dict,
133
+ )
134
+ tags = [
135
+ "text-to-video",
136
+ "diffusers-training",
137
+ "diffusers",
138
+ "lora",
139
+ "mochi-1-preview",
140
+ "mochi-1-preview-diffusers",
141
+ "template:sd-lora",
142
+ ]
143
+
144
+ model_card = populate_model_card(model_card, tags=tags)
145
+ model_card.save(os.path.join(repo_folder, "README.md"))
146
+
147
+
148
+ def log_validation(
149
+ pipe: MochiPipeline,
150
+ args: Dict[str, Any],
151
+ pipeline_args: Dict[str, Any],
152
+ step: int,
153
+ wandb_run: str = None,
154
+ is_final_validation: bool = False,
155
+ ):
156
+ print(
157
+ f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
158
+ )
159
+ phase_name = "test" if is_final_validation else "validation"
160
+
161
+ if not args.enable_model_cpu_offload:
162
+ pipe = pipe.to("cuda")
163
+
164
+ # run inference
165
+ generator = torch.manual_seed(args.seed) if args.seed else None
166
+
167
+ videos = []
168
+ with torch.autocast("cuda", torch.bfloat16, cache_enabled=False):
169
+ for _ in range(args.num_validation_videos):
170
+ video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
171
+ videos.append(video)
172
+
173
+ video_filenames = []
174
+ for i, video in enumerate(videos):
175
+ prompt = (
176
+ pipeline_args["prompt"][:25]
177
+ .replace(" ", "_")
178
+ .replace(" ", "_")
179
+ .replace("'", "_")
180
+ .replace('"', "_")
181
+ .replace("/", "_")
182
+ )
183
+ filename = os.path.join(args.output_dir, f"{phase_name}_{str(step)}_video_{i}_{prompt}.mp4")
184
+ export_to_video(video, filename, fps=30)
185
+ video_filenames.append(filename)
186
+
187
+ if wandb_run:
188
+ wandb.log(
189
+ {
190
+ phase_name: [
191
+ wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}", fps=30)
192
+ for i, filename in enumerate(video_filenames)
193
+ ]
194
+ }
195
+ )
196
+
197
+ return videos
198
+
199
+
200
+ # Adapted from the original code:
201
+ # https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/pipelines.py#L578
202
+ def cast_dit(model, dtype):
203
+ for name, module in model.named_modules():
204
+ if isinstance(module, torch.nn.Linear):
205
+ assert any(
206
+ n in name for n in ["time_embed", "proj_out", "blocks", "norm_out"]
207
+ ), f"Unexpected linear layer: {name}"
208
+ module.to(dtype=dtype)
209
+ elif isinstance(module, torch.nn.Conv2d):
210
+ module.to(dtype=dtype)
211
+ return model
212
+
213
+
214
+ def save_checkpoint(model, optimizer, lr_scheduler, global_step, checkpoint_path):
215
+ # lora_state_dict = get_peft_model_state_dict(model)
216
+ processor_state_dict = get_processor_state_dict(model)
217
+ torch.save(
218
+ {
219
+ "state_dict": processor_state_dict,
220
+ "optimizer": optimizer.state_dict(),
221
+ "lr_scheduler": lr_scheduler.state_dict(),
222
+ "global_step": global_step,
223
+ },
224
+ checkpoint_path,
225
+ )
226
+
227
+
228
+ class CollateFunction:
229
+ def __init__(self, caption_dropout: float = None) -> None:
230
+ self.caption_dropout = caption_dropout
231
+
232
+ def __call__(self, samples: List[Tuple[dict, torch.Tensor]]) -> Dict[str, torch.Tensor]:
233
+ ldists = torch.cat([data[0]["ldist"] for data in samples], dim=0)
234
+ z = DiagonalGaussianDistribution(ldists).sample()
235
+ assert torch.isfinite(z).all()
236
+
237
+ # Sample noise which we will add to the samples.
238
+ eps = torch.randn_like(z)
239
+ sigma = torch.rand(z.shape[:1], device="cpu", dtype=torch.float32)
240
+
241
+ prompt_embeds = torch.cat([data[1]["prompt_embeds"] for data in samples], dim=0)
242
+ prompt_attention_mask = torch.cat([data[1]["prompt_attention_mask"] for data in samples], dim=0)
243
+ if self.caption_dropout and random.random() < self.caption_dropout:
244
+ prompt_embeds.zero_()
245
+ prompt_attention_mask = prompt_attention_mask.long()
246
+ prompt_attention_mask.zero_()
247
+ prompt_attention_mask = prompt_attention_mask.bool()
248
+
249
+ return dict(
250
+ z=z, eps=eps, sigma=sigma, prompt_embeds=prompt_embeds, prompt_attention_mask=prompt_attention_mask
251
+ )
252
+
253
+
254
+ def main(args):
255
+ if not torch.cuda.is_available():
256
+ raise ValueError("Not supported without CUDA.")
257
+
258
+ if args.report_to == "wandb" and args.hub_token is not None:
259
+ raise ValueError(
260
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
261
+ " Please use `huggingface-cli login` to authenticate with the Hub."
262
+ )
263
+
264
+ # Handle the repository creation
265
+ if args.output_dir is not None:
266
+ os.makedirs(args.output_dir, exist_ok=True)
267
+
268
+ # Prepare models and scheduler
269
+ transformer = MochiTransformer3DModel.from_pretrained(
270
+ args.pretrained_model_name_or_path,
271
+ subfolder="transformer",
272
+ revision=args.revision,
273
+ variant=args.variant,
274
+ )
275
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
276
+ args.pretrained_model_name_or_path, subfolder="scheduler"
277
+ )
278
+
279
+ transformer.requires_grad_(False)
280
+ transformer.to("cuda")
281
+ if args.gradient_checkpointing:
282
+ transformer.enable_gradient_checkpointing()
283
+ if args.cast_dit:
284
+ transformer = cast_dit(transformer, torch.bfloat16)
285
+ if args.compile_dit:
286
+ transformer.compile()
287
+
288
+ prepare_for_rgba_inference(
289
+ model=transformer,
290
+ device=torch.device("cuda"),
291
+ dtype=torch.bfloat16,
292
+ # seq_length=seq_length,
293
+ )
294
+ processor_params = get_all_processor_params(transformer)
295
+
296
+ # Enable TF32 for faster training on Ampere GPUs,
297
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
298
+ if args.allow_tf32 and torch.cuda.is_available():
299
+ torch.backends.cuda.matmul.allow_tf32 = True
300
+
301
+ if args.scale_lr:
302
+ args.learning_rate = args.learning_rate * args.train_batch_size
303
+ # only upcast trainable parameters (LoRA) into fp32
304
+
305
+ if not isinstance(processor_params, list):
306
+ processor_params = [processor_params]
307
+ for m in processor_params:
308
+ for param in m:
309
+ # only upcast trainable parameters into fp32
310
+ if param.requires_grad:
311
+ param.data = param.to(torch.float32)
312
+
313
+ # Prepare optimizer
314
+ transformer_lora_parameters = processor_params # list(filter(lambda p: p.requires_grad, transformer.parameters()))
315
+ num_trainable_parameters = sum(param.numel() for param in transformer_lora_parameters)
316
+ optimizer = torch.optim.AdamW(transformer_lora_parameters, lr=args.learning_rate, weight_decay=args.weight_decay)
317
+
318
+ # Dataset and DataLoader
319
+ train_vids = list(sorted(glob(f"{args.data_root}/*.mp4")))
320
+ train_vids = [v for v in train_vids if not v.endswith(".recon.mp4")]
321
+ print(f"Found {len(train_vids)} training videos in {args.data_root}")
322
+ assert len(train_vids) > 0, f"No training data found in {args.data_root}"
323
+
324
+ collate_fn = CollateFunction(caption_dropout=args.caption_dropout)
325
+ train_dataset = LatentEmbedDataset(train_vids, repeat=1)
326
+ train_dataloader = DataLoader(
327
+ train_dataset,
328
+ collate_fn=collate_fn,
329
+ batch_size=args.train_batch_size,
330
+ num_workers=args.dataloader_num_workers,
331
+ pin_memory=args.pin_memory,
332
+ )
333
+
334
+ # LR scheduler and math around the number of training steps.
335
+ overrode_max_train_steps = False
336
+ num_update_steps_per_epoch = len(train_dataloader)
337
+ if args.max_train_steps is None:
338
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
339
+ overrode_max_train_steps = True
340
+
341
+ lr_scheduler = get_cosine_annealing_lr_scheduler(
342
+ optimizer, warmup_steps=args.lr_warmup_steps, total_steps=args.max_train_steps
343
+ )
344
+
345
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
346
+ num_update_steps_per_epoch = len(train_dataloader)
347
+ if overrode_max_train_steps:
348
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
349
+ # Afterwards we recalculate our number of training epochs
350
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
351
+
352
+ # We need to initialize the trackers we use, and also store our configuration.
353
+ # The trackers initializes automatically on the main process.
354
+ wandb_run = None
355
+ if args.report_to == "wandb":
356
+ tracker_name = args.tracker_name or "mochi-1-rgba-lora"
357
+ wandb_run = wandb.init(project=tracker_name, config=vars(args))
358
+
359
+ # Resume from checkpoint if specified
360
+ if args.resume_from_checkpoint:
361
+ checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu")
362
+ if "global_step" in checkpoint:
363
+ global_step = checkpoint["global_step"]
364
+ if "optimizer" in checkpoint:
365
+ optimizer.load_state_dict(checkpoint["optimizer"])
366
+ if "lr_scheduler" in checkpoint:
367
+ lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
368
+
369
+ # set_peft_model_state_dict(transformer, checkpoint["state_dict"]) # Luozhou: modify this line
370
+
371
+ processor_state_dict = checkpoint["state_dict"]
372
+ load_processor_state_dict(transformer, processor_state_dict)
373
+
374
+ print(f"Resuming from checkpoint: {args.resume_from_checkpoint}")
375
+ print(f"Resuming from global step: {global_step}")
376
+ else:
377
+ global_step = 0
378
+
379
+ print("===== Memory before training =====")
380
+ reset_memory("cuda")
381
+ print_memory("cuda")
382
+
383
+ # Train!
384
+ total_batch_size = args.train_batch_size
385
+ print("***** Running training *****")
386
+ print(f" Num trainable parameters = {num_trainable_parameters}")
387
+ print(f" Num examples = {len(train_dataset)}")
388
+ print(f" Num batches each epoch = {len(train_dataloader)}")
389
+ print(f" Num epochs = {args.num_train_epochs}")
390
+ print(f" Instantaneous batch size per device = {args.train_batch_size}")
391
+ print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
392
+ print(f" Total optimization steps = {args.max_train_steps}")
393
+
394
+ first_epoch = 0
395
+ progress_bar = tqdm(
396
+ range(0, args.max_train_steps),
397
+ initial=global_step,
398
+ desc="Steps",
399
+ )
400
+ for epoch in range(first_epoch, args.num_train_epochs):
401
+ transformer.train()
402
+
403
+ for step, batch in enumerate(train_dataloader):
404
+ with torch.no_grad():
405
+ z = batch["z"].to("cuda")
406
+ eps = batch["eps"].to("cuda")
407
+ sigma = batch["sigma"].to("cuda")
408
+ prompt_embeds = batch["prompt_embeds"].to("cuda")
409
+ prompt_attention_mask = batch["prompt_attention_mask"].to("cuda")
410
+
411
+ all_attention_mask = prepare_attention_mask(
412
+ prompt_attention_mask=prompt_attention_mask,
413
+ latents=z
414
+ )
415
+
416
+ sigma_bcthw = sigma[:, None, None, None, None] # [B, 1, 1, 1, 1]
417
+ # Add noise according to flow matching.
418
+ # zt = (1 - texp) * x + texp * z1
419
+ z_sigma = (1 - sigma_bcthw) * z + sigma_bcthw * eps
420
+ ut = z - eps
421
+
422
+ # (1 - sigma) because of
423
+ # https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py#L656
424
+ # Also, we operate on the scaled version of the `timesteps` directly in the `diffusers` implementation.
425
+ timesteps = (1 - sigma) * scheduler.config.num_train_timesteps
426
+
427
+ with torch.autocast("cuda", torch.bfloat16):
428
+ model_pred = transformer(
429
+ hidden_states=z_sigma,
430
+ encoder_hidden_states=prompt_embeds,
431
+ encoder_attention_mask=all_attention_mask,
432
+ timestep=timesteps,
433
+ return_dict=False,
434
+ )[0]
435
+ assert model_pred.shape == z.shape
436
+ loss = F.mse_loss(model_pred.float(), ut.float())
437
+ loss.backward()
438
+
439
+ optimizer.step()
440
+ optimizer.zero_grad()
441
+ lr_scheduler.step()
442
+
443
+ progress_bar.update(1)
444
+ global_step += 1
445
+
446
+ last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
447
+ logs = {"loss": loss.detach().item(), "lr": last_lr}
448
+ progress_bar.set_postfix(**logs)
449
+ if wandb_run:
450
+ wandb_run.log(logs, step=global_step)
451
+
452
+ if args.checkpointing_steps is not None and global_step % args.checkpointing_steps == 0:
453
+ print(f"Saving checkpoint at step {global_step}")
454
+ checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.pt")
455
+ save_checkpoint(
456
+ transformer,
457
+ optimizer,
458
+ lr_scheduler,
459
+ global_step,
460
+ checkpoint_path,
461
+ )
462
+
463
+ # if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
464
+ print("===== Memory before validation =====")
465
+ print_memory("cuda")
466
+
467
+ transformer.eval()
468
+ pipe = MochiPipeline.from_pretrained(
469
+ args.pretrained_model_name_or_path,
470
+ transformer=transformer,
471
+ scheduler=scheduler,
472
+ revision=args.revision,
473
+ variant=args.variant,
474
+ )
475
+
476
+ if args.enable_slicing:
477
+ pipe.vae.enable_slicing()
478
+ if args.enable_tiling:
479
+ pipe.vae.enable_tiling()
480
+ if args.enable_model_cpu_offload:
481
+ pipe.enable_model_cpu_offload()
482
+
483
+ # validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
484
+ validation_prompts = [
485
+ "A boy in a white shirt and shorts is seen bouncing a ball, isolated background",
486
+ ]
487
+ for validation_prompt in validation_prompts:
488
+ pipeline_args = {
489
+ "prompt": validation_prompt,
490
+ "guidance_scale": 6.0,
491
+ "num_frames": 37,
492
+ "num_inference_steps": 64,
493
+ "height": args.height,
494
+ "width": args.width,
495
+ "max_sequence_length": 256,
496
+ }
497
+ log_validation(
498
+ pipe=pipe,
499
+ args=args,
500
+ pipeline_args=pipeline_args,
501
+ step=global_step,
502
+ wandb_run=wandb_run,
503
+ )
504
+
505
+ print("===== Memory after validation =====")
506
+ print_memory("cuda")
507
+ reset_memory("cuda")
508
+
509
+ del pipe.text_encoder
510
+ del pipe.vae
511
+ del pipe
512
+ gc.collect()
513
+ torch.cuda.empty_cache()
514
+
515
+ transformer.train()
516
+
517
+ if global_step >= args.max_train_steps:
518
+ break
519
+
520
+ if global_step >= args.max_train_steps:
521
+ break
522
+
523
+ transformer.eval()
524
+
525
+ # saving lora weights
526
+ # transformer_lora_layers = get_peft_model_state_dict(transformer)
527
+ # MochiPipeline.save_lora_weights(save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers)
528
+
529
+ # Cleanup trained models to save memory
530
+ del transformer
531
+
532
+ gc.collect()
533
+ torch.cuda.empty_cache()
534
+
535
+ # Final test inference
536
+ # validation_outputs = []
537
+ # if args.validation_prompt and args.num_validation_videos > 0:
538
+ # print("===== Memory before testing =====")
539
+ # print_memory("cuda")
540
+ # reset_memory("cuda")
541
+
542
+ # pipe = MochiPipeline.from_pretrained(
543
+ # args.pretrained_model_name_or_path,
544
+ # revision=args.revision,
545
+ # variant=args.variant,
546
+ # )
547
+
548
+
549
+
550
+ # if args.enable_slicing:
551
+ # pipe.vae.enable_slicing()
552
+ # if args.enable_tiling:
553
+ # pipe.vae.enable_tiling()
554
+ # if args.enable_model_cpu_offload:
555
+ # pipe.enable_model_cpu_offload()
556
+
557
+ # # Load LoRA weights
558
+ # # lora_scaling = args.lora_alpha / args.rank
559
+ # # pipe.load_lora_weights(args.output_dir, adapter_name="mochi-lora")
560
+ # # pipe.set_adapters(["mochi-lora"], [lora_scaling])
561
+
562
+ # # Run inference
563
+ # validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
564
+ # for validation_prompt in validation_prompts:
565
+ # pipeline_args = {
566
+ # "prompt": validation_prompt,
567
+ # "guidance_scale": 6.0,
568
+ # "num_inference_steps": 64,
569
+ # "height": args.height,
570
+ # "width": args.width,
571
+ # "max_sequence_length": 256,
572
+ # }
573
+
574
+ # video = log_validation(
575
+ # pipe=pipe,
576
+ # args=args,
577
+ # pipeline_args=pipeline_args,
578
+ # epoch=epoch,
579
+ # wandb_run=wandb_run,
580
+ # is_final_validation=True,
581
+ # )
582
+ # validation_outputs.extend(video)
583
+
584
+ # print("===== Memory after testing =====")
585
+ # print_memory("cuda")
586
+ # reset_memory("cuda")
587
+ # torch.cuda.synchronize("cuda")
588
+
589
+
590
+
591
+ if __name__ == "__main__":
592
+ args = get_args()
593
+ main(args)
Mochi/train.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ export NCCL_P2P_DISABLE=1
3
+ export TORCH_NCCL_ENABLE_MONITORING=0
4
+
5
+ GPU_IDS="3"
6
+
7
+ DATA_ROOT="/hpc2hdd/home/lwang592/projects/finetrainers/training/data/video-matte-240k-rgb-prepared-f37"
8
+ MODEL="genmo/mochi-1-preview"
9
+ OUTPUT_PATH="mochi-rgba-lora-f37"
10
+
11
+ cmd="CUDA_VISIBLE_DEVICES=$GPU_IDS python train.py \
12
+ --pretrained_model_name_or_path $MODEL \
13
+ --cast_dit \
14
+ --data_root $DATA_ROOT \
15
+ --seed 42 \
16
+ --output_dir $OUTPUT_PATH \
17
+ --train_batch_size 2 \
18
+ --dataloader_num_workers 4 \
19
+ --pin_memory \
20
+ --caption_dropout 0.0 \
21
+ --max_train_steps 5000 \
22
+ --gradient_checkpointing \
23
+ --enable_slicing \
24
+ --enable_tiling \
25
+ --enable_model_cpu_offload \
26
+ --optimizer adamw \
27
+ --allow_tf32"
28
+
29
+ echo "Running command: $cmd"
30
+ eval $cmd
31
+ echo -ne "-------------------- Finished executing script --------------------\n\n"
Mochi/trim_and_crop_videos.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from:
3
+ https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/trim_and_crop_videos.py
4
+ """
5
+
6
+ from pathlib import Path
7
+ import shutil
8
+
9
+ import click
10
+ from moviepy.editor import VideoFileClip
11
+ from tqdm import tqdm
12
+
13
+
14
+ @click.command()
15
+ @click.argument("folder", type=click.Path(exists=True, dir_okay=True))
16
+ @click.argument("output_folder", type=click.Path(dir_okay=True))
17
+ @click.option("--num_frames", "-f", type=float, default=30, help="Number of frames")
18
+ @click.option("--resolution", "-r", type=str, default="480x848", help="Video resolution")
19
+ @click.option("--force_upsample", is_flag=True, help="Force upsample.")
20
+ def truncate_videos(folder, output_folder, num_frames, resolution, force_upsample):
21
+ """Truncate all MP4 and MOV files in FOLDER to specified number of frames and resolution"""
22
+ input_path = Path(folder)
23
+ output_path = Path(output_folder)
24
+ output_path.mkdir(parents=True, exist_ok=True)
25
+
26
+ # Parse target resolution
27
+ target_height, target_width = map(int, resolution.split("x"))
28
+
29
+ # Calculate duration
30
+ duration = (num_frames / 30) + 0.09
31
+
32
+ # Find all MP4 and MOV files
33
+ video_files = (
34
+ list(input_path.rglob("*.mp4"))
35
+ + list(input_path.rglob("*.MOV"))
36
+ + list(input_path.rglob("*.mov"))
37
+ + list(input_path.rglob("*.MP4"))
38
+ )
39
+
40
+ for file_path in tqdm(video_files):
41
+ try:
42
+ relative_path = file_path.relative_to(input_path)
43
+ output_file = output_path / relative_path.with_suffix(".mp4")
44
+ output_file.parent.mkdir(parents=True, exist_ok=True)
45
+
46
+ click.echo(f"Processing: {file_path}")
47
+ video = VideoFileClip(str(file_path))
48
+
49
+ # Skip if video is too short
50
+ if video.duration < duration:
51
+ click.echo(f"Skipping {file_path} as it is too short")
52
+ continue
53
+
54
+ # Skip if target resolution is larger than input
55
+ if target_width > video.w or target_height > video.h:
56
+ if force_upsample:
57
+ click.echo(
58
+ f"{file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}. So, upsampling the video."
59
+ )
60
+ video = video.resize(width=target_width, height=target_height)
61
+ else:
62
+ click.echo(
63
+ f"Skipping {file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}"
64
+ )
65
+ continue
66
+
67
+ # First truncate duration
68
+ truncated = video.subclip(0, duration)
69
+
70
+ # Calculate crop dimensions to maintain aspect ratio
71
+ target_ratio = target_width / target_height
72
+ current_ratio = truncated.w / truncated.h
73
+
74
+ if current_ratio > target_ratio:
75
+ # Video is wider than target ratio - crop width
76
+ new_width = int(truncated.h * target_ratio)
77
+ x1 = (truncated.w - new_width) // 2
78
+ final = truncated.crop(x1=x1, width=new_width).resize((target_width, target_height))
79
+ else:
80
+ # Video is taller than target ratio - crop height
81
+ new_height = int(truncated.w / target_ratio)
82
+ y1 = (truncated.h - new_height) // 2
83
+ final = truncated.crop(y1=y1, height=new_height).resize((target_width, target_height))
84
+
85
+ # Set output parameters for consistent MP4 encoding
86
+ output_params = {
87
+ "codec": "libx264",
88
+ "audio": False, # Disable audio
89
+ "preset": "medium", # Balance between speed and quality
90
+ "bitrate": "5000k", # Adjust as needed
91
+ }
92
+
93
+ # Set FPS to 30
94
+ final = final.set_fps(30)
95
+
96
+ # Check for a corresponding .txt file
97
+ txt_file_path = file_path.with_suffix(".txt")
98
+ if txt_file_path.exists():
99
+ output_txt_file = output_path / relative_path.with_suffix(".txt")
100
+ output_txt_file.parent.mkdir(parents=True, exist_ok=True)
101
+ shutil.copy(txt_file_path, output_txt_file)
102
+ click.echo(f"Copied {txt_file_path} to {output_txt_file}")
103
+ else:
104
+ # Print warning in bold yellow with a warning emoji
105
+ click.echo(
106
+ f"\033[1;33m⚠️ Warning: No caption found for {file_path}, using an empty caption. This may hurt fine-tuning quality.\033[0m"
107
+ )
108
+ output_txt_file = output_path / relative_path.with_suffix(".txt")
109
+ output_txt_file.parent.mkdir(parents=True, exist_ok=True)
110
+ output_txt_file.touch()
111
+
112
+ # Write the output file
113
+ final.write_videofile(str(output_file), **output_params)
114
+
115
+ # Clean up
116
+ video.close()
117
+ truncated.close()
118
+ final.close()
119
+
120
+ except Exception as e:
121
+ click.echo(f"\033[1;31m Error processing {file_path}: {str(e)}\033[0m", err=True)
122
+ raise
123
+
124
+
125
+ if __name__ == "__main__":
126
+ truncate_videos()
Mochi/utils.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import inspect
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ from accelerate import Accelerator
7
+ from accelerate.logging import get_logger
8
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
9
+ from diffusers.utils.torch_utils import is_compiled_module
10
+
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ def get_optimizer(
16
+ params_to_optimize,
17
+ optimizer_name: str = "adam",
18
+ learning_rate: float = 1e-3,
19
+ beta1: float = 0.9,
20
+ beta2: float = 0.95,
21
+ beta3: float = 0.98,
22
+ epsilon: float = 1e-8,
23
+ weight_decay: float = 1e-4,
24
+ prodigy_decouple: bool = False,
25
+ prodigy_use_bias_correction: bool = False,
26
+ prodigy_safeguard_warmup: bool = False,
27
+ use_8bit: bool = False,
28
+ use_4bit: bool = False,
29
+ use_torchao: bool = False,
30
+ use_deepspeed: bool = False,
31
+ use_cpu_offload_optimizer: bool = False,
32
+ offload_gradients: bool = False,
33
+ ) -> torch.optim.Optimizer:
34
+ optimizer_name = optimizer_name.lower()
35
+
36
+ # Use DeepSpeed optimzer
37
+ if use_deepspeed:
38
+ from accelerate.utils import DummyOptim
39
+
40
+ return DummyOptim(
41
+ params_to_optimize,
42
+ lr=learning_rate,
43
+ betas=(beta1, beta2),
44
+ eps=epsilon,
45
+ weight_decay=weight_decay,
46
+ )
47
+
48
+ if use_8bit and use_4bit:
49
+ raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.")
50
+
51
+ if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer:
52
+ try:
53
+ import torchao
54
+
55
+ torchao.__version__
56
+ except ImportError:
57
+ raise ImportError(
58
+ "To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`."
59
+ )
60
+
61
+ if not use_torchao and use_4bit:
62
+ raise ValueError("4-bit Optimizers are only supported with torchao.")
63
+
64
+ # Optimizer creation
65
+ supported_optimizers = ["adam", "adamw", "prodigy", "came"]
66
+ if optimizer_name not in supported_optimizers:
67
+ logger.warning(
68
+ f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`."
69
+ )
70
+ optimizer_name = "adamw"
71
+
72
+ if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
73
+ raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.")
74
+
75
+ if use_8bit:
76
+ try:
77
+ import bitsandbytes as bnb
78
+ except ImportError:
79
+ raise ImportError(
80
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
81
+ )
82
+
83
+ if optimizer_name == "adamw":
84
+ if use_torchao:
85
+ from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
86
+
87
+ optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
88
+ else:
89
+ optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
90
+
91
+ init_kwargs = {
92
+ "betas": (beta1, beta2),
93
+ "eps": epsilon,
94
+ "weight_decay": weight_decay,
95
+ }
96
+
97
+ elif optimizer_name == "adam":
98
+ if use_torchao:
99
+ from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
100
+
101
+ optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam
102
+ else:
103
+ optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam
104
+
105
+ init_kwargs = {
106
+ "betas": (beta1, beta2),
107
+ "eps": epsilon,
108
+ "weight_decay": weight_decay,
109
+ }
110
+
111
+ elif optimizer_name == "prodigy":
112
+ try:
113
+ import prodigyopt
114
+ except ImportError:
115
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
116
+
117
+ optimizer_class = prodigyopt.Prodigy
118
+
119
+ if learning_rate <= 0.1:
120
+ logger.warning(
121
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
122
+ )
123
+
124
+ init_kwargs = {
125
+ "lr": learning_rate,
126
+ "betas": (beta1, beta2),
127
+ "beta3": beta3,
128
+ "eps": epsilon,
129
+ "weight_decay": weight_decay,
130
+ "decouple": prodigy_decouple,
131
+ "use_bias_correction": prodigy_use_bias_correction,
132
+ "safeguard_warmup": prodigy_safeguard_warmup,
133
+ }
134
+
135
+ elif optimizer_name == "came":
136
+ try:
137
+ import came_pytorch
138
+ except ImportError:
139
+ raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`")
140
+
141
+ optimizer_class = came_pytorch.CAME
142
+
143
+ init_kwargs = {
144
+ "lr": learning_rate,
145
+ "eps": (1e-30, 1e-16),
146
+ "betas": (beta1, beta2, beta3),
147
+ "weight_decay": weight_decay,
148
+ }
149
+
150
+ if use_cpu_offload_optimizer:
151
+ from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
152
+
153
+ if "fused" in inspect.signature(optimizer_class.__init__).parameters:
154
+ init_kwargs.update({"fused": True})
155
+
156
+ optimizer = CPUOffloadOptimizer(
157
+ params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs
158
+ )
159
+ else:
160
+ optimizer = optimizer_class(params_to_optimize, **init_kwargs)
161
+
162
+ return optimizer
163
+
164
+
165
+ def get_gradient_norm(parameters):
166
+ norm = 0
167
+ for param in parameters:
168
+ if param.grad is None:
169
+ continue
170
+ local_norm = param.grad.detach().data.norm(2)
171
+ norm += local_norm.item() ** 2
172
+ norm = norm**0.5
173
+ return norm
174
+
175
+
176
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
177
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
178
+ tw = tgt_width
179
+ th = tgt_height
180
+ h, w = src
181
+ r = h / w
182
+ if r > (th / tw):
183
+ resize_height = th
184
+ resize_width = int(round(th / h * w))
185
+ else:
186
+ resize_width = tw
187
+ resize_height = int(round(tw / w * h))
188
+
189
+ crop_top = int(round((th - resize_height) / 2.0))
190
+ crop_left = int(round((tw - resize_width) / 2.0))
191
+
192
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
193
+
194
+
195
+ def prepare_rotary_positional_embeddings(
196
+ height: int,
197
+ width: int,
198
+ num_frames: int,
199
+ vae_scale_factor_spatial: int = 8,
200
+ patch_size: int = 2,
201
+ patch_size_t: int = None,
202
+ attention_head_dim: int = 64,
203
+ device: Optional[torch.device] = None,
204
+ base_height: int = 480,
205
+ base_width: int = 720,
206
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
207
+ grid_height = height // (vae_scale_factor_spatial * patch_size)
208
+ grid_width = width // (vae_scale_factor_spatial * patch_size)
209
+ base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
210
+ base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
211
+
212
+ if patch_size_t is None:
213
+ # CogVideoX 1.0
214
+ grid_crops_coords = get_resize_crop_region_for_grid(
215
+ (grid_height, grid_width), base_size_width, base_size_height
216
+ )
217
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
218
+ embed_dim=attention_head_dim,
219
+ crops_coords=grid_crops_coords,
220
+ grid_size=(grid_height, grid_width),
221
+ temporal_size=num_frames,
222
+ )
223
+ else:
224
+ # CogVideoX 1.5
225
+ base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t
226
+
227
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
228
+ embed_dim=attention_head_dim,
229
+ crops_coords=None,
230
+ grid_size=(grid_height, grid_width),
231
+ temporal_size=base_num_frames,
232
+ grid_type="slice",
233
+ max_size=(base_size_height, base_size_width),
234
+ )
235
+
236
+ freqs_cos = freqs_cos.to(device=device)
237
+ freqs_sin = freqs_sin.to(device=device)
238
+ return freqs_cos, freqs_sin
239
+
240
+
241
+ def reset_memory(device: Union[str, torch.device]) -> None:
242
+ gc.collect()
243
+ torch.cuda.empty_cache()
244
+ torch.cuda.reset_peak_memory_stats(device)
245
+ torch.cuda.reset_accumulated_memory_stats(device)
246
+
247
+
248
+ def print_memory(device: Union[str, torch.device]) -> None:
249
+ memory_allocated = torch.cuda.memory_allocated(device) / 1024**3
250
+ max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3
251
+ max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
252
+ print(f"{memory_allocated=:.3f} GB")
253
+ print(f"{max_memory_allocated=:.3f} GB")
254
+ print(f"{max_memory_reserved=:.3f} GB")
255
+
256
+
257
+ def unwrap_model(accelerator: Accelerator, model):
258
+ model = accelerator.unwrap_model(model)
259
+ model = model._orig_mod if is_compiled_module(model) else model
260
+ return model
README.md CHANGED
@@ -1,12 +1,161 @@
1
- ---
2
- title: TransPixelerTest
3
- emoji: 😻
4
- colorFrom: gray
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 5.35.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: TransPixelerTest
3
+ app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 5.35.0
6
+ ---
7
+ ## TransPixeler: Advancing Text-to-Video Generation with Transparency (CVPR2025)
8
+ <br>
9
+ <a href="https://arxiv.org/abs/2501.03006"><img src='https://img.shields.io/badge/arXiv-2501.03006-b31b1b.svg'></a>
10
+ <a href='https://wileewang.github.io/TransPixeler'><img src='https://img.shields.io/badge/Project_Page-TransPixeler-blue'></a>
11
+ <a href='https://huggingface.co/spaces/wileewang/TransPixar'><img src='https://img.shields.io/badge/HuggingFace-TransPixeler-yellow'></a>
12
+ <a href="https://discord.gg/7Xds3Qjr"><img src="https://img.shields.io/badge/Discord-join-blueviolet?logo=discord&amp"></a>
13
+ <a href="https://github.com/wileewang/TransPixar/blob/main/wechat_group.jpg"><img src="https://img.shields.io/badge/Wechat-Join-green?logo=wechat&amp"></a>
14
+ <a href='https://openbayes.com/console/public/tutorials/tKhPalKrDb9'><img src='https://img.shields.io/badge/Demo-OpenBayes贝式计算-blue'></a>
15
+ <br>
16
+
17
+ [Luozhou Wang*](https://wileewang.github.io/),
18
+ [Yijun Li**](https://yijunmaverick.github.io/),
19
+ [Zhifei Chen](),
20
+ [Jui-Hsien Wang](http://juiwang.com/),
21
+ [Zhifei Zhang](https://zzutk.github.io/),
22
+ [He Zhang](https://sites.google.com/site/hezhangsprinter),
23
+ [Zhe Lin](https://sites.google.com/site/zhelin625/home),
24
+ [Ying-Cong Chen†](https://www.yingcong.me)
25
+
26
+ HKUST(GZ), HKUST, Adobe Research.
27
+
28
+ \* Internship Project
29
+ \** Project Lead
30
+ † Corresponding Author
31
+
32
+ Text-to-video generative models have made significant strides, enabling diverse applications in entertainment, advertising, and education. However, generating RGBA video, which includes alpha channels for transparency, remains a challenge due to limited datasets and the difficulty of adapting existing models. Alpha channels are crucial for visual effects (VFX), allowing transparent elements like smoke and reflections to blend seamlessly into scenes.
33
+ We introduce TransPixar, a method to extend pretrained video models for RGBA generation while retaining the original RGB capabilities. TransPixar leverages a diffusion transformer (DiT) architecture, incorporating alpha-specific tokens and using LoRA-based fine-tuning to jointly generate RGB and alpha channels with high consistency. By optimizing attention mechanisms, TransPixeler preserves the strengths of the original RGB model and achieves strong alignment between RGB and alpha channels despite limited training data.
34
+ Our approach effectively generates diverse and consistent RGBA videos, advancing the possibilities for VFX and interactive content creation.
35
+
36
+ <!-- insert a teaser gif -->
37
+ <!-- <img src="assets/mi.gif" width="640" /> -->
38
+
39
+
40
+
41
+ ## 📰 News
42
+
43
+ - **[2025.04.28]** We have introduced a new development branch [`wan`](https://github.com/wileewang/TransPixar/tree/wan) that integrates the [Wan2.1](https://github.com/Wan-Video/Wan2.1) video generation model to support **joint generation** tasks. This branch includes training code tailored for generating both RGB and associated modalities (e.g., segmentation maps, alpha masks) from a shared text prompt.
44
+
45
+ - **[2025.02.26]** **TransPixeler** is accepted by CVPR 2025! See you in Nashville!
46
+
47
+ - **[2025.01.19]** We've renamed our project from **TransPixar** to **TransPixeler**!!
48
+
49
+ - **[2025.01.17]** We’ve created a [Discord group](https://discord.gg/7Xds3Qjr) and a [WeChat group](https://github.com/wileewang/TransPixar/blob/main/wechat_group.jpg)! Everyone is welcome to join for discussions and collaborations.
50
+
51
+ - **[2025.01.14]** Added new tasks to the repository's roadmap, including support for Hunyuan and LTX video models, and ComfyUI integration.
52
+
53
+ - **[2025.01.07]** Released project page, arXiv paper, inference code, and Hugging Face demo.
54
+
55
+
56
+
57
+
58
+ ## 🔥 New Branch for Joint Generation with Wan2.1
59
+
60
+ We have introduced a new development branch [`wan`](https://github.com/wileewang/TransPixar/tree/wan) that integrates the [Wan2.1](https://github.com/Wan-Video/Wan2.1) video generation model to support **joint generation** tasks.
61
+
62
+ In the `wan` branch, we have developed and released training code tailored for joint generation scenarios, enabling the simultaneous generation of RGB videos and associated modalities (e.g., segmentation maps, alpha masks) from a shared text prompt.
63
+
64
+ **Key features of the `wan` branch:**
65
+ - **Integration of Wan2.1**: Leverages the capabilities of the Wan2.1 video generation model for enhanced performance.
66
+ - **Joint Generation Support**: Facilitates the concurrent generation of RGB and paired modality videos.
67
+ - **Dataset Structure**: Expects each sample to include:
68
+ - A primary video file (`001.mp4`) representing the RGB content.
69
+ - A paired secondary video file (`001_seg.mp4`) with a fixed `_seg` suffix, representing the associated modality.
70
+ - A caption text file (`001.txt`) with the same base name as the primary video.
71
+ - **Periodic Evaluation**: Supports periodic video sampling during training by setting `eval_every_step` or `eval_every_epoch` in the configuration.
72
+ - **Customized Pipelines**: Offers tailored training and inference pipelines designed specifically for joint generation tasks.
73
+
74
+ 👉 To utilize the joint generation features, please checkout the [`wan`](https://github.com/wileewang/TransPixar/tree/wan) branch.
75
+
76
+
77
+
78
+
79
+ ## Contents
80
+
81
+ * [Installation](#installation)
82
+ * [TransPixar LoRA Weights](#transpixar-lora-hub)
83
+ * [Training](#training)
84
+ * [Inference](#inference)
85
+ * [Acknowledgement](#acknowledgement)
86
+ * [Citation](#citation)
87
+
88
+
89
+
90
+ ## Installation
91
+
92
+ ```bash
93
+ # For the main branch
94
+ conda create -n TransPixeler python=3.10
95
+ conda activate TransPixeler
96
+ pip install -r requirements.txt
97
+ ```
98
+
99
+ **Note:**
100
+ If you want to use the **Wan2.1 model**, please first checkout the `wan` branch:
101
+
102
+ ```bash
103
+ git checkout wan
104
+ ```
105
+
106
+ ## TransPixeler LoRA Weights
107
+
108
+ Our pipeline is designed to support various video tasks, including Text-to-RGBA Video, Image-to-RGBA Video.
109
+
110
+ We provide the following pre-trained LoRA weights:
111
+
112
+ | Task | Base Model | Frames | LoRA weights | Inference VRAM |
113
+ |---------------|---------------------------------------------------------------|--------|--------------------------------------------------------------------|----------------|
114
+ | T2V + RGBA | [THUDM/CogVideoX-5B](https://huggingface.co/THUDM/CogVideoX-5b) | 49 | [link](https://huggingface.co/wileewang/TransPixar/blob/main/cogvideox_rgba_lora.safetensors) | ~24GB |
115
+
116
+
117
+ ## Training - RGB + Alpha Joint Generation
118
+ We have open-sourced the training code for **Mochi** on RGBA joint generation. Please refer to the [Mochi README](Mochi/README.md) for details.
119
+
120
+
121
+ ## Inference - Gradio Demo
122
+ In addition to the [Hugging Face online demo](https://huggingface.co/spaces/wileewang/TransPixar), users can also launch a local inference demo based on CogVideoX-5B by running the following command:
123
+
124
+ ```bash
125
+ python app.py
126
+ ```
127
+
128
+ ## Inference - Command Line Interface (CLI)
129
+ To generate RGBA videos, navigate to the corresponding directory for the video model and execute the following command:
130
+ ```bash
131
+ python cli.py \
132
+ --lora_path /path/to/lora \
133
+ --prompt "..."
134
+ ```
135
+
136
+ ---
137
+
138
+ ## Acknowledgement
139
+
140
+ * [finetrainers](https://github.com/a-r-r-o-w/finetrainers): We followed their implementation of Mochi training and inference.
141
+ * [CogVideoX](https://github.com/THUDM/CogVideo): We followed their implementation of CogVideoX training and inference.
142
+
143
+ We are grateful for their exceptional work and generous contribution to the open-source community.
144
+
145
+ ## Citation
146
+
147
+ ```bibtex
148
+ @misc{wang2025transpixeler,
149
+ title={TransPixeler: Advancing Text-to-Video Generation with Transparency},
150
+ author={Luozhou Wang and Yijun Li and Zhifei Chen and Jui-Hsien Wang and Zhifei Zhang and He Zhang and Zhe Lin and Ying-Cong Chen},
151
+ year={2025},
152
+ eprint={2501.03006},
153
+ archivePrefix={arXiv},
154
+ primaryClass={cs.CV},
155
+ url={https://arxiv.org/abs/2501.03006},
156
+ }
157
+ ```
158
+
159
+ ## Star History
160
+
161
+ [![Star History Chart](https://api.star-history.com/svg?repos=wileewang/TransPixeler&type=Date)](https://star-history.com/#wileewang/TransPixeler&Date)
app.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ THis is the main file for the gradio web demo. It uses the CogVideoX-5B model to generate videos gradio web demo.
3
+ set environment variable OPENAI_API_KEY to use the OpenAI API to enhance the prompt.
4
+ Usage:
5
+ OpenAI_API_KEY=your_openai_api_key OPENAI_BASE_URL=https://api.openai.com/v1 python inference/gradio_web_demo.py
6
+ """
7
+
8
+ import math
9
+ import os
10
+ import random
11
+ import threading
12
+ import time
13
+
14
+ import cv2
15
+ import tempfile
16
+ import imageio_ffmpeg
17
+ import gradio as gr
18
+ import torch
19
+ from PIL import Image
20
+ # from diffusers import (
21
+ # CogVideoXPipeline,
22
+ # CogVideoXDPMScheduler,
23
+ # CogVideoXVideoToVideoPipeline,
24
+ # CogVideoXImageToVideoPipeline,
25
+ # CogVideoXTransformer3DModel,
26
+ # )
27
+ from typing import Union, List
28
+ from CogVideoX.pipeline_rgba import CogVideoXPipeline
29
+ from CogVideoX.rgba_utils import *
30
+ from diffusers import CogVideoXDPMScheduler
31
+
32
+ from diffusers.utils import load_video, load_image, export_to_video
33
+ from datetime import datetime, timedelta
34
+
35
+ from diffusers.image_processor import VaeImageProcessor
36
+ import moviepy.editor as mp
37
+ import numpy as np
38
+ from huggingface_hub import hf_hub_download, snapshot_download
39
+ import gc
40
+
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+
43
+ # hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran")
44
+ hf_hub_download(repo_id="wileewang/TransPixar", filename="cogvideox_rgba_lora.safetensors", local_dir="model_cogvideox_rgba_lora")
45
+ # snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
46
+
47
+ pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5B", torch_dtype=torch.bfloat16)
48
+ # pipe.enable_sequential_cpu_offload()
49
+ pipe.vae.enable_slicing()
50
+ pipe.vae.enable_tiling()
51
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
52
+ seq_length = 2 * (
53
+ (480 // pipe.vae_scale_factor_spatial // 2)
54
+ * (720 // pipe.vae_scale_factor_spatial // 2)
55
+ * ((13 - 1) // pipe.vae_scale_factor_temporal + 1)
56
+ )
57
+ prepare_for_rgba_inference(
58
+ pipe.transformer,
59
+ rgba_weights_path="model_cogvideox_rgba_lora/cogvideox_rgba_lora.safetensors",
60
+ device="cuda",
61
+ dtype=torch.bfloat16,
62
+ text_length=226,
63
+ seq_length=seq_length, # this is for the creation of attention mask.
64
+ )
65
+
66
+ # pipe.transformer.to(memory_format=torch.channels_last)
67
+ # pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
68
+ # pipe_image.transformer.to(memory_format=torch.channels_last)
69
+ # pipe_image.transformer = torch.compile(pipe_image.transformer, mode="max-autotune", fullgraph=True)
70
+
71
+ os.makedirs("./output", exist_ok=True)
72
+ os.makedirs("./gradio_tmp", exist_ok=True)
73
+
74
+ # upscale_model = utils.load_sd_upscale("model_real_esran/RealESRGAN_x4.pth", device)
75
+ # frame_interpolation_model = load_rife_model("model_rife")
76
+
77
+
78
+ sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
79
+ For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
80
+ There are a few rules to follow:
81
+ You will only ever output a single video description per user request.
82
+ When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
83
+ Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
84
+ Video descriptions must have the same num of words as examples below. Extra words will be ignored.
85
+ """
86
+ def save_video(tensor: Union[List[np.ndarray], List[Image.Image]], fps: int = 8, prefix='rgb'):
87
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
88
+ video_path = f"./output/{prefix}_{timestamp}.mp4"
89
+ os.makedirs(os.path.dirname(video_path), exist_ok=True)
90
+ export_to_video(tensor, video_path, fps=fps)
91
+ return video_path
92
+
93
+ def resize_if_unfit(input_video, progress=gr.Progress(track_tqdm=True)):
94
+ width, height = get_video_dimensions(input_video)
95
+
96
+ if width == 720 and height == 480:
97
+ processed_video = input_video
98
+ else:
99
+ processed_video = center_crop_resize(input_video)
100
+ return processed_video
101
+
102
+
103
+ def get_video_dimensions(input_video_path):
104
+ reader = imageio_ffmpeg.read_frames(input_video_path)
105
+ metadata = next(reader)
106
+ return metadata["size"]
107
+
108
+
109
+ def center_crop_resize(input_video_path, target_width=720, target_height=480):
110
+ cap = cv2.VideoCapture(input_video_path)
111
+
112
+ orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
113
+ orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
114
+ orig_fps = cap.get(cv2.CAP_PROP_FPS)
115
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
116
+
117
+ width_factor = target_width / orig_width
118
+ height_factor = target_height / orig_height
119
+ resize_factor = max(width_factor, height_factor)
120
+
121
+ inter_width = int(orig_width * resize_factor)
122
+ inter_height = int(orig_height * resize_factor)
123
+
124
+ target_fps = 8
125
+ ideal_skip = max(0, math.ceil(orig_fps / target_fps) - 1)
126
+ skip = min(5, ideal_skip) # Cap at 5
127
+
128
+ while (total_frames / (skip + 1)) < 49 and skip > 0:
129
+ skip -= 1
130
+
131
+ processed_frames = []
132
+ frame_count = 0
133
+ total_read = 0
134
+
135
+ while frame_count < 49 and total_read < total_frames:
136
+ ret, frame = cap.read()
137
+ if not ret:
138
+ break
139
+
140
+ if total_read % (skip + 1) == 0:
141
+ resized = cv2.resize(frame, (inter_width, inter_height), interpolation=cv2.INTER_AREA)
142
+
143
+ start_x = (inter_width - target_width) // 2
144
+ start_y = (inter_height - target_height) // 2
145
+ cropped = resized[start_y : start_y + target_height, start_x : start_x + target_width]
146
+
147
+ processed_frames.append(cropped)
148
+ frame_count += 1
149
+
150
+ total_read += 1
151
+
152
+ cap.release()
153
+
154
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
155
+ temp_video_path = temp_file.name
156
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
157
+ out = cv2.VideoWriter(temp_video_path, fourcc, target_fps, (target_width, target_height))
158
+
159
+ for frame in processed_frames:
160
+ out.write(frame)
161
+
162
+ out.release()
163
+
164
+ return temp_video_path
165
+
166
+
167
+
168
+ def infer(
169
+ prompt: str,
170
+ num_inference_steps: int,
171
+ guidance_scale: float,
172
+ seed: int = -1,
173
+ progress=gr.Progress(track_tqdm=True),
174
+ ):
175
+ if seed == -1:
176
+ seed = random.randint(0, 2**8 - 1)
177
+ pipe.to(device)
178
+ video_pt = pipe(
179
+ prompt=prompt + ", isolated background",
180
+ num_videos_per_prompt=1,
181
+ num_inference_steps=num_inference_steps,
182
+ num_frames=13,
183
+ use_dynamic_cfg=True,
184
+ output_type="latent",
185
+ guidance_scale=guidance_scale,
186
+ generator=torch.Generator(device=device).manual_seed(int(seed)),
187
+ ).frames
188
+ # pipe.to("cpu")
189
+ gc.collect()
190
+ return (video_pt, seed)
191
+
192
+
193
+ def convert_to_gif(video_path):
194
+ clip = mp.VideoFileClip(video_path)
195
+ clip = clip.set_fps(8)
196
+ clip = clip.resize(height=240)
197
+ gif_path = video_path.replace(".mp4", ".gif")
198
+ clip.write_gif(gif_path, fps=8)
199
+ return gif_path
200
+
201
+
202
+ def delete_old_files():
203
+ while True:
204
+ now = datetime.now()
205
+ cutoff = now - timedelta(minutes=10)
206
+ directories = ["./output", "./gradio_tmp"]
207
+
208
+ for directory in directories:
209
+ for filename in os.listdir(directory):
210
+ file_path = os.path.join(directory, filename)
211
+ if os.path.isfile(file_path):
212
+ file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
213
+ if file_mtime < cutoff:
214
+ os.remove(file_path)
215
+ time.sleep(600)
216
+
217
+
218
+ threading.Thread(target=delete_old_files, daemon=True).start()
219
+
220
+ with gr.Blocks() as demo:
221
+ gr.HTML("""
222
+ <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
223
+ TransPixar + CogVideoX-5B Huggingface Space🤗
224
+ </div>
225
+ <div style="text-align: center;">
226
+ <a href="https://huggingface.co/wileewang/TransPixar">🤗 TransPixar LoRA Hub</a> |
227
+ <a href="https://github.com/wileewang/TransPixar">🌐 Github</a> |
228
+ <a href="https://arxiv.org/">📜 arxiv </a>
229
+ </div>
230
+ <div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
231
+ ⚠️ This demo is for academic research and experiential use only.
232
+ </div>
233
+ """)
234
+ with gr.Row():
235
+ with gr.Column():
236
+ prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
237
+ with gr.Group():
238
+ with gr.Column():
239
+ with gr.Row():
240
+ seed_param = gr.Number(
241
+ label="Inference Seed (Enter a positive number, -1 for random)", value=-1
242
+ )
243
+
244
+ generate_button = gr.Button("🎬 Generate Video")
245
+ with gr.Row():
246
+ gr.Markdown(
247
+ """
248
+ **Note:** The output RGB is a premultiplied version to avoid the color decontamination problem.
249
+ It can directly composite with a background using:
250
+ ```
251
+ composite = rgb + (1 - alpha) * background
252
+ ```
253
+ """
254
+ )
255
+
256
+ with gr.Column():
257
+ rgb_video_output = gr.Video(label="Generated RGB Video", width=720, height=480)
258
+ alpha_video_output = gr.Video(label="Generated Alpha Video", width=720, height=480)
259
+ with gr.Row():
260
+ download_rgb_video_button = gr.File(label="📥 Download RGB Video", visible=False)
261
+ download_alpha_video_button = gr.File(label="📥 Download Alpha Video", visible=False)
262
+ seed_text = gr.Number(label="Seed Used for Video Generation", visible=False)
263
+
264
+
265
+ def generate(
266
+ prompt,
267
+ seed_value,
268
+ progress=gr.Progress(track_tqdm=True)
269
+ ):
270
+ latents, seed = infer(
271
+ prompt,
272
+ num_inference_steps=25, # NOT Changed 25
273
+ guidance_scale=7.0, # NOT Changed
274
+ seed=seed_value,
275
+ progress=progress,
276
+ )
277
+
278
+ latents_rgb, latents_alpha = latents.chunk(2, dim=1)
279
+
280
+ frames_rgb = decode_latents(pipe, latents_rgb)
281
+ frames_alpha = decode_latents(pipe, latents_alpha)
282
+
283
+ pooled_alpha = np.max(frames_alpha, axis=-1, keepdims=True)
284
+ frames_alpha_pooled = np.repeat(pooled_alpha, 3, axis=-1)
285
+ premultiplied_rgb = frames_rgb * frames_alpha_pooled
286
+
287
+ rgb_video_path = save_video(premultiplied_rgb[0], fps=8, prefix='rgb')
288
+ rgb_video_update = gr.update(visible=True, value=rgb_video_path)
289
+
290
+ alpha_video_path = save_video(frames_alpha_pooled[0], fps=8, prefix='alpha')
291
+ alpha_video_update = gr.update(visible=True, value=alpha_video_path)
292
+ seed_update = gr.update(visible=True, value=seed)
293
+
294
+ return rgb_video_path, alpha_video_path, rgb_video_update, alpha_video_update, seed_update
295
+
296
+
297
+ generate_button.click(
298
+ generate,
299
+ inputs=[prompt, seed_param],
300
+ outputs=[rgb_video_output, alpha_video_output, download_rgb_video_button, download_alpha_video_button, seed_text],
301
+ )
302
+
303
+
304
+ if __name__ == "__main__":
305
+ demo.queue(max_size=15)
306
+ demo.launch(share=True)
model_cogvideox_rgba_lora/cogvideox_rgba_lora.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1865b9208cbfaca5ac8d9cec0c0e3788a32834dea11873d2dee653303e72006e
3
+ size 264292376
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.4.0
2
+ torchvision
3
+ torchaudio
4
+ wandb
5
+ gradio
6
+ sentencepiece
7
+ diffusers==0.32.0
8
+ huggingface_hub==0.27.0
9
+ transformers
10
+ imageio>=2.5.0
11
+ imageio-ffmpeg
12
+ moviepy==1.0.3
13
+ opencv-python>=4.5
14
+ accelerate
wechat_group.jpg ADDED

Git LFS Details

  • SHA256: 4f293f623244cd0746ffd745624cacaca0cc8dbc33da1200edb2d3f4c74baf6e
  • Pointer size: 131 Bytes
  • Size of remote file: 157 kB