` to show
+ memory requirements per model, per training type with sensible training settings.
+
+ PARALLEL ARGUMENTS
+ ------------------
+ parallel_backend (`str`, defaults to `accelerate`):
+ The parallel backend to use for training. Choose between ['accelerate', 'ptd'].
+ pp_degree (`int`, defaults to `1`):
+ The degree of pipeline parallelism.
+ dp_degree (`int`, defaults to `1`):
+ The degree of data parallelism (number of model replicas).
+ dp_shards (`int`, defaults to `-1`):
+ The number of data parallel shards (number of model partitions).
+ cp_degree (`int`, defaults to `1`):
+ The degree of context parallelism.
+
+ MODEL ARGUMENTS
+ ---------------
+ model_name (`str`):
+ Name of model to train. To get a list of models, run `python train.py --list_models`.
+ pretrained_model_name_or_path (`str`):
+ Path to pretrained model or model identifier from https://huggingface.co/models. The model should be
+ loadable based on specified `model_name`.
+ revision (`str`, defaults to `None`):
+ If provided, the model will be loaded from a specific branch of the model repository.
+ variant (`str`, defaults to `None`):
+ Variant of model weights to use. Some models provide weight variants, such as `fp16`, to reduce disk
+ storage requirements.
+ cache_dir (`str`, defaults to `None`):
+ The directory where the downloaded models and datasets will be stored, or loaded from.
+ tokenizer_id (`str`, defaults to `None`):
+ Identifier for the tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`.
+ tokenizer_2_id (`str`, defaults to `None`):
+ Identifier for the second tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`.
+ tokenizer_3_id (`str`, defaults to `None`):
+ Identifier for the third tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`.
+ text_encoder_id (`str`, defaults to `None`):
+ Identifier for the text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`.
+ text_encoder_2_id (`str`, defaults to `None`):
+ Identifier for the second text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`.
+ text_encoder_3_id (`str`, defaults to `None`):
+ Identifier for the third text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`.
+ transformer_id (`str`, defaults to `None`):
+ Identifier for the transformer model. This is useful when using a different transformer model than the default from `pretrained_model_name_or_path`.
+ vae_id (`str`, defaults to `None`):
+ Identifier for the VAE model. This is useful when using a different VAE model than the default from `pretrained_model_name_or_path`.
+ text_encoder_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
+ Data type for the text encoder when generating text embeddings.
+ text_encoder_2_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
+ Data type for the text encoder 2 when generating text embeddings.
+ text_encoder_3_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
+ Data type for the text encoder 3 when generating text embeddings.
+ transformer_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
+ Data type for the transformer model.
+ vae_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
+ Data type for the VAE model.
+ layerwise_upcasting_modules (`List[str]`, defaults to `[]`):
+ Modules that should have fp8 storage weights but higher precision computation. Choose between ['transformer'].
+ layerwise_upcasting_storage_dtype (`torch.dtype`, defaults to `float8_e4m3fn`):
+ Data type for the layerwise upcasting storage. Choose between ['float8_e4m3fn', 'float8_e5m2'].
+ layerwise_upcasting_skip_modules_pattern (`List[str]`, defaults to `["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"]`):
+ Modules to skip for layerwise upcasting. Layers such as normalization and modulation, when casted to fp8 precision
+ naively (as done in layerwise upcasting), can lead to poorer training and inference quality. We skip these layers
+ by default, and recommend adding more layers to the default list based on the model architecture.
+
+ DATASET ARGUMENTS
+ -----------------
+ dataset_config (`str`):
+ File to a dataset file containing information about training data. This file can contain information about one or
+ more datasets in JSON format. The file must have a key called "datasets", which is a list of dictionaries. Each
+ dictionary must contain the following keys:
+ - "data_root": (`str`)
+ The root directory containing the dataset. This parameter must be provided if `dataset_file` is not provided.
+ - "dataset_file": (`str`)
+ Path to a CSV/JSON/JSONL/PARQUET/ARROW/HF_HUB_DATASET file containing metadata for training. This parameter
+ must be provided if `data_root` is not provided.
+ - "dataset_type": (`str`)
+ Type of dataset. Choose between ['image', 'video'].
+ - "id_token": (`str`)
+ Identifier token appended to the start of each prompt if provided. This is useful for LoRA-type training
+ for single subject/concept/style training, but is not necessary.
+ - "image_resolution_buckets": (`List[Tuple[int, int]]`)
+ Resolution buckets for image. This should be a list of tuples containing 2 values, where each tuple
+ represents the resolution (height, width). All images will be resized to the nearest bucket resolution.
+ This parameter must be provided if `dataset_type` is 'image'.
+ - "video_resolution_buckets": (`List[Tuple[int, int, int]]`)
+ Resolution buckets for video. This should be a list of tuples containing 3 values, where each tuple
+ represents the resolution (num_frames, height, width). All videos will be resized to the nearest bucket
+ resolution. This parameter must be provided if `dataset_type` is 'video'.
+ - "reshape_mode": (`str`)
+ All input images/videos are reshaped using this mode. Choose between the following:
+ ["center_crop", "random_crop", "bicubic"].
+ - "remove_common_llm_caption_prefixes": (`boolean`)
+ Whether or not to remove common LLM caption prefixes. See `~constants.py` for the list of common prefixes.
+ dataset_shuffle_buffer_size (`int`, defaults to `1`):
+ The buffer size for shuffling the dataset. This is useful for shuffling the dataset before training. The default
+ value of `1` means that the dataset will not be shuffled.
+ enable_precomputation (`bool`, defaults to `False`):
+ Whether or not to precompute the embeddings for the dataset. This is useful for faster training. If set to `True`,
+ the embeddings will be precomputed and saved to disk and loaded as required.
+ precomputation_items (`int`, defaults to `512`):
+ Number of data samples to precompute at once for memory-efficient training. The higher this value,
+ the more disk memory will be used to save the precomputed samples (conditions and latents).
+ precomputation_dir (`str`, defaults to `None`):
+ The directory where the precomputed samples will be stored. If not provided, the precomputed samples
+ will be stored in a temporary directory of the output directory.
+ precomputation_once (`bool`, defaults to `False`):
+ Precompute embeddings from all datasets at once before training. This is useful to save time during training
+ with smaller datasets. If set to `False`, will save disk space by precomputing embeddings on-the-fly during
+ training when required (that is, computing embeddings of more data samples once `precomputation_items` of them
+ have been exhausted across all distributed ranks). Make sure to set `precomputation_items` to a reasonable value
+ in line with the size of your dataset(s).
+ precomputation_reuse (`bool`, defaults to `False`):
+ Reuse precomputed embeddings from previous training runs. This is useful to save time during training
+ with medium/large datasets. By default, old precomputed embeddings that exist in the specified precomputation
+ directory, or default precomputation dir `{output_dir}/precomputed` will be deleted if this is not set to `True`.
+ This flag is ignored if `enable_precomputation` is `False`. The topology of the distributed training run must be
+ the same as the one used to precompute the embeddings for this to work correctly (this limitation will be
+ addressed in the future).
+
+ DATALOADER_ARGUMENTS
+ --------------------
+ See https://pytorch.org/docs/stable/data.html for more information.
+
+ dataloader_num_workers (`int`, defaults to `0`):
+ Number of subprocesses to use for data loading. `0` means that the data will be loaded in a blocking manner
+ on the main process.
+ pin_memory (`bool`, defaults to `False`):
+ Whether or not to use the pinned memory setting in PyTorch dataloader. This is useful for faster data loading.
+
+ DIFFUSION ARGUMENTS
+ -------------------
+ flow_resolution_shifting (`bool`, defaults to `False`):
+ Resolution-dependent shifting of timestep schedules.
+ [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206).
+ TODO(aryan): We don't support this yet.
+ flow_base_seq_len (`int`, defaults to `256`):
+ Base number of tokens for images/video when applying resolution-dependent shifting.
+ flow_max_seq_len (`int`, defaults to `4096`):
+ Maximum number of tokens for images/video when applying resolution-dependent shifting.
+ flow_base_shift (`float`, defaults to `0.5`):
+ Base shift for timestep schedules when applying resolution-dependent shifting.
+ flow_max_shift (`float`, defaults to `1.15`):
+ Maximum shift for timestep schedules when applying resolution-dependent shifting.
+ flow_shift (`float`, defaults to `1.0`):
+ Instead of training with uniform/logit-normal sigmas, shift them as (shift * sigma) / (1 + (shift - 1) * sigma).
+ Setting it higher is helpful when trying to train models for high-resolution generation or to produce better
+ samples in lower number of inference steps.
+ flow_weighting_scheme (`str`, defaults to `none`):
+ We default to the "none" weighting scheme for uniform sampling and uniform loss.
+ Choose between ['sigma_sqrt', 'logit_normal', 'mode', 'cosmap', 'none'].
+ flow_logit_mean (`float`, defaults to `0.0`):
+ Mean to use when using the `'logit_normal'` weighting scheme.
+ flow_logit_std (`float`, defaults to `1.0`):
+ Standard deviation to use when using the `'logit_normal'` weighting scheme.
+ flow_mode_scale (`float`, defaults to `1.29`):
+ Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.
+
+ TRAINING ARGUMENTS
+ ------------------
+ training_type (`str`, defaults to `None`):
+ Type of training to perform. Choose between ['lora'].
+ seed (`int`, defaults to `42`):
+ A seed for reproducible training.
+ batch_size (`int`, defaults to `1`):
+ Per-device batch size.
+ train_steps (`int`, defaults to `1000`):
+ Total number of training steps to perform.
+ max_data_samples (`int`, defaults to `2**64`):
+ Maximum number of data samples observed during training training. If lesser than that required by `train_steps`,
+ the training will stop early.
+ gradient_accumulation_steps (`int`, defaults to `1`):
+ Number of gradients steps to accumulate before performing an optimizer step.
+ gradient_checkpointing (`bool`, defaults to `False`):
+ Whether or not to use gradient/activation checkpointing to save memory at the expense of slower
+ backward pass.
+ checkpointing_steps (`int`, defaults to `500`):
+ Save a checkpoint of the training state every X training steps. These checkpoints can be used both
+ as final checkpoints in case they are better than the last checkpoint, and are also suitable for
+ resuming training using `resume_from_checkpoint`.
+ checkpointing_limit (`int`, defaults to `None`):
+ Max number of checkpoints to store.
+ resume_from_checkpoint (`str`, defaults to `None`):
+ Can be an integer or the string `"latest"`. If an integer is provided, training will resume from that step if a
+ checkpoint corresponding to it exists. If `"latest"` is provided, training will resume from the latest checkpoint
+ in the `--output_dir`.
+
+ OPTIMIZER ARGUMENTS
+ -------------------
+ optimizer (`str`, defaults to `adamw`):
+ The optimizer type to use. Choose between the following:
+ - Torch optimizers: ["adam", "adamw"]
+ - Bitsandbytes optimizers: ["adam-bnb", "adamw-bnb", "adam-bnb-8bit", "adamw-bnb-8bit"]
+ lr (`float`, defaults to `1e-4`):
+ Initial learning rate (after the potential warmup period) to use.
+ lr_scheduler (`str`, defaults to `cosine_with_restarts`):
+ The scheduler type to use. Choose between ['linear', 'cosine', 'cosine_with_restarts', 'polynomial',
+ 'constant', 'constant_with_warmup'].
+ lr_warmup_steps (`int`, defaults to `500`):
+ Number of steps for the warmup in the lr scheduler.
+ lr_num_cycles (`int`, defaults to `1`):
+ Number of hard resets of the lr in cosine_with_restarts scheduler.
+ lr_power (`float`, defaults to `1.0`):
+ Power factor of the polynomial scheduler.
+ beta1 (`float`, defaults to `0.9`):
+ beta2 (`float`, defaults to `0.95`):
+ beta3 (`float`, defaults to `0.999`):
+ weight_decay (`float`, defaults to `0.0001`):
+ Penalty for large weights in the model.
+ epsilon (`float`, defaults to `1e-8`):
+ Small value to avoid division by zero in the optimizer.
+ max_grad_norm (`float`, defaults to `1.0`):
+ Maximum gradient norm to clip the gradients.
+
+ VALIDATION ARGUMENTS
+ --------------------
+ validation_dataset_file (`str`, defaults to `None`):
+ Path to a CSV/JSON/PARQUET/ARROW file containing information for validation. The file must contain atleast the
+ "caption" column. Other columns such as "image_path" and "video_path" can be provided too. If provided, "image_path"
+ will be used to load a PIL.Image.Image and set the "image" key in the sample dictionary. Similarly, "video_path"
+ will be used to load a List[PIL.Image.Image] and set the "video" key in the sample dictionary.
+ The validation dataset file may contain other attributes specific to inference/validation such as:
+ - "height" and "width" and "num_frames": Resolution
+ - "num_inference_steps": Number of inference steps
+ - "guidance_scale": Classifier-free Guidance Scale
+ - ... (any number of additional attributes can be provided. The ModelSpecification::validate method will be
+ invoked with the sample dictionary to validate the sample.)
+ validation_steps (`int`, defaults to `500`):
+ Number of training steps after which a validation step is performed.
+ enable_model_cpu_offload (`bool`, defaults to `False`):
+ Whether or not to offload different modeling components to CPU during validation.
+
+ MISCELLANEOUS ARGUMENTS
+ -----------------------
+ tracker_name (`str`, defaults to `finetrainers`):
+ Name of the tracker/project to use for logging training metrics.
+ push_to_hub (`bool`, defaults to `False`):
+ Whether or not to push the model to the Hugging Face Hub.
+ hub_token (`str`, defaults to `None`):
+ The API token to use for pushing the model to the Hugging Face Hub.
+ hub_model_id (`str`, defaults to `None`):
+ The model identifier to use for pushing the model to the Hugging Face Hub.
+ output_dir (`str`, defaults to `None`):
+ The directory where the model checkpoints and logs will be stored.
+ logging_dir (`str`, defaults to `logs`):
+ The directory where the logs will be stored.
+ logging_steps (`int`, defaults to `1`):
+ Training logs will be tracked every `logging_steps` steps.
+ nccl_timeout (`int`, defaults to `1800`):
+ Timeout for the NCCL communication.
+ report_to (`str`, defaults to `wandb`):
+ The name of the logger to use for logging training metrics. Choose between ['wandb'].
+ verbose (`int`, defaults to `1`):
+ Whether or not to print verbose logs.
+ - 0: Diffusers/Transformers warning logging on local main process only
+ - 1: Diffusers/Transformers info logging on local main process only
+ - 2: Diffusers/Transformers debug logging on local main process only
+ - 3: Diffusers/Transformers debug logging on all processes
+
+ TORCH CONFIG ARGUMENTS
+ ----------------------
+ compile_modules (`List[str]`, defaults to `[]`):
+ Modules that should be regionally compiled with `torch.compile`.
+ compile_scopes (`str`, defaults to `None`):
+ The scope of compilation for each `--compile_modules`. Choose between ['regional', 'full']. Must have the same length as
+ `--compile_modules`. If `None`, will default to `regional` for all modules.
+ allow_tf32 (`bool`, defaults to `False`):
+ Whether or not to allow the use of TF32 matmul on compatible hardware.
+ float32_matmul_precision (`str`, defaults to `highest`):
+ The precision to use for float32 matmul. Choose between ['highest', 'high', 'medium'].
+ """
+
+ # Parallel arguments
+ parallel_backend = ParallelBackendEnum.ACCELERATE
+ pp_degree: int = 1
+ dp_degree: int = 1
+ dp_shards: int = 1
+ cp_degree: int = 1
+ tp_degree: int = 1
+
+ # Model arguments
+ model_name: str = None
+ pretrained_model_name_or_path: str = None
+ revision: Optional[str] = None
+ variant: Optional[str] = None
+ cache_dir: Optional[str] = None
+ tokenizer_id: Optional[str] = None
+ tokenizer_2_id: Optional[str] = None
+ tokenizer_3_id: Optional[str] = None
+ text_encoder_id: Optional[str] = None
+ text_encoder_2_id: Optional[str] = None
+ text_encoder_3_id: Optional[str] = None
+ transformer_id: Optional[str] = None
+ vae_id: Optional[str] = None
+ text_encoder_dtype: torch.dtype = torch.bfloat16
+ text_encoder_2_dtype: torch.dtype = torch.bfloat16
+ text_encoder_3_dtype: torch.dtype = torch.bfloat16
+ transformer_dtype: torch.dtype = torch.bfloat16
+ vae_dtype: torch.dtype = torch.bfloat16
+ layerwise_upcasting_modules: List[str] = []
+ layerwise_upcasting_storage_dtype: torch.dtype = torch.float8_e4m3fn
+ # fmt: off
+ layerwise_upcasting_skip_modules_pattern: List[str] = ["patch_embed", "pos_embed", "x_embedder", "context_embedder", "time_embed", "^proj_in$", "^proj_out$", "norm"]
+ # fmt: on
+
+ # Dataset arguments
+ dataset_config: str = None
+ dataset_shuffle_buffer_size: int = 1
+ enable_precomputation: bool = False
+ precomputation_items: int = 512
+ precomputation_dir: Optional[str] = None
+ precomputation_once: bool = False
+ precomputation_reuse: bool = False
+
+ # Dataloader arguments
+ dataloader_num_workers: int = 0
+ pin_memory: bool = False
+
+ # Diffusion arguments
+ flow_resolution_shifting: bool = False
+ flow_base_seq_len: int = 256
+ flow_max_seq_len: int = 4096
+ flow_base_shift: float = 0.5
+ flow_max_shift: float = 1.15
+ flow_shift: float = 1.0
+ flow_weighting_scheme: str = "none"
+ flow_logit_mean: float = 0.0
+ flow_logit_std: float = 1.0
+ flow_mode_scale: float = 1.29
+
+ # Training arguments
+ training_type: str = None
+ seed: int = 42
+ batch_size: int = 1
+ train_steps: int = 1000
+ max_data_samples: int = 2**64
+ gradient_accumulation_steps: int = 1
+ gradient_checkpointing: bool = False
+ checkpointing_steps: int = 500
+ checkpointing_limit: Optional[int] = None
+ resume_from_checkpoint: Optional[str] = None
+ enable_slicing: bool = False
+ enable_tiling: bool = False
+
+ # Optimizer arguments
+ optimizer: str = "adamw"
+ lr: float = 1e-4
+ lr_scheduler: str = "cosine_with_restarts"
+ lr_warmup_steps: int = 0
+ lr_num_cycles: int = 1
+ lr_power: float = 1.0
+ beta1: float = 0.9
+ beta2: float = 0.95
+ beta3: float = 0.999
+ weight_decay: float = 0.0001
+ epsilon: float = 1e-8
+ max_grad_norm: float = 1.0
+
+ # Validation arguments
+ validation_dataset_file: Optional[str] = None
+ validation_steps: int = 500
+ enable_model_cpu_offload: bool = False
+
+ # Miscellaneous arguments
+ tracker_name: str = "finetrainers"
+ push_to_hub: bool = False
+ hub_token: Optional[str] = None
+ hub_model_id: Optional[str] = None
+ output_dir: str = None
+ logging_dir: Optional[str] = "logs"
+ logging_steps: int = 1
+ init_timeout: int = 300 # 5 minutes
+ nccl_timeout: int = 600 # 10 minutes, considering that validation may be performed
+ report_to: str = "wandb"
+ verbose: int = 1
+
+ # Torch config arguments
+ compile_modules: List[str] = []
+ compile_scopes: List[str] = None
+ allow_tf32: bool = False
+ float32_matmul_precision: str = "highest"
+
+ # Attention provider arguments
+ attention_provider_args: AttentionProviderArgs = AttentionProviderArgs()
+
+ _registered_config_mixins: List[ArgsConfigMixin] = []
+ _arg_group_map: Dict[str, ArgsConfigMixin] = {}
+
+ def __init__(self):
+ self._arg_group_map: Dict[str, ArgsConfigMixin] = {
+ "attention_provider_args": self.attention_provider_args,
+ }
+
+ for arg_config_mixin in self._arg_group_map.values():
+ self.register_args(arg_config_mixin)
+
+ def to_dict(self) -> Dict[str, Any]:
+ parallel_arguments = {
+ "pp_degree": self.pp_degree,
+ "dp_degree": self.dp_degree,
+ "dp_shards": self.dp_shards,
+ "cp_degree": self.cp_degree,
+ "tp_degree": self.tp_degree,
+ }
+
+ model_arguments = {
+ "model_name": self.model_name,
+ "pretrained_model_name_or_path": self.pretrained_model_name_or_path,
+ "revision": self.revision,
+ "variant": self.variant,
+ "cache_dir": self.cache_dir,
+ "tokenizer_id": self.tokenizer_id,
+ "tokenizer_2_id": self.tokenizer_2_id,
+ "tokenizer_3_id": self.tokenizer_3_id,
+ "text_encoder_id": self.text_encoder_id,
+ "text_encoder_2_id": self.text_encoder_2_id,
+ "text_encoder_3_id": self.text_encoder_3_id,
+ "transformer_id": self.transformer_id,
+ "vae_id": self.vae_id,
+ "text_encoder_dtype": self.text_encoder_dtype,
+ "text_encoder_2_dtype": self.text_encoder_2_dtype,
+ "text_encoder_3_dtype": self.text_encoder_3_dtype,
+ "transformer_dtype": self.transformer_dtype,
+ "vae_dtype": self.vae_dtype,
+ "layerwise_upcasting_modules": self.layerwise_upcasting_modules,
+ "layerwise_upcasting_storage_dtype": self.layerwise_upcasting_storage_dtype,
+ "layerwise_upcasting_skip_modules_pattern": self.layerwise_upcasting_skip_modules_pattern,
+ }
+ model_arguments = get_non_null_items(model_arguments)
+
+ dataset_arguments = {
+ "dataset_config": self.dataset_config,
+ "dataset_shuffle_buffer_size": self.dataset_shuffle_buffer_size,
+ "enable_precomputation": self.enable_precomputation,
+ "precomputation_items": self.precomputation_items,
+ "precomputation_dir": self.precomputation_dir,
+ "precomputation_once": self.precomputation_once,
+ "precomputation_reuse": self.precomputation_reuse,
+ }
+ dataset_arguments = get_non_null_items(dataset_arguments)
+
+ dataloader_arguments = {
+ "dataloader_num_workers": self.dataloader_num_workers,
+ "pin_memory": self.pin_memory,
+ }
+
+ diffusion_arguments = {
+ "flow_resolution_shifting": self.flow_resolution_shifting,
+ "flow_base_seq_len": self.flow_base_seq_len,
+ "flow_max_seq_len": self.flow_max_seq_len,
+ "flow_base_shift": self.flow_base_shift,
+ "flow_max_shift": self.flow_max_shift,
+ "flow_shift": self.flow_shift,
+ "flow_weighting_scheme": self.flow_weighting_scheme,
+ "flow_logit_mean": self.flow_logit_mean,
+ "flow_logit_std": self.flow_logit_std,
+ "flow_mode_scale": self.flow_mode_scale,
+ }
+
+ training_arguments = {
+ "training_type": self.training_type,
+ "seed": self.seed,
+ "batch_size": self.batch_size,
+ "train_steps": self.train_steps,
+ "max_data_samples": self.max_data_samples,
+ "gradient_accumulation_steps": self.gradient_accumulation_steps,
+ "gradient_checkpointing": self.gradient_checkpointing,
+ "checkpointing_steps": self.checkpointing_steps,
+ "checkpointing_limit": self.checkpointing_limit,
+ "resume_from_checkpoint": self.resume_from_checkpoint,
+ "enable_slicing": self.enable_slicing,
+ "enable_tiling": self.enable_tiling,
+ }
+ training_arguments = get_non_null_items(training_arguments)
+
+ optimizer_arguments = {
+ "optimizer": self.optimizer,
+ "lr": self.lr,
+ "lr_scheduler": self.lr_scheduler,
+ "lr_warmup_steps": self.lr_warmup_steps,
+ "lr_num_cycles": self.lr_num_cycles,
+ "lr_power": self.lr_power,
+ "beta1": self.beta1,
+ "beta2": self.beta2,
+ "beta3": self.beta3,
+ "weight_decay": self.weight_decay,
+ "epsilon": self.epsilon,
+ "max_grad_norm": self.max_grad_norm,
+ }
+ optimizer_arguments = get_non_null_items(optimizer_arguments)
+
+ validation_arguments = {
+ "validation_dataset_file": self.validation_dataset_file,
+ "validation_steps": self.validation_steps,
+ "enable_model_cpu_offload": self.enable_model_cpu_offload,
+ }
+ validation_arguments = get_non_null_items(validation_arguments)
+
+ miscellaneous_arguments = {
+ "tracker_name": self.tracker_name,
+ "push_to_hub": self.push_to_hub,
+ "hub_token": self.hub_token,
+ "hub_model_id": self.hub_model_id,
+ "output_dir": self.output_dir,
+ "logging_dir": self.logging_dir,
+ "logging_steps": self.logging_steps,
+ "init_timeout": self.init_timeout,
+ "nccl_timeout": self.nccl_timeout,
+ "report_to": self.report_to,
+ "verbose": self.verbose,
+ }
+ miscellaneous_arguments = get_non_null_items(miscellaneous_arguments)
+
+ torch_config_arguments = {
+ "compile_modules": self.compile_modules,
+ "compile_scopes": self.compile_scopes,
+ "allow_tf32": self.allow_tf32,
+ "float32_matmul_precision": self.float32_matmul_precision,
+ }
+
+ additional_arguments = {}
+ for config_mixin in self._registered_config_mixins:
+ additional_arguments[config_mixin.__class__.__name__] = config_mixin.to_dict()
+
+ return {
+ "parallel_arguments": parallel_arguments,
+ "model_arguments": model_arguments,
+ "dataset_arguments": dataset_arguments,
+ "dataloader_arguments": dataloader_arguments,
+ "diffusion_arguments": diffusion_arguments,
+ "training_arguments": training_arguments,
+ "optimizer_arguments": optimizer_arguments,
+ "validation_arguments": validation_arguments,
+ "miscellaneous_arguments": miscellaneous_arguments,
+ "additional_arguments": additional_arguments,
+ "torch_config_arguments": torch_config_arguments,
+ }
+
+ def register_args(self, config: ArgsConfigMixin) -> None:
+ if not hasattr(self, "_extended_add_arguments"):
+ self._extended_add_arguments = []
+ self._extended_add_arguments.append((config.add_args, config.validate_args, config.map_args))
+ self._registered_config_mixins.append(config)
+
+ def parse_args(self):
+ _LIST_MODELS = "--list_models"
+
+ parser = argparse.ArgumentParser()
+
+ special_args = [_LIST_MODELS]
+ if any(arg in sys.argv for arg in special_args):
+ _add_helper_arguments(parser)
+ args = parser.parse_args()
+ _display_helper_messages(args)
+ sys.exit(0)
+ else:
+ _add_args(parser)
+ for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []):
+ add_fn, _, _ = extended_add_arg_fns
+ add_fn(parser)
+
+ args, remaining_args = parser.parse_known_args()
+ logger.debug(f"Remaining unparsed arguments: {remaining_args}")
+
+ mapped_args = _map_to_args_type(args)
+ for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []):
+ _, _, map_fn = extended_add_arg_fns
+ map_fn(args, mapped_args)
+
+ _validate_args(mapped_args)
+ for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []):
+ _, validate_fn, _ = extended_add_arg_fns
+ validate_fn(mapped_args)
+
+ return mapped_args
+
+ def __getattribute__(self, name: str):
+ try:
+ return object.__getattribute__(self, name)
+ except AttributeError:
+ for arg_group in self._arg_group_map.values():
+ if hasattr(arg_group, name):
+ return getattr(arg_group, name)
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
+
+ def __setattr__(self, name: str, value: Any):
+ if name in self.__dict__:
+ object.__setattr__(self, name, value)
+ return
+ for arg_group in self._arg_group_map.values():
+ if hasattr(arg_group, name):
+ setattr(arg_group, name, value)
+ return
+ object.__setattr__(self, name, value)
+
+
+def _add_args(parser: argparse.ArgumentParser) -> None:
+ _add_parallel_arguments(parser)
+ _add_model_arguments(parser)
+ _add_dataset_arguments(parser)
+ _add_dataloader_arguments(parser)
+ _add_diffusion_arguments(parser)
+ _add_training_arguments(parser)
+ _add_optimizer_arguments(parser)
+ _add_validation_arguments(parser)
+ _add_miscellaneous_arguments(parser)
+ _add_torch_config_arguments(parser)
+
+
+def _validate_args(args: BaseArgs):
+ _validate_model_args(args)
+ _validate_dataset_args(args)
+ _validate_validation_args(args)
+
+
+def _add_parallel_arguments(parser: argparse.ArgumentParser) -> None:
+ parser.add_argument(
+ "--parallel_backend",
+ type=str,
+ default=ParallelBackendEnum.ACCELERATE,
+ choices=[ParallelBackendEnum.ACCELERATE, ParallelBackendEnum.PTD],
+ )
+ parser.add_argument("--pp_degree", type=int, default=1)
+ parser.add_argument("--dp_degree", type=int, default=1)
+ parser.add_argument("--dp_shards", type=int, default=1)
+ parser.add_argument("--cp_degree", type=int, default=1)
+ parser.add_argument("--tp_degree", type=int, default=1)
+
+
+def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
+ parser.add_argument(
+ "--model_name", type=str, required=True, choices=[x.value for x in ModelType.__members__.values()]
+ )
+ parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
+ parser.add_argument("--revision", type=str, default=None, required=False)
+ parser.add_argument("--variant", type=str, default=None)
+ parser.add_argument("--cache_dir", type=str, default=None)
+ parser.add_argument("--tokenizer_id", type=str, default=None)
+ parser.add_argument("--tokenizer_2_id", type=str, default=None)
+ parser.add_argument("--tokenizer_3_id", type=str, default=None)
+ parser.add_argument("--text_encoder_id", type=str, default=None)
+ parser.add_argument("--text_encoder_2_id", type=str, default=None)
+ parser.add_argument("--text_encoder_3_id", type=str, default=None)
+ parser.add_argument("--transformer_id", type=str, default=None)
+ parser.add_argument("--vae_id", type=str, default=None)
+ parser.add_argument("--text_encoder_dtype", type=str, default="bf16")
+ parser.add_argument("--text_encoder_2_dtype", type=str, default="bf16")
+ parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16")
+ parser.add_argument("--transformer_dtype", type=str, default="bf16")
+ parser.add_argument("--vae_dtype", type=str, default="bf16")
+ parser.add_argument("--layerwise_upcasting_modules", type=str, default=[], nargs="+", choices=["transformer"])
+ parser.add_argument(
+ "--layerwise_upcasting_storage_dtype",
+ type=str,
+ default="float8_e4m3fn",
+ choices=["float8_e4m3fn", "float8_e5m2"],
+ )
+ parser.add_argument(
+ "--layerwise_upcasting_skip_modules_pattern",
+ type=str,
+ default=["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"],
+ nargs="+",
+ )
+
+
+def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None:
+ parser.add_argument("--dataset_config", type=str, required=True)
+ parser.add_argument("--dataset_shuffle_buffer_size", type=int, default=1)
+ parser.add_argument("--enable_precomputation", action="store_true")
+ parser.add_argument("--precomputation_items", type=int, default=512)
+ parser.add_argument("--precomputation_dir", type=str, default=None)
+ parser.add_argument("--precomputation_once", action="store_true")
+ parser.add_argument("--precomputation_reuse", action="store_true")
+
+
+def _add_dataloader_arguments(parser: argparse.ArgumentParser) -> None:
+ parser.add_argument("--dataloader_num_workers", type=int, default=0)
+ parser.add_argument("--pin_memory", action="store_true")
+
+
+def _add_diffusion_arguments(parser: argparse.ArgumentParser) -> None:
+ parser.add_argument("--flow_resolution_shifting", action="store_true")
+ parser.add_argument("--flow_base_seq_len", type=int, default=256)
+ parser.add_argument("--flow_max_seq_len", type=int, default=4096)
+ parser.add_argument("--flow_base_shift", type=float, default=0.5)
+ parser.add_argument("--flow_max_shift", type=float, default=1.15)
+ parser.add_argument("--flow_shift", type=float, default=1.0)
+ parser.add_argument(
+ "--flow_weighting_scheme",
+ type=str,
+ default="none",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
+ )
+ parser.add_argument("--flow_logit_mean", type=float, default=0.0)
+ parser.add_argument("--flow_logit_std", type=float, default=1.0)
+ parser.add_argument("--flow_mode_scale", type=float, default=1.29)
+
+
+def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
+ parser.add_argument(
+ "--training_type", type=str, choices=[x.value for x in TrainingType.__members__.values()], required=True
+ )
+ parser.add_argument("--seed", type=int, default=None)
+ parser.add_argument("--batch_size", type=int, default=1)
+ parser.add_argument("--train_steps", type=int, default=1000)
+ parser.add_argument("--max_data_samples", type=int, default=2**64)
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
+ parser.add_argument("--gradient_checkpointing", action="store_true")
+ parser.add_argument("--checkpointing_steps", type=int, default=500)
+ parser.add_argument("--checkpointing_limit", type=int, default=None)
+ parser.add_argument("--resume_from_checkpoint", type=str, default=None)
+ parser.add_argument("--enable_slicing", action="store_true")
+ parser.add_argument("--enable_tiling", action="store_true")
+
+
+def _add_optimizer_arguments(parser: argparse.ArgumentParser) -> None:
+ parser.add_argument("--lr", type=float, default=1e-4)
+ parser.add_argument("--lr_scheduler", type=str, default="constant")
+ parser.add_argument("--lr_warmup_steps", type=int, default=500)
+ parser.add_argument("--lr_num_cycles", type=int, default=1)
+ parser.add_argument("--lr_power", type=float, default=1.0)
+ parser.add_argument(
+ "--optimizer",
+ type=lambda s: s.lower(),
+ default="adam",
+ choices=["adam", "adamw", "adam-bnb", "adamw-bnb", "adam-bnb-8bit", "adamw-bnb-8bit"],
+ )
+ parser.add_argument("--beta1", type=float, default=0.9)
+ parser.add_argument("--beta2", type=float, default=0.95)
+ parser.add_argument("--beta3", type=float, default=None)
+ parser.add_argument("--weight_decay", type=float, default=1e-04)
+ parser.add_argument("--epsilon", type=float, default=1e-8)
+ parser.add_argument("--max_grad_norm", default=1.0, type=float)
+
+
+def _add_validation_arguments(parser: argparse.ArgumentParser) -> None:
+ parser.add_argument("--validation_dataset_file", type=str, default=None)
+ parser.add_argument("--validation_steps", type=int, default=500)
+ parser.add_argument("--enable_model_cpu_offload", action="store_true")
+
+
+def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None:
+ parser.add_argument("--tracker_name", type=str, default="finetrainers")
+ parser.add_argument("--push_to_hub", action="store_true")
+ parser.add_argument("--hub_token", type=str, default=None)
+ parser.add_argument("--hub_model_id", type=str, default=None)
+ parser.add_argument("--output_dir", type=str, default="finetrainers-training")
+ parser.add_argument("--logging_dir", type=str, default="logs")
+ parser.add_argument("--logging_steps", type=int, default=1)
+ parser.add_argument("--init_timeout", type=int, default=300)
+ parser.add_argument("--nccl_timeout", type=int, default=600)
+ parser.add_argument("--report_to", type=str, default="none", choices=["none", "wandb"])
+ parser.add_argument("--verbose", type=int, default=0, choices=[0, 1, 2, 3])
+
+
+def _add_torch_config_arguments(parser: argparse.ArgumentParser) -> None:
+ parser.add_argument("--compile_modules", type=str, default=[], nargs="+")
+ parser.add_argument("--compile_scopes", type=str, default=None, nargs="+")
+ parser.add_argument("--allow_tf32", action="store_true")
+ parser.add_argument(
+ "--float32_matmul_precision",
+ type=str,
+ default="highest",
+ choices=["highest", "high", "medium"],
+ help="The precision to use for float32 matmul. Choose between ['highest', 'high', 'medium'].",
+ )
+
+
+def _add_helper_arguments(parser: argparse.ArgumentParser) -> None:
+ parser.add_argument("--list_models", action="store_true")
+
+
+_DTYPE_MAP = {
+ "bf16": torch.bfloat16,
+ "fp16": torch.float16,
+ "fp32": torch.float32,
+ "float8_e4m3fn": torch.float8_e4m3fn,
+ "float8_e5m2": torch.float8_e5m2,
+}
+
+
+def _map_to_args_type(args: Dict[str, Any]) -> BaseArgs:
+ result_args = BaseArgs()
+
+ # Parallel arguments
+ result_args.parallel_backend = args.parallel_backend
+ result_args.pp_degree = args.pp_degree
+ result_args.dp_degree = args.dp_degree
+ result_args.dp_shards = args.dp_shards
+ result_args.cp_degree = args.cp_degree
+ result_args.tp_degree = args.tp_degree
+
+ # Model arguments
+ compile_scopes = args.compile_scopes
+ if len(args.compile_modules) > 0:
+ if compile_scopes is None:
+ compile_scopes = "regional"
+ if isinstance(compile_scopes, list) and len(compile_scopes) == 1:
+ compile_scopes = compile_scopes[0]
+ if isinstance(compile_scopes, str):
+ compile_scopes = [compile_scopes] * len(args.compile_modules)
+ else:
+ compile_scopes = []
+
+ result_args.model_name = args.model_name
+ result_args.pretrained_model_name_or_path = args.pretrained_model_name_or_path
+ result_args.revision = args.revision
+ result_args.variant = args.variant
+ result_args.cache_dir = args.cache_dir
+ result_args.tokenizer_id = args.tokenizer_id
+ result_args.tokenizer_2_id = args.tokenizer_2_id
+ result_args.tokenizer_3_id = args.tokenizer_3_id
+ result_args.text_encoder_id = args.text_encoder_id
+ result_args.text_encoder_2_id = args.text_encoder_2_id
+ result_args.text_encoder_3_id = args.text_encoder_3_id
+ result_args.transformer_id = args.transformer_id
+ result_args.vae_id = args.vae_id
+ result_args.text_encoder_dtype = _DTYPE_MAP[args.text_encoder_dtype]
+ result_args.text_encoder_2_dtype = _DTYPE_MAP[args.text_encoder_2_dtype]
+ result_args.text_encoder_3_dtype = _DTYPE_MAP[args.text_encoder_3_dtype]
+ result_args.transformer_dtype = _DTYPE_MAP[args.transformer_dtype]
+ result_args.vae_dtype = _DTYPE_MAP[args.vae_dtype]
+ result_args.layerwise_upcasting_modules = args.layerwise_upcasting_modules
+ result_args.layerwise_upcasting_storage_dtype = _DTYPE_MAP[args.layerwise_upcasting_storage_dtype]
+ result_args.layerwise_upcasting_skip_modules_pattern = args.layerwise_upcasting_skip_modules_pattern
+
+ # Dataset arguments
+ result_args.dataset_config = args.dataset_config
+ result_args.dataset_shuffle_buffer_size = args.dataset_shuffle_buffer_size
+ result_args.enable_precomputation = args.enable_precomputation
+ result_args.precomputation_items = args.precomputation_items
+ result_args.precomputation_dir = args.precomputation_dir or os.path.join(args.output_dir, "precomputed")
+ result_args.precomputation_once = args.precomputation_once
+ result_args.precomputation_reuse = args.precomputation_reuse
+
+ # Dataloader arguments
+ result_args.dataloader_num_workers = args.dataloader_num_workers
+ result_args.pin_memory = args.pin_memory
+
+ # Diffusion arguments
+ result_args.flow_resolution_shifting = args.flow_resolution_shifting
+ result_args.flow_base_seq_len = args.flow_base_seq_len
+ result_args.flow_max_seq_len = args.flow_max_seq_len
+ result_args.flow_base_shift = args.flow_base_shift
+ result_args.flow_max_shift = args.flow_max_shift
+ result_args.flow_shift = args.flow_shift
+ result_args.flow_weighting_scheme = args.flow_weighting_scheme
+ result_args.flow_logit_mean = args.flow_logit_mean
+ result_args.flow_logit_std = args.flow_logit_std
+ result_args.flow_mode_scale = args.flow_mode_scale
+
+ # Training arguments
+ result_args.training_type = args.training_type
+ result_args.seed = args.seed
+ result_args.batch_size = args.batch_size
+ result_args.train_steps = args.train_steps
+ result_args.max_data_samples = args.max_data_samples
+ result_args.gradient_accumulation_steps = args.gradient_accumulation_steps
+ result_args.gradient_checkpointing = args.gradient_checkpointing
+ result_args.checkpointing_steps = args.checkpointing_steps
+ result_args.checkpointing_limit = args.checkpointing_limit
+ result_args.resume_from_checkpoint = args.resume_from_checkpoint
+ result_args.enable_slicing = args.enable_slicing
+ result_args.enable_tiling = args.enable_tiling
+
+ # Optimizer arguments
+ result_args.optimizer = args.optimizer or "adamw"
+ result_args.lr = args.lr or 1e-4
+ result_args.lr_scheduler = args.lr_scheduler
+ result_args.lr_warmup_steps = args.lr_warmup_steps
+ result_args.lr_num_cycles = args.lr_num_cycles
+ result_args.lr_power = args.lr_power
+ result_args.beta1 = args.beta1
+ result_args.beta2 = args.beta2
+ result_args.beta3 = args.beta3
+ result_args.weight_decay = args.weight_decay
+ result_args.epsilon = args.epsilon
+ result_args.max_grad_norm = args.max_grad_norm
+
+ # Validation arguments
+ result_args.validation_dataset_file = args.validation_dataset_file
+ result_args.validation_steps = args.validation_steps
+ result_args.enable_model_cpu_offload = args.enable_model_cpu_offload
+
+ # Miscellaneous arguments
+ result_args.tracker_name = args.tracker_name
+ result_args.push_to_hub = args.push_to_hub
+ result_args.hub_token = args.hub_token
+ result_args.hub_model_id = args.hub_model_id
+ result_args.output_dir = args.output_dir
+ result_args.logging_dir = args.logging_dir
+ result_args.logging_steps = args.logging_steps
+ result_args.init_timeout = args.init_timeout
+ result_args.nccl_timeout = args.nccl_timeout
+ result_args.report_to = args.report_to
+ result_args.verbose = args.verbose
+
+ # Torch config arguments
+ result_args.compile_modules = args.compile_modules
+ result_args.compile_scopes = compile_scopes
+ result_args.allow_tf32 = args.allow_tf32
+ result_args.float32_matmul_precision = args.float32_matmul_precision
+
+ return result_args
+
+
+def _validate_model_args(args: BaseArgs):
+ if args.training_type == "full-finetune":
+ assert "transformer" not in args.layerwise_upcasting_modules, (
+ "Layerwise upcasting is not supported for full-finetune training"
+ )
+ if len(args.compile_modules) > 0:
+ assert len(args.compile_modules) == len(args.compile_scopes) and all(
+ x in ["regional", "full"] for x in args.compile_scopes
+ ), (
+ "Compile modules and compile scopes must be of the same length and compile scopes must be either 'regional' or 'full'"
+ )
+
+
+def _validate_dataset_args(args: BaseArgs):
+ dataset_config = pathlib.Path(args.dataset_config)
+ if not dataset_config.exists():
+ raise ValueError(f"Dataset config file {args.dataset_config} does not exist.")
+ if args.dataset_shuffle_buffer_size < 1:
+ raise ValueError("Dataset shuffle buffer size must be greater than 0.")
+ if args.precomputation_items < 1:
+ raise ValueError("Precomputation items must be greater than 0.")
+
+
+def _validate_validation_args(args: BaseArgs):
+ if args.enable_model_cpu_offload:
+ if any(x > 1 for x in [args.pp_degree, args.dp_degree, args.dp_shards, args.cp_degree, args.tp_degree]):
+ raise ValueError("Model CPU offload is not supported on multi-GPU at the moment.")
+
+
+def _display_helper_messages(args: argparse.Namespace):
+ if args.list_models:
+ print("Supported models:")
+ for index, model_name in enumerate(SUPPORTED_MODEL_CONFIGS.keys()):
+ print(f" {index + 1}. {model_name}")
diff --git a/docs/finetrainers-src-codebase/finetrainers/config.py b/docs/finetrainers-src-codebase/finetrainers/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..46e713e9b6ea9314a63d994875900d2a5facf3bd
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/config.py
@@ -0,0 +1,75 @@
+from enum import Enum
+from typing import Type
+
+from .models import ModelSpecification
+from .models.cogvideox import CogVideoXModelSpecification
+from .models.cogview4 import CogView4ControlModelSpecification, CogView4ModelSpecification
+from .models.flux import FluxModelSpecification
+from .models.hunyuan_video import HunyuanVideoModelSpecification
+from .models.ltx_video import LTXVideoModelSpecification
+from .models.wan import WanControlModelSpecification, WanModelSpecification
+
+
+class ModelType(str, Enum):
+ COGVIDEOX = "cogvideox"
+ COGVIEW4 = "cogview4"
+ FLUX = "flux"
+ HUNYUAN_VIDEO = "hunyuan_video"
+ LTX_VIDEO = "ltx_video"
+ WAN = "wan"
+
+
+class TrainingType(str, Enum):
+ # SFT
+ LORA = "lora"
+ FULL_FINETUNE = "full-finetune"
+
+ # Control
+ CONTROL_LORA = "control-lora"
+ CONTROL_FULL_FINETUNE = "control-full-finetune"
+
+
+SUPPORTED_MODEL_CONFIGS = {
+ # TODO(aryan): autogenerate this
+ # SFT
+ ModelType.COGVIDEOX: {
+ TrainingType.LORA: CogVideoXModelSpecification,
+ TrainingType.FULL_FINETUNE: CogVideoXModelSpecification,
+ },
+ ModelType.COGVIEW4: {
+ TrainingType.LORA: CogView4ModelSpecification,
+ TrainingType.FULL_FINETUNE: CogView4ModelSpecification,
+ TrainingType.CONTROL_LORA: CogView4ControlModelSpecification,
+ TrainingType.CONTROL_FULL_FINETUNE: CogView4ControlModelSpecification,
+ },
+ ModelType.FLUX: {
+ TrainingType.LORA: FluxModelSpecification,
+ TrainingType.FULL_FINETUNE: FluxModelSpecification,
+ },
+ ModelType.HUNYUAN_VIDEO: {
+ TrainingType.LORA: HunyuanVideoModelSpecification,
+ TrainingType.FULL_FINETUNE: HunyuanVideoModelSpecification,
+ },
+ ModelType.LTX_VIDEO: {
+ TrainingType.LORA: LTXVideoModelSpecification,
+ TrainingType.FULL_FINETUNE: LTXVideoModelSpecification,
+ },
+ ModelType.WAN: {
+ TrainingType.LORA: WanModelSpecification,
+ TrainingType.FULL_FINETUNE: WanModelSpecification,
+ TrainingType.CONTROL_LORA: WanControlModelSpecification,
+ TrainingType.CONTROL_FULL_FINETUNE: WanControlModelSpecification,
+ },
+}
+
+
+def _get_model_specifiction_cls(model_name: str, training_type: str) -> Type[ModelSpecification]:
+ if model_name not in SUPPORTED_MODEL_CONFIGS:
+ raise ValueError(
+ f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}"
+ )
+ if training_type not in SUPPORTED_MODEL_CONFIGS[model_name]:
+ raise ValueError(
+ f"Training type {training_type} not supported for model {model_name}. Supported training types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}"
+ )
+ return SUPPORTED_MODEL_CONFIGS[model_name][training_type]
diff --git a/docs/finetrainers-src-codebase/finetrainers/constants.py b/docs/finetrainers-src-codebase/finetrainers/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd45d2925af9541608170b8f73244f27b13471d2
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/constants.py
@@ -0,0 +1,87 @@
+import os
+
+
+ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
+
+FINETRAINERS_LOG_LEVEL = os.environ.get("FINETRAINERS_LOG_LEVEL", "INFO")
+FINETRAINERS_ATTN_PROVIDER = os.environ.get("FINETRAINERS_ATTN_PROVIDER", "native")
+FINETRAINERS_ATTN_CHECKS = os.getenv("FINETRAINERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
+FINETRAINERS_ENABLE_TIMING = os.getenv("FINETRAINERS_ENABLE_TIMING", "1") in ENV_VARS_TRUE_VALUES
+
+DEFAULT_HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
+DEFAULT_WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
+DEFAULT_FRAME_BUCKETS = [49]
+
+DEFAULT_IMAGE_RESOLUTION_BUCKETS = []
+for height in DEFAULT_HEIGHT_BUCKETS:
+ for width in DEFAULT_WIDTH_BUCKETS:
+ DEFAULT_IMAGE_RESOLUTION_BUCKETS.append((height, width))
+
+DEFAULT_VIDEO_RESOLUTION_BUCKETS = []
+for frames in DEFAULT_FRAME_BUCKETS:
+ for height in DEFAULT_HEIGHT_BUCKETS:
+ for width in DEFAULT_WIDTH_BUCKETS:
+ DEFAULT_VIDEO_RESOLUTION_BUCKETS.append((frames, height, width))
+
+PRECOMPUTED_DIR_NAME = "precomputed"
+PRECOMPUTED_CONDITIONS_DIR_NAME = "conditions"
+PRECOMPUTED_LATENTS_DIR_NAME = "latents"
+
+MODEL_DESCRIPTION = r"""
+\# {model_id} {training_type} finetune
+
+
+
+\#\# Model Description
+
+This model is a {training_type} of the `{model_id}` model.
+
+This model was trained using the `fine-video-trainers` library - a repository containing memory-optimized scripts for training video models with [Diffusers](https://github.com/huggingface/diffusers).
+
+\#\# Download model
+
+[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
+
+\#\# Usage
+
+Requires [🧨 Diffusers](https://github.com/huggingface/diffusers) installed.
+
+```python
+{model_example}
+```
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
+
+\#\# License
+
+Please adhere to the license of the base model.
+""".strip()
+
+_COMMON_BEGINNING_PHRASES = (
+ "This video",
+ "The video",
+ "This clip",
+ "The clip",
+ "The animation",
+ "This image",
+ "The image",
+ "This picture",
+ "The picture",
+)
+_COMMON_CONTINUATION_WORDS = ("shows", "depicts", "features", "captures", "highlights", "introduces", "presents")
+
+COMMON_LLM_START_PHRASES = (
+ "In the video,",
+ "In this video,",
+ "In this video clip,",
+ "In the clip,",
+ "Caption:",
+ *(
+ f"{beginning} {continuation}"
+ for beginning in _COMMON_BEGINNING_PHRASES
+ for continuation in _COMMON_CONTINUATION_WORDS
+ ),
+)
+
+SUPPORTED_IMAGE_FILE_EXTENSIONS = ("jpg", "jpeg", "png")
+SUPPORTED_VIDEO_FILE_EXTENSIONS = ("mp4", "mov")
diff --git a/docs/finetrainers-src-codebase/finetrainers/data/__init__.py b/docs/finetrainers-src-codebase/finetrainers/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2025f19f9fde243cbbc998cbb58d330d70f9544
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/data/__init__.py
@@ -0,0 +1,26 @@
+from ._artifact import ImageArtifact, VideoArtifact
+from .dataloader import DPDataLoader
+from .dataset import (
+ ImageCaptionFilePairDataset,
+ ImageFileCaptionFileListDataset,
+ ImageFolderDataset,
+ ImageWebDataset,
+ ValidationDataset,
+ VideoCaptionFilePairDataset,
+ VideoFileCaptionFileListDataset,
+ VideoFolderDataset,
+ VideoWebDataset,
+ combine_datasets,
+ initialize_dataset,
+ wrap_iterable_dataset_for_preprocessing,
+)
+from .precomputation import (
+ InMemoryDataIterable,
+ InMemoryDistributedDataPreprocessor,
+ InMemoryOnceDataIterable,
+ PrecomputedDataIterable,
+ PrecomputedDistributedDataPreprocessor,
+ PrecomputedOnceDataIterable,
+ initialize_preprocessor,
+)
+from .sampler import ResolutionSampler
diff --git a/docs/finetrainers-src-codebase/finetrainers/data/_artifact.py b/docs/finetrainers-src-codebase/finetrainers/data/_artifact.py
new file mode 100644
index 0000000000000000000000000000000000000000..400f25d143f5062d77ed6391ca9862654d295de7
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/data/_artifact.py
@@ -0,0 +1,29 @@
+# ===== THIS FILE ONLY EXISTS FOR THE TIME BEING SINCE I DID NOT KNOW WHERE TO PUT IT =====
+
+from dataclasses import dataclass
+from typing import Any, List
+
+from PIL.Image import Image
+
+
+@dataclass
+class Artifact:
+ type: str
+ value: Any
+ file_extension: str
+
+
+@dataclass
+class ImageArtifact(Artifact):
+ value: Image
+
+ def __init__(self, value: Image):
+ super().__init__(type="image", value=value, file_extension="png")
+
+
+@dataclass
+class VideoArtifact(Artifact):
+ value: List[Image]
+
+ def __init__(self, value: List[Image]):
+ super().__init__(type="video", value=value, file_extension="mp4")
diff --git a/docs/finetrainers-src-codebase/finetrainers/data/dataloader.py b/docs/finetrainers-src-codebase/finetrainers/data/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..752229489de69a684b395c79e5a2799c3f747596
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/data/dataloader.py
@@ -0,0 +1,40 @@
+import pickle
+from typing import Any, Dict
+
+import torch.distributed.checkpoint.stateful
+import torchdata.stateful_dataloader
+
+from finetrainers.logging import get_logger
+
+
+logger = get_logger()
+
+
+class DPDataLoader(torchdata.stateful_dataloader.StatefulDataLoader, torch.distributed.checkpoint.stateful.Stateful):
+ def __init__(
+ self,
+ rank: int,
+ dataset: torch.utils.data.IterableDataset,
+ batch_size: int = 1,
+ num_workers: int = 0,
+ collate_fn=None,
+ ) -> None:
+ super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn)
+
+ self._dp_rank = rank
+ self._rank_id = f"dp_rank_{rank}"
+
+ def state_dict(self) -> Dict[str, Any]:
+ # Store state only for dp rank to avoid replicating the same state across other dimensions
+ return {self._rank_id: pickle.dumps(super().state_dict())}
+
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+ # State being empty is valid
+ if not state_dict:
+ return
+
+ if self._rank_id not in state_dict:
+ logger.warning(f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}")
+ return
+
+ super().load_state_dict(pickle.loads(state_dict[self._rank_id]))
diff --git a/docs/finetrainers-src-codebase/finetrainers/data/dataset.py b/docs/finetrainers-src-codebase/finetrainers/data/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..5416e6151a52404d174a9939279420d46dbe232e
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/data/dataset.py
@@ -0,0 +1,1040 @@
+import pathlib
+import random
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import datasets
+import datasets.data_files
+import datasets.distributed
+import datasets.exceptions
+import huggingface_hub
+import huggingface_hub.errors
+import numpy as np
+import PIL.Image
+import PIL.JpegImagePlugin
+import torch
+import torch.distributed.checkpoint.stateful
+import torchvision
+from diffusers.utils import load_image, load_video
+from huggingface_hub import list_repo_files, repo_exists, snapshot_download
+from tqdm.auto import tqdm
+
+from finetrainers import constants
+from finetrainers import functional as FF
+from finetrainers.logging import get_logger
+from finetrainers.utils import find_files
+from finetrainers.utils.import_utils import is_datasets_version
+
+
+import decord # isort:skip
+
+decord.bridge.set_bridge("torch")
+
+logger = get_logger()
+
+
+# fmt: off
+MAX_PRECOMPUTABLE_ITEMS_LIMIT = 1024
+COMMON_CAPTION_FILES = ["prompt.txt", "prompts.txt", "caption.txt", "captions.txt"]
+COMMON_VIDEO_FILES = ["video.txt", "videos.txt"]
+COMMON_IMAGE_FILES = ["image.txt", "images.txt"]
+COMMON_WDS_CAPTION_COLUMN_NAMES = ["txt", "text", "caption", "captions", "short_caption", "long_caption", "prompt", "prompts", "short_prompt", "long_prompt", "description", "descriptions", "alt_text", "alt_texts", "alt_caption", "alt_captions", "alt_prompt", "alt_prompts", "alt_description", "alt_descriptions", "image_description", "image_descriptions", "image_caption", "image_captions", "image_prompt", "image_prompts", "image_alt_text", "image_alt_texts", "image_alt_caption", "image_alt_captions", "image_alt_prompt", "image_alt_prompts", "image_alt_description", "image_alt_descriptions", "video_description", "video_descriptions", "video_caption", "video_captions", "video_prompt", "video_prompts", "video_alt_text", "video_alt_texts", "video_alt_caption", "video_alt_captions", "video_alt_prompt", "video_alt_prompts", "video_alt_description"]
+# fmt: on
+
+
+class ImageCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
+ def __init__(self, root: str, infinite: bool = False) -> None:
+ super().__init__()
+
+ self.root = pathlib.Path(root)
+ self.infinite = infinite
+
+ data = []
+ caption_files = sorted(find_files(self.root.as_posix(), "*.txt", depth=0))
+ for caption_file in caption_files:
+ data_file = self._find_data_file(caption_file)
+ if data_file:
+ data.append(
+ {
+ "caption": (self.root / caption_file).as_posix(),
+ "image": (self.root / data_file).as_posix(),
+ }
+ )
+
+ data = datasets.Dataset.from_list(data)
+ data = data.cast_column("image", datasets.Image(mode="RGB"))
+
+ self._data = data.to_iterable_dataset()
+ self._sample_index = 0
+ self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
+
+ def _get_data_iter(self):
+ if self._sample_index == 0:
+ return iter(self._data)
+ return iter(self._data.skip(self._sample_index))
+
+ def __iter__(self):
+ while True:
+ for sample in self._get_data_iter():
+ self._sample_index += 1
+ sample["caption"] = _read_caption_from_file(sample["caption"])
+ yield sample
+
+ if not self.infinite:
+ logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
+ break
+ else:
+ self._sample_index = 0
+
+ def load_state_dict(self, state_dict):
+ self._sample_index = state_dict["sample_index"]
+
+ def state_dict(self):
+ return {"sample_index": self._sample_index}
+
+ def _find_data_file(self, caption_file: str) -> str:
+ caption_file = pathlib.Path(caption_file)
+ data_file = None
+ found_data = 0
+
+ for extension in constants.SUPPORTED_IMAGE_FILE_EXTENSIONS:
+ image_filename = caption_file.with_suffix(f".{extension}")
+ if image_filename.exists():
+ found_data += 1
+ data_file = image_filename
+
+ if found_data == 0:
+ return False
+ elif found_data > 1:
+ raise ValueError(
+ f"Multiple data files found for caption file {caption_file}. Please ensure there is only one data "
+ f"file per caption file. The following extensions are supported:\n"
+ f" - Images: {constants.SUPPORTED_IMAGE_FILE_EXTENSIONS}\n"
+ )
+
+ return data_file.as_posix()
+
+
+class VideoCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
+ def __init__(self, root: str, infinite: bool = False) -> None:
+ super().__init__()
+
+ self.root = pathlib.Path(root)
+ self.infinite = infinite
+
+ data = []
+ caption_files = sorted(find_files(self.root.as_posix(), "*.txt", depth=0))
+ for caption_file in caption_files:
+ data_file = self._find_data_file(caption_file)
+ if data_file:
+ data.append(
+ {
+ "caption": (self.root / caption_file).as_posix(),
+ "video": (self.root / data_file).as_posix(),
+ }
+ )
+
+ data = datasets.Dataset.from_list(data)
+ data = data.cast_column("video", datasets.Video())
+
+ self._data = data.to_iterable_dataset()
+ self._sample_index = 0
+ self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
+
+ def _get_data_iter(self):
+ if self._sample_index == 0:
+ return iter(self._data)
+ return iter(self._data.skip(self._sample_index))
+
+ def __iter__(self):
+ while True:
+ for sample in self._get_data_iter():
+ self._sample_index += 1
+ sample["caption"] = _read_caption_from_file(sample["caption"])
+ yield sample
+
+ if not self.infinite:
+ logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
+ break
+ else:
+ self._sample_index = 0
+
+ def load_state_dict(self, state_dict):
+ self._sample_index = state_dict["sample_index"]
+
+ def state_dict(self):
+ return {"sample_index": self._sample_index}
+
+ def _find_data_file(self, caption_file: str) -> str:
+ caption_file = pathlib.Path(caption_file)
+ data_file = None
+ found_data = 0
+
+ for extension in constants.SUPPORTED_VIDEO_FILE_EXTENSIONS:
+ video_filename = caption_file.with_suffix(f".{extension}")
+ if video_filename.exists():
+ found_data += 1
+ data_file = video_filename
+
+ if found_data == 0:
+ return False
+ elif found_data > 1:
+ raise ValueError(
+ f"Multiple data files found for caption file {caption_file}. Please ensure there is only one data "
+ f"file per caption file. The following extensions are supported:\n"
+ f" - Videos: {constants.SUPPORTED_VIDEO_FILE_EXTENSIONS}\n"
+ )
+
+ return data_file.as_posix()
+
+
+class ImageFileCaptionFileListDataset(
+ torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful
+):
+ def __init__(self, root: str, infinite: bool = False) -> None:
+ super().__init__()
+
+ VALID_CAPTION_FILES = ["caption.txt", "captions.txt", "prompt.txt", "prompts.txt"]
+ VALID_IMAGE_FILES = ["image.txt", "images.txt"]
+
+ self.root = pathlib.Path(root)
+ self.infinite = infinite
+
+ data = []
+ existing_caption_files = [file for file in VALID_CAPTION_FILES if (self.root / file).exists()]
+ existing_image_files = [file for file in VALID_IMAGE_FILES if (self.root / file).exists()]
+
+ if len(existing_caption_files) == 0:
+ raise FileNotFoundError(
+ f"No caption file found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}"
+ )
+ if len(existing_image_files) == 0:
+ raise FileNotFoundError(
+ f"No image file found in {self.root}. Must have exactly one of {VALID_IMAGE_FILES}"
+ )
+ if len(existing_caption_files) > 1:
+ raise ValueError(
+ f"Multiple caption files found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}"
+ )
+ if len(existing_image_files) > 1:
+ raise ValueError(
+ f"Multiple image files found in {self.root}. Must have exactly one of {VALID_IMAGE_FILES}"
+ )
+
+ caption_file = existing_caption_files[0]
+ image_file = existing_image_files[0]
+
+ with open((self.root / caption_file).as_posix(), "r") as f:
+ captions = f.read().splitlines()
+ with open((self.root / image_file).as_posix(), "r") as f:
+ images = f.read().splitlines()
+ images = [(self.root / image).as_posix() for image in images]
+
+ if len(captions) != len(images):
+ raise ValueError(f"Number of captions ({len(captions)}) must match number of images ({len(images)})")
+
+ for caption, image in zip(captions, images):
+ data.append({"caption": caption, "image": image})
+
+ data = datasets.Dataset.from_list(data)
+ data = data.cast_column("image", datasets.Image(mode="RGB"))
+
+ self._data = data.to_iterable_dataset()
+ self._sample_index = 0
+ self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
+
+ def _get_data_iter(self):
+ if self._sample_index == 0:
+ return iter(self._data)
+ return iter(self._data.skip(self._sample_index))
+
+ def __iter__(self):
+ while True:
+ for sample in self._get_data_iter():
+ self._sample_index += 1
+ yield sample
+
+ if not self.infinite:
+ logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
+ break
+ else:
+ self._sample_index = 0
+
+ def load_state_dict(self, state_dict):
+ self._sample_index = state_dict["sample_index"]
+
+ def state_dict(self):
+ return {"sample_index": self._sample_index}
+
+
+class VideoFileCaptionFileListDataset(
+ torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful
+):
+ def __init__(self, root: str, infinite: bool = False) -> None:
+ super().__init__()
+
+ VALID_CAPTION_FILES = ["caption.txt", "captions.txt", "prompt.txt", "prompts.txt"]
+ VALID_VIDEO_FILES = ["video.txt", "videos.txt"]
+
+ self.root = pathlib.Path(root)
+ self.infinite = infinite
+
+ data = []
+ existing_caption_files = [file for file in VALID_CAPTION_FILES if (self.root / file).exists()]
+ existing_video_files = [file for file in VALID_VIDEO_FILES if (self.root / file).exists()]
+
+ if len(existing_caption_files) == 0:
+ raise FileNotFoundError(
+ f"No caption file found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}"
+ )
+ if len(existing_video_files) == 0:
+ raise FileNotFoundError(
+ f"No video file found in {self.root}. Must have exactly one of {VALID_VIDEO_FILES}"
+ )
+ if len(existing_caption_files) > 1:
+ raise ValueError(
+ f"Multiple caption files found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}"
+ )
+ if len(existing_video_files) > 1:
+ raise ValueError(
+ f"Multiple video files found in {self.root}. Must have exactly one of {VALID_VIDEO_FILES}"
+ )
+
+ caption_file = existing_caption_files[0]
+ video_file = existing_video_files[0]
+
+ with open((self.root / caption_file).as_posix(), "r") as f:
+ captions = f.read().splitlines()
+ with open((self.root / video_file).as_posix(), "r") as f:
+ videos = f.read().splitlines()
+ videos = [(self.root / video).as_posix() for video in videos]
+
+ if len(captions) != len(videos):
+ raise ValueError(f"Number of captions ({len(captions)}) must match number of videos ({len(videos)})")
+
+ for caption, video in zip(captions, videos):
+ data.append({"caption": caption, "video": video})
+
+ data = datasets.Dataset.from_list(data)
+ data = data.cast_column("video", datasets.Video())
+
+ self._data = data.to_iterable_dataset()
+ self._sample_index = 0
+ self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
+
+ def _get_data_iter(self):
+ if self._sample_index == 0:
+ return iter(self._data)
+ return iter(self._data.skip(self._sample_index))
+
+ def __iter__(self):
+ while True:
+ for sample in self._get_data_iter():
+ self._sample_index += 1
+ yield sample
+
+ if not self.infinite:
+ logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
+ break
+ else:
+ self._sample_index = 0
+
+ def load_state_dict(self, state_dict):
+ self._sample_index = state_dict["sample_index"]
+
+ def state_dict(self):
+ return {"sample_index": self._sample_index}
+
+
+class ImageFolderDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
+ def __init__(self, root: str, infinite: bool = False) -> None:
+ super().__init__()
+
+ self.root = pathlib.Path(root)
+ self.infinite = infinite
+
+ data = datasets.load_dataset("imagefolder", data_dir=self.root.as_posix(), split="train")
+
+ self._data = data.to_iterable_dataset()
+ self._sample_index = 0
+ self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
+
+ def _get_data_iter(self):
+ if self._sample_index == 0:
+ return iter(self._data)
+ return iter(self._data.skip(self._sample_index))
+
+ def __iter__(self):
+ while True:
+ for sample in self._get_data_iter():
+ self._sample_index += 1
+ yield sample
+
+ if not self.infinite:
+ logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
+ break
+ else:
+ self._sample_index = 0
+
+ def load_state_dict(self, state_dict):
+ self._sample_index = state_dict["sample_index"]
+
+ def state_dict(self):
+ return {"sample_index": self._sample_index}
+
+
+class VideoFolderDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
+ def __init__(self, root: str, infinite: bool = False) -> None:
+ super().__init__()
+
+ self.root = pathlib.Path(root)
+ self.infinite = infinite
+
+ data = datasets.load_dataset("videofolder", data_dir=self.root.as_posix(), split="train")
+
+ self._data = data.to_iterable_dataset()
+ self._sample_index = 0
+ self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
+
+ def _get_data_iter(self):
+ if self._sample_index == 0:
+ return iter(self._data)
+ return iter(self._data.skip(self._sample_index))
+
+ def __iter__(self):
+ while True:
+ for sample in self._get_data_iter():
+ self._sample_index += 1
+ yield sample
+
+ if not self.infinite:
+ logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
+ break
+ else:
+ self._sample_index = 0
+
+ def load_state_dict(self, state_dict):
+ self._sample_index = state_dict["sample_index"]
+
+ def state_dict(self):
+ return {"sample_index": self._sample_index}
+
+
+class ImageWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
+ def __init__(
+ self,
+ dataset_name: str,
+ infinite: bool = False,
+ column_names: Union[str, List[str]] = "__auto__",
+ weights: Dict[str, float] = -1,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+
+ assert weights == -1 or isinstance(weights, dict), (
+ "`weights` must be a dictionary of probabilities for each caption column"
+ )
+
+ self.dataset_name = dataset_name
+ self.infinite = infinite
+
+ data = datasets.load_dataset(dataset_name, split="train", streaming=True)
+
+ if column_names == "__auto__":
+ if weights == -1:
+ caption_columns = [column for column in data.column_names if column in COMMON_WDS_CAPTION_COLUMN_NAMES]
+ if len(caption_columns) == 0:
+ raise ValueError(
+ f"No common caption column found in the dataset. Supported columns are: {COMMON_WDS_CAPTION_COLUMN_NAMES}. "
+ f"Available columns are: {data.column_names}"
+ )
+ weights = [1] * len(caption_columns)
+ else:
+ caption_columns = list(weights.keys())
+ weights = list(weights.values())
+ if not all(column in data.column_names for column in caption_columns):
+ raise ValueError(
+ f"Caption columns {caption_columns} not found in the dataset. Available columns are: {data.column_names}"
+ )
+ else:
+ if isinstance(column_names, str):
+ if column_names not in data.column_names:
+ raise ValueError(
+ f"Caption column {column_names} not found in the dataset. Available columns are: {data.column_names}"
+ )
+ caption_columns = [column_names]
+ weights = [1] if weights == -1 else [weights.get(column_names)]
+ elif isinstance(column_names, list):
+ if not all(column in data.column_names for column in column_names):
+ raise ValueError(
+ f"Caption columns {column_names} not found in the dataset. Available columns are: {data.column_names}"
+ )
+ caption_columns = column_names
+ weights = [1] if weights == -1 else [weights.get(column) for column in column_names]
+ else:
+ raise ValueError(f"Unsupported type for column_name: {type(column_names)}")
+
+ for column_names in constants.SUPPORTED_IMAGE_FILE_EXTENSIONS:
+ if column_names in data.column_names:
+ data = data.cast_column(column_names, datasets.Image(mode="RGB"))
+ data = data.rename_column(column_names, "image")
+ break
+
+ self._data = data
+ self._sample_index = 0
+ self._precomputable_once = False
+ self._caption_columns = caption_columns
+ self._weights = weights
+
+ def _get_data_iter(self):
+ if self._sample_index == 0:
+ return iter(self._data)
+ return iter(self._data.skip(self._sample_index))
+
+ def __iter__(self):
+ while True:
+ for sample in self._get_data_iter():
+ self._sample_index += 1
+ caption_column = random.choices(self._caption_columns, weights=self._weights, k=1)[0]
+ sample["caption"] = sample[caption_column]
+ yield sample
+
+ if not self.infinite:
+ logger.warning(f"Dataset {self.dataset_name} has run out of data")
+ break
+ else:
+ # Reset offset for the next iteration
+ self._sample_index = 0
+ logger.warning(f"Dataset {self.dataset_name} is being re-looped")
+
+ def load_state_dict(self, state_dict):
+ self._sample_index = state_dict["sample_index"]
+
+ def state_dict(self):
+ return {"sample_index": self._sample_index}
+
+
+class VideoWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
+ def __init__(
+ self,
+ dataset_name: str,
+ infinite: bool = False,
+ column_names: Union[str, List[str]] = "__auto__",
+ weights: Dict[str, float] = -1,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+
+ assert weights == -1 or isinstance(weights, dict), (
+ "`weights` must be a dictionary of probabilities for each caption column"
+ )
+
+ self.dataset_name = dataset_name
+ self.infinite = infinite
+
+ data = datasets.load_dataset(dataset_name, split="train", streaming=True)
+
+ if column_names == "__auto__":
+ if weights == -1:
+ caption_columns = [column for column in data.column_names if column in COMMON_WDS_CAPTION_COLUMN_NAMES]
+ if len(caption_columns) == 0:
+ raise ValueError(
+ f"No common caption column found in the dataset. Supported columns are: {COMMON_WDS_CAPTION_COLUMN_NAMES}"
+ )
+ weights = [1] * len(caption_columns)
+ else:
+ caption_columns = list(weights.keys())
+ weights = list(weights.values())
+ if not all(column in data.column_names for column in caption_columns):
+ raise ValueError(
+ f"Caption columns {caption_columns} not found in the dataset. Available columns are: {data.column_names}"
+ )
+ else:
+ if isinstance(column_names, str):
+ if column_names not in data.column_names:
+ raise ValueError(
+ f"Caption column {column_names} not found in the dataset. Available columns are: {data.column_names}"
+ )
+ caption_columns = [column_names]
+ weights = [1] if weights == -1 else [weights.get(column_names)]
+ elif isinstance(column_names, list):
+ if not all(column in data.column_names for column in column_names):
+ raise ValueError(
+ f"Caption columns {column_names} not found in the dataset. Available columns are: {data.column_names}"
+ )
+ caption_columns = column_names
+ weights = [1] if weights == -1 else [weights.get(column) for column in column_names]
+ else:
+ raise ValueError(f"Unsupported type for column_name: {type(column_names)}")
+
+ for column_names in constants.SUPPORTED_VIDEO_FILE_EXTENSIONS:
+ if column_names in data.column_names:
+ data = data.cast_column(column_names, datasets.Video())
+ data = data.rename_column(column_names, "video")
+ break
+
+ self._data = data
+ self._sample_index = 0
+ self._precomputable_once = False
+ self._caption_columns = caption_columns
+ self._weights = weights
+
+ def _get_data_iter(self):
+ if self._sample_index == 0:
+ return iter(self._data)
+ return iter(self._data.skip(self._sample_index))
+
+ def __iter__(self):
+ while True:
+ for sample in self._get_data_iter():
+ self._sample_index += 1
+ caption_column = random.choices(self._caption_columns, weights=self._weights, k=1)[0]
+ sample["caption"] = sample[caption_column]
+ yield sample
+
+ if not self.infinite:
+ logger.warning(f"Dataset {self.dataset_name} has run out of data")
+ break
+ else:
+ # Reset offset for the next iteration
+ self._sample_index = 0
+ logger.warning(f"Dataset {self.dataset_name} is being re-looped")
+
+ def load_state_dict(self, state_dict):
+ self._sample_index = state_dict["sample_index"]
+
+ def state_dict(self):
+ return {"sample_index": self._sample_index}
+
+
+class ValidationDataset(torch.utils.data.IterableDataset):
+ def __init__(self, filename: str):
+ super().__init__()
+
+ self.filename = pathlib.Path(filename)
+
+ if not self.filename.exists():
+ raise FileNotFoundError(f"File {self.filename.as_posix()} does not exist")
+
+ if self.filename.suffix == ".csv":
+ data = datasets.load_dataset("csv", data_files=self.filename.as_posix(), split="train")
+ elif self.filename.suffix == ".json":
+ data = datasets.load_dataset("json", data_files=self.filename.as_posix(), split="train", field="data")
+ elif self.filename.suffix == ".parquet":
+ data = datasets.load_dataset("parquet", data_files=self.filename.as_posix(), split="train")
+ elif self.filename.suffix == ".arrow":
+ data = datasets.load_dataset("arrow", data_files=self.filename.as_posix(), split="train")
+ else:
+ _SUPPORTED_FILE_FORMATS = [".csv", ".json", ".parquet", ".arrow"]
+ raise ValueError(
+ f"Unsupported file format {self.filename.suffix} for validation dataset. Supported formats are: {_SUPPORTED_FILE_FORMATS}"
+ )
+
+ self._data = data.to_iterable_dataset()
+
+ def __iter__(self):
+ for sample in self._data:
+ # For consistency reasons, we mandate that "caption" is always present in the validation dataset.
+ # However, since the model specifications use "prompt", we create an alias here.
+ sample["prompt"] = sample["caption"]
+
+ # Load image or video if the path is provided
+ # TODO(aryan): need to handle custom columns here for control conditions
+ sample["image"] = None
+ sample["video"] = None
+
+ if sample.get("image_path", None) is not None:
+ image_path = sample["image_path"]
+ if not pathlib.Path(image_path).is_file() and not image_path.startswith("http"):
+ logger.warning(f"Image file {image_path.as_posix()} does not exist.")
+ else:
+ sample["image"] = load_image(sample["image_path"])
+
+ if sample.get("video_path", None) is not None:
+ video_path = sample["video_path"]
+ if not pathlib.Path(video_path).is_file() and not video_path.startswith("http"):
+ logger.warning(f"Video file {video_path.as_posix()} does not exist.")
+ else:
+ sample["video"] = load_video(sample["video_path"])
+
+ if sample.get("control_image_path", None) is not None:
+ control_image_path = sample["control_image_path"]
+ if not pathlib.Path(control_image_path).is_file() and not control_image_path.startswith("http"):
+ logger.warning(f"Control Image file {control_image_path.as_posix()} does not exist.")
+ else:
+ sample["control_image"] = load_image(sample["control_image_path"])
+
+ if sample.get("control_video_path", None) is not None:
+ control_video_path = sample["control_video_path"]
+ if not pathlib.Path(control_video_path).is_file() and not control_video_path.startswith("http"):
+ logger.warning(f"Control Video file {control_video_path.as_posix()} does not exist.")
+ else:
+ sample["control_video"] = load_video(sample["control_video_path"])
+
+ sample = {k: v for k, v in sample.items() if v is not None}
+ yield sample
+
+
+class IterableDatasetPreprocessingWrapper(
+ torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful
+):
+ def __init__(
+ self,
+ dataset: torch.utils.data.IterableDataset,
+ dataset_type: str,
+ id_token: Optional[str] = None,
+ image_resolution_buckets: List[Tuple[int, int]] = None,
+ video_resolution_buckets: List[Tuple[int, int, int]] = None,
+ rename_columns: Optional[Dict[str, str]] = None,
+ drop_columns: Optional[List[str]] = None,
+ reshape_mode: str = "bicubic",
+ remove_common_llm_caption_prefixes: bool = False,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.dataset = dataset
+ self.dataset_type = dataset_type
+ self.id_token = id_token
+ self.image_resolution_buckets = image_resolution_buckets
+ self.video_resolution_buckets = video_resolution_buckets
+ self.rename_columns = rename_columns or {}
+ self.drop_columns = drop_columns or []
+ self.reshape_mode = reshape_mode
+ self.remove_common_llm_caption_prefixes = remove_common_llm_caption_prefixes
+
+ logger.info(
+ f"Initializing IterableDatasetPreprocessingWrapper for the dataset with the following configuration:\n"
+ f" - Dataset Type: {dataset_type}\n"
+ f" - ID Token: {id_token}\n"
+ f" - Image Resolution Buckets: {image_resolution_buckets}\n"
+ f" - Video Resolution Buckets: {video_resolution_buckets}\n"
+ f" - Rename Columns: {rename_columns}\n"
+ f" - Reshape Mode: {reshape_mode}\n"
+ f" - Remove Common LLM Caption Prefixes: {remove_common_llm_caption_prefixes}\n"
+ )
+
+ def __iter__(self):
+ logger.info("Starting IterableDatasetPreprocessingWrapper for the dataset")
+ for sample in iter(self.dataset):
+ for column in self.drop_columns:
+ sample.pop(column, None)
+
+ sample = {self.rename_columns.get(k, k): v for k, v in sample.items()}
+
+ for key in sample.keys():
+ if isinstance(sample[key], PIL.Image.Image):
+ sample[key] = _preprocess_image(sample[key])
+ elif isinstance(sample[key], (decord.VideoReader, torchvision.io.video_reader.VideoReader)):
+ sample[key] = _preprocess_video(sample[key])
+
+ if self.dataset_type == "image":
+ if self.image_resolution_buckets:
+ sample["_original_num_frames"] = 1
+ sample["_original_height"] = sample["image"].size(1)
+ sample["_original_width"] = sample["image"].size(2)
+ sample["image"] = FF.resize_to_nearest_bucket_image(
+ sample["image"], self.image_resolution_buckets, self.reshape_mode
+ )
+ elif self.dataset_type == "video":
+ if self.video_resolution_buckets:
+ sample["_original_num_frames"] = sample["video"].size(0)
+ sample["_original_height"] = sample["video"].size(2)
+ sample["_original_width"] = sample["video"].size(3)
+ sample["video"], _first_frame_only = FF.resize_to_nearest_bucket_video(
+ sample["video"], self.video_resolution_buckets, self.reshape_mode
+ )
+ if _first_frame_only:
+ msg = (
+ "The number of frames in the video is less than the minimum bucket size "
+ "specified. The first frame is being used as a single frame video. This "
+ "message is logged at the first occurence and for every 128th occurence "
+ "after that."
+ )
+ logger.log_freq("WARNING", "BUCKET_TEMPORAL_SIZE_UNAVAILABLE", msg, frequency=128)
+ sample["video"] = sample["video"][:1]
+
+ caption = sample["caption"]
+ if isinstance(caption, list):
+ caption = caption[0]
+ if caption.startswith("b'") and caption.endswith("'"):
+ caption = FF.convert_byte_str_to_str(caption)
+ if self.remove_common_llm_caption_prefixes:
+ caption = FF.remove_prefix(caption, constants.COMMON_LLM_START_PHRASES)
+ if self.id_token is not None:
+ caption = f"{self.id_token} {caption}"
+ sample["caption"] = caption
+
+ yield sample
+
+ def load_state_dict(self, state_dict):
+ self.dataset.load_state_dict(state_dict["dataset"])
+
+ def state_dict(self):
+ return {"dataset": self.dataset.state_dict()}
+
+
+class IterableCombinedDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
+ def __init__(self, datasets: List[torch.utils.data.IterableDataset], buffer_size: int, shuffle: bool = False):
+ super().__init__()
+
+ self.datasets = datasets
+ self.buffer_size = buffer_size
+ self.shuffle = shuffle
+
+ logger.info(
+ f"Initializing IterableCombinedDataset with the following configuration:\n"
+ f" - Number of Datasets: {len(datasets)}\n"
+ f" - Buffer Size: {buffer_size}\n"
+ f" - Shuffle: {shuffle}\n"
+ )
+
+ def __iter__(self):
+ logger.info(f"Starting IterableCombinedDataset with {len(self.datasets)} datasets")
+ iterators = [iter(dataset) for dataset in self.datasets]
+ buffer = []
+ per_iter = max(1, self.buffer_size // len(iterators))
+
+ for index, it in enumerate(iterators):
+ for _ in tqdm(range(per_iter), desc=f"Filling buffer from data iterator {index}"):
+ try:
+ buffer.append((it, next(it)))
+ except StopIteration:
+ continue
+
+ while len(buffer) > 0:
+ idx = 0
+ if self.shuffle:
+ idx = random.randint(0, len(buffer) - 1)
+ current_it, sample = buffer.pop(idx)
+ yield sample
+ try:
+ buffer.append((current_it, next(current_it)))
+ except StopIteration:
+ pass
+
+ def load_state_dict(self, state_dict):
+ for dataset, dataset_state_dict in zip(self.datasets, state_dict["datasets"]):
+ dataset.load_state_dict(dataset_state_dict)
+
+ def state_dict(self):
+ return {"datasets": [dataset.state_dict() for dataset in self.datasets]}
+
+
+# TODO(aryan): maybe write a test for this
+def initialize_dataset(
+ dataset_name_or_root: str,
+ dataset_type: str = "video",
+ streaming: bool = True,
+ infinite: bool = False,
+ *,
+ _caption_options: Optional[Dict[str, Any]] = None,
+) -> torch.utils.data.IterableDataset:
+ assert dataset_type in ["image", "video"]
+
+ try:
+ does_repo_exist_on_hub = repo_exists(dataset_name_or_root, repo_type="dataset")
+ except huggingface_hub.errors.HFValidationError:
+ does_repo_exist_on_hub = False
+
+ if does_repo_exist_on_hub:
+ return _initialize_hub_dataset(dataset_name_or_root, dataset_type, infinite, _caption_options=_caption_options)
+ else:
+ return _initialize_local_dataset(
+ dataset_name_or_root, dataset_type, infinite, _caption_options=_caption_options
+ )
+
+
+def combine_datasets(
+ datasets: List[torch.utils.data.IterableDataset], buffer_size: int, shuffle: bool = False
+) -> torch.utils.data.IterableDataset:
+ return IterableCombinedDataset(datasets=datasets, buffer_size=buffer_size, shuffle=shuffle)
+
+
+def wrap_iterable_dataset_for_preprocessing(
+ dataset: torch.utils.data.IterableDataset, dataset_type: str, config: Dict[str, Any]
+) -> torch.utils.data.IterableDataset:
+ return IterableDatasetPreprocessingWrapper(dataset, dataset_type, **config)
+
+
+def _initialize_local_dataset(
+ dataset_name_or_root: str,
+ dataset_type: str,
+ infinite: bool = False,
+ *,
+ _caption_options: Optional[Dict[str, Any]] = None,
+):
+ root = pathlib.Path(dataset_name_or_root)
+ supported_metadata_files = ["metadata.json", "metadata.jsonl", "metadata.csv"]
+ metadata_files = [root / metadata_file for metadata_file in supported_metadata_files]
+ metadata_files = [metadata_file for metadata_file in metadata_files if metadata_file.exists()]
+
+ if len(metadata_files) > 1:
+ raise ValueError("Found multiple metadata files. Please ensure there is only one metadata file.")
+
+ if len(metadata_files) == 1:
+ if dataset_type == "image":
+ dataset = ImageFolderDataset(root.as_posix(), infinite=infinite)
+ else:
+ dataset = VideoFolderDataset(root.as_posix(), infinite=infinite)
+ return dataset
+
+ file_list = find_files(root.as_posix(), "*", depth=100)
+ has_tar_or_parquet_files = any(file.endswith(".tar") or file.endswith(".parquet") for file in file_list)
+ if has_tar_or_parquet_files:
+ return _initialize_webdataset(root.as_posix(), dataset_type, infinite, _caption_options=_caption_options)
+
+ if _has_data_caption_file_pairs(root, remote=False):
+ if dataset_type == "image":
+ dataset = ImageCaptionFilePairDataset(root.as_posix(), infinite=infinite)
+ else:
+ dataset = VideoCaptionFilePairDataset(root.as_posix(), infinite=infinite)
+ elif _has_data_file_caption_file_lists(root, remote=False):
+ if dataset_type == "image":
+ dataset = ImageFileCaptionFileListDataset(root.as_posix(), infinite=infinite)
+ else:
+ dataset = VideoFileCaptionFileListDataset(root.as_posix(), infinite=infinite)
+ else:
+ raise ValueError(
+ f"Could not find any supported dataset structure in the directory {root}. Please open an issue at "
+ f"https://github.com/a-r-r-o-w/finetrainers with information about your dataset structure and we will "
+ f"help you set it up."
+ )
+
+ return dataset
+
+
+def _initialize_hub_dataset(
+ dataset_name: str, dataset_type: str, infinite: bool = False, *, _caption_options: Optional[Dict[str, Any]] = None
+):
+ repo_file_list = list_repo_files(dataset_name, repo_type="dataset")
+ if _has_data_caption_file_pairs(repo_file_list, remote=True):
+ return _initialize_data_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
+ elif _has_data_file_caption_file_lists(repo_file_list, remote=True):
+ return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
+
+ has_tar_or_parquet_files = any(file.endswith(".tar") or file.endswith(".parquet") for file in repo_file_list)
+ if has_tar_or_parquet_files:
+ return _initialize_webdataset(dataset_name, dataset_type, infinite, _caption_options=_caption_options)
+
+ # TODO(aryan): This should be improved
+ caption_files = [pathlib.Path(file).name for file in repo_file_list if file.endswith(".txt")]
+ if len(caption_files) < MAX_PRECOMPUTABLE_ITEMS_LIMIT:
+ try:
+ dataset_root = snapshot_download(dataset_name, repo_type="dataset")
+ if dataset_type == "image":
+ dataset = ImageFolderDataset(dataset_root, infinite=infinite)
+ else:
+ dataset = VideoFolderDataset(dataset_root, infinite=infinite)
+ return dataset
+ except Exception:
+ pass
+
+ raise ValueError(f"Could not load dataset {dataset_name} from the HF Hub")
+
+
+def _initialize_data_caption_file_dataset_from_hub(
+ dataset_name: str, dataset_type: str, infinite: bool = False
+) -> torch.utils.data.IterableDataset:
+ logger.info(f"Downloading dataset {dataset_name} from the HF Hub")
+ dataset_root = snapshot_download(dataset_name, repo_type="dataset")
+ if dataset_type == "image":
+ return ImageCaptionFilePairDataset(dataset_root, infinite=infinite)
+ else:
+ return VideoCaptionFilePairDataset(dataset_root, infinite=infinite)
+
+
+def _initialize_data_file_caption_file_dataset_from_hub(
+ dataset_name: str, dataset_type: str, infinite: bool = False
+) -> torch.utils.data.IterableDataset:
+ logger.info(f"Downloading dataset {dataset_name} from the HF Hub")
+ dataset_root = snapshot_download(dataset_name, repo_type="dataset")
+ if dataset_type == "image":
+ return ImageFileCaptionFileListDataset(dataset_root, infinite=infinite)
+ else:
+ return VideoFileCaptionFileListDataset(dataset_root, infinite=infinite)
+
+
+def _initialize_webdataset(
+ dataset_name: str, dataset_type: str, infinite: bool = False, _caption_options: Optional[Dict[str, Any]] = None
+) -> torch.utils.data.IterableDataset:
+ logger.info(f"Streaming webdataset {dataset_name} from the HF Hub")
+ _caption_options = _caption_options or {}
+ if dataset_type == "image":
+ return ImageWebDataset(dataset_name, infinite=infinite, **_caption_options)
+ else:
+ return VideoWebDataset(dataset_name, infinite=infinite, **_caption_options)
+
+
+def _has_data_caption_file_pairs(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool:
+ # TODO(aryan): this logic can be improved
+ if not remote:
+ caption_files = find_files(root.as_posix(), "*.txt", depth=0)
+ for caption_file in caption_files:
+ caption_file = pathlib.Path(caption_file)
+ for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]:
+ data_filename = caption_file.with_suffix(f".{extension}")
+ if data_filename.exists():
+ return True
+ return False
+ else:
+ caption_files = [file for file in root if file.endswith(".txt")]
+ for caption_file in caption_files:
+ caption_file = pathlib.Path(caption_file)
+ for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]:
+ data_filename = caption_file.with_suffix(f".{extension}").name
+ if data_filename in root:
+ return True
+ return False
+
+
+def _has_data_file_caption_file_lists(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool:
+ # TODO(aryan): this logic can be improved
+ if not remote:
+ file_list = {x.name for x in root.iterdir()}
+ has_caption_files = any(file in file_list for file in COMMON_CAPTION_FILES)
+ has_video_files = any(file in file_list for file in COMMON_VIDEO_FILES)
+ has_image_files = any(file in file_list for file in COMMON_IMAGE_FILES)
+ return has_caption_files and (has_video_files or has_image_files)
+ else:
+ has_caption_files = any(file in root for file in COMMON_CAPTION_FILES)
+ has_video_files = any(file in root for file in COMMON_VIDEO_FILES)
+ has_image_files = any(file in root for file in COMMON_IMAGE_FILES)
+ return has_caption_files and (has_video_files or has_image_files)
+
+
+def _read_caption_from_file(filename: str) -> str:
+ with open(filename, "r") as f:
+ return f.read().strip()
+
+
+def _preprocess_image(image: PIL.Image.Image) -> torch.Tensor:
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.float32)
+ image = torch.from_numpy(image)
+ image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0
+ return image
+
+
+if is_datasets_version("<", "3.4.0"):
+
+ def _preprocess_video(video: decord.VideoReader) -> torch.Tensor:
+ video = video.get_batch(list(range(len(video))))
+ video = video.permute(0, 3, 1, 2).contiguous()
+ video = video.float() / 127.5 - 1.0
+ return video
+
+else:
+ # Hardcode max frames for now. Ideally, we should allow user to set this and handle it in IterableDatasetPreprocessingWrapper
+ MAX_FRAMES = 4096
+
+ def _preprocess_video(video: torchvision.io.video_reader.VideoReader) -> torch.Tensor:
+ frames = []
+ # Error driven data loading! torchvision does not expose length of video
+ try:
+ for _ in range(MAX_FRAMES):
+ frames.append(next(video)["data"])
+ except StopIteration:
+ pass
+ video = torch.stack(frames)
+ video = video.float() / 127.5 - 1.0
+ return video
diff --git a/docs/finetrainers-src-codebase/finetrainers/data/precomputation.py b/docs/finetrainers-src-codebase/finetrainers/data/precomputation.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a33a80603558befd85ca43abf0a1ce39c6d94cd
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/data/precomputation.py
@@ -0,0 +1,420 @@
+import pathlib
+from typing import Any, Callable, Dict, Iterable, List, Optional, Union
+
+import torch
+from tqdm.auto import tqdm
+
+from finetrainers.logging import get_logger
+from finetrainers.utils import delete_files
+
+
+logger = get_logger()
+
+PRECOMPUTED_DATA_DIR = "finetrainers-precomputed-data"
+
+
+def initialize_preprocessor(
+ rank: int,
+ world_size: int,
+ num_items: int,
+ processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]],
+ save_dir: Optional[str] = None,
+ enable_precomputation: bool = False,
+ enable_reuse: bool = False,
+) -> Union["InMemoryDistributedDataPreprocessor", "PrecomputedDistributedDataPreprocessor"]:
+ if enable_precomputation:
+ return PrecomputedDistributedDataPreprocessor(
+ rank, world_size, num_items, processor_fn, save_dir, enable_reuse
+ )
+ return InMemoryDistributedDataPreprocessor(rank, num_items, processor_fn)
+
+
+class DistributedDataProcessorMixin:
+ def consume(self, *args, **kwargs):
+ raise NotImplementedError("DistributedDataProcessorMixin::consume must be implemented by the subclass.")
+
+ def consume_once(self, *args, **kwargs):
+ raise NotImplementedError("DistributedDataProcessorMixin::consume_once must be implemented by the subclass.")
+
+ @property
+ def requires_data(self):
+ raise NotImplementedError("DistributedDataProcessorMixin::requires_data must be implemented by the subclass.")
+
+
+class InMemoryDistributedDataPreprocessor(DistributedDataProcessorMixin):
+ def __init__(
+ self, rank: int, num_items: int, processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]]
+ ) -> None:
+ super().__init__()
+
+ self._rank = rank
+ self._num_items = num_items
+ self._processor_fn = processor_fn
+
+ self._cached_samples = []
+ self._buffer = InMemoryDataBuffer(num_items)
+ self._preprocessed_iterator: Union["InMemoryDataIterable", "InMemoryOnceDataIterable"] = None
+
+ def consume(
+ self,
+ data_type: str,
+ components: Dict[str, Any],
+ data_iterator,
+ generator: Optional[torch.Generator] = None,
+ cache_samples: bool = False,
+ use_cached_samples: bool = False,
+ drop_samples: bool = False,
+ ) -> Iterable[Dict[str, Any]]:
+ if data_type not in self._processor_fn.keys():
+ raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}")
+ if cache_samples:
+ if use_cached_samples:
+ raise ValueError("Cannot cache and use cached samples at the same time.")
+ if drop_samples:
+ raise ValueError("Cannot cache and drop samples at the same time.")
+
+ for i in range(self._num_items):
+ if use_cached_samples:
+ item = self._cached_samples[i]
+ else:
+ item = next(data_iterator)
+ if cache_samples:
+ self._cached_samples.append(item)
+ item = self._processor_fn[data_type](**item, **components, generator=generator)
+ self._buffer.add(data_type, item)
+
+ if drop_samples:
+ del self._cached_samples
+ self._cached_samples = []
+
+ self._preprocessed_iterator = InMemoryDataIterable(self._rank, data_type, self._buffer)
+ return iter(self._preprocessed_iterator)
+
+ def consume_once(
+ self,
+ data_type: str,
+ components: Dict[str, Any],
+ data_iterator,
+ generator: Optional[torch.Generator] = None,
+ cache_samples: bool = False,
+ use_cached_samples: bool = False,
+ drop_samples: bool = False,
+ ) -> Iterable[Dict[str, Any]]:
+ if data_type not in self._processor_fn.keys():
+ raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}")
+ if cache_samples:
+ if use_cached_samples:
+ raise ValueError("Cannot cache and use cached samples at the same time.")
+ if drop_samples:
+ raise ValueError("Cannot cache and drop samples at the same time.")
+
+ for i in range(self._num_items):
+ if use_cached_samples:
+ item = self._cached_samples[i]
+ else:
+ item = next(data_iterator)
+ if cache_samples:
+ self._cached_samples.append(item)
+ item = self._processor_fn[data_type](**item, **components, generator=generator)
+ self._buffer.add(data_type, item)
+
+ if drop_samples:
+ del self._cached_samples
+ self._cached_samples = []
+
+ self._preprocessed_iterator = InMemoryOnceDataIterable(self._rank, data_type, self._buffer)
+ return iter(self._preprocessed_iterator)
+
+ @property
+ def requires_data(self):
+ if self._preprocessed_iterator is None:
+ return True
+ return self._preprocessed_iterator.requires_data
+
+
+class PrecomputedDistributedDataPreprocessor(DistributedDataProcessorMixin):
+ def __init__(
+ self,
+ rank: int,
+ world_size: int,
+ num_items: int,
+ processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]],
+ save_dir: str,
+ enable_reuse: bool = False,
+ ) -> None:
+ super().__init__()
+
+ self._rank = rank
+ self._world_size = world_size
+ self._num_items = num_items
+ self._processor_fn = processor_fn
+ self._save_dir = pathlib.Path(save_dir) / PRECOMPUTED_DATA_DIR
+ self._enable_reuse = enable_reuse
+
+ self._cached_samples = []
+ self._preprocessed_iterator: Union["PrecomputedDataIterable", "PrecomputedOnceDataIterable"] = None
+
+ if enable_reuse:
+ if not self._save_dir.exists() or not self._save_dir.is_dir():
+ raise RuntimeError(
+ f"The directory '{self._save_dir}' does not exist or is not a directory, but is required when enabling reuse of precomputed data."
+ )
+ logger.info(f"Reusing precomputed data from {self._save_dir}.")
+ else:
+ subdirectories = [] if not self._save_dir.exists() else [f for f in self._save_dir.iterdir() if f.is_dir()]
+ if len(subdirectories) > 0:
+ raise RuntimeError(
+ "The current directory contains subdirectories other than the saved precomputed files. Please remove them or enable precomputation reuse."
+ )
+ # NOTE: this should be safe since we are adding `PRECOMPUTED_DATA_DIR` to the path, but be careful since
+ # delete_files can seriously mess up your filesystem if used incorrectly.
+ delete_files([self._save_dir])
+ self._save_dir.mkdir(parents=True, exist_ok=True)
+ logger.info(f"Cleaned up any existing precomputed data in {self._save_dir} and created a fresh directory.")
+
+ def consume(
+ self,
+ data_type: str,
+ components: Dict[str, Any],
+ data_iterator,
+ generator: Optional[torch.Generator] = None,
+ cache_samples: bool = False,
+ use_cached_samples: bool = False,
+ drop_samples: bool = False,
+ ) -> Iterable[Dict[str, Any]]:
+ if data_type not in self._processor_fn.keys():
+ raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}")
+ if cache_samples:
+ if use_cached_samples:
+ raise ValueError("Cannot cache and use cached samples at the same time.")
+ if drop_samples:
+ raise ValueError("Cannot cache and drop samples at the same time.")
+
+ if not self._enable_reuse:
+ for i in tqdm(range(self._num_items), desc=f"Rank {self._rank}", total=self._num_items):
+ if use_cached_samples:
+ item = self._cached_samples[i]
+ else:
+ item = next(data_iterator)
+ if cache_samples:
+ self._cached_samples.append(item)
+ item = self._processor_fn[data_type](**item, **components, generator=generator)
+ index = self._rank * self._num_items + i
+ _save_item(item, index, self._save_dir, data_type)
+
+ if drop_samples:
+ del self._cached_samples
+ self._cached_samples = []
+
+ if self._enable_reuse:
+ data_iterator = PrecomputedOnceDataIterable(self._rank, self._world_size, self._save_dir, data_type)
+ else:
+ data_iterator = PrecomputedDataIterable(self._rank, self._world_size, self._save_dir, data_type)
+ self._preprocessed_iterator = data_iterator
+ return iter(data_iterator)
+
+ def consume_once(
+ self,
+ data_type: str,
+ components: Dict[str, Any],
+ data_iterator,
+ generator: Optional[torch.Generator] = None,
+ cache_samples: bool = False,
+ use_cached_samples: bool = False,
+ drop_samples: bool = False,
+ ) -> Iterable[Dict[str, Any]]:
+ if data_type not in self._processor_fn.keys():
+ raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}")
+ if cache_samples:
+ if use_cached_samples:
+ raise ValueError("Cannot cache and use cached samples at the same time.")
+ if drop_samples:
+ raise ValueError("Cannot cache and drop samples at the same time.")
+
+ if not self._enable_reuse:
+ for i in tqdm(range(self._num_items), desc=f"Processing data on rank {self._rank}", total=self._num_items):
+ if use_cached_samples:
+ item = self._cached_samples[i]
+ else:
+ item = next(data_iterator)
+ if cache_samples:
+ self._cached_samples.append(item)
+ item = self._processor_fn[data_type](**item, **components, generator=generator)
+ index = self._rank * self._num_items + i
+ _save_item(item, index, self._save_dir, data_type)
+
+ if drop_samples:
+ del self._cached_samples
+ self._cached_samples = []
+
+ self._preprocessed_iterator = PrecomputedOnceDataIterable(
+ self._rank, self._world_size, self._save_dir, data_type
+ )
+ return iter(self._preprocessed_iterator)
+
+ @property
+ def requires_data(self):
+ if self._preprocessed_iterator is None:
+ return True
+ return self._preprocessed_iterator.requires_data
+
+
+class InMemoryDataIterable:
+ """
+ An iterator that loads data items from an in-memory buffer. Once all the data is consumed,
+ `requires_data` is set to True, indicating that the more data is required and the preprocessor's
+ consume method should be called again.
+ """
+
+ def __init__(self, rank: int, data_type: str, buffer: "InMemoryDataBuffer") -> None:
+ self._rank = rank
+ self._data_type = data_type
+ self._buffer = buffer
+
+ self._requires_data = False
+
+ def __iter__(self) -> Iterable[Dict[str, Any]]:
+ while (length := self._buffer.get_length(self._data_type)) > 0:
+ if length <= 1:
+ self._requires_data = True
+ yield self._buffer.get(self._data_type)
+
+ def __len__(self) -> int:
+ return self._buffer.get_length(self._data_type)
+
+ @property
+ def requires_data(self):
+ return self._requires_data
+
+
+class InMemoryOnceDataIterable:
+ """
+ An iterator that loads data items from an in-memory buffer. This iterator will never set
+ `requires_data` to True, as it is assumed that all the data was configured to be preprocessed
+ by the user. The data will indefinitely be cycled from the buffer.
+ """
+
+ def __init__(self, rank: int, data_type: str, buffer: "InMemoryDataBuffer") -> None:
+ self._rank = rank
+ self._data_type = data_type
+ self._buffer = buffer
+
+ self._requires_data = False
+
+ def __iter__(self) -> Iterable[Dict[str, Any]]:
+ assert len(self) > 0, "No data available in the buffer."
+ while True:
+ item = self._buffer.get(self._data_type)
+ yield item
+ self._buffer.add(self._data_type, item)
+
+ def __len__(self) -> int:
+ return self._buffer.get_length(self._data_type)
+
+ @property
+ def requires_data(self):
+ return self._requires_data
+
+
+class PrecomputedDataIterable:
+ """
+ An iterator that loads preconfigured number of data items from disk. Once all the data is
+ loaded, `requires_data` is set to True, indicating that the more data is required and
+ the preprocessor's consume method should be called again.
+ """
+
+ def __init__(self, rank: int, world_size: int, save_dir: str, data_type: str) -> None:
+ self._rank = rank
+ self._world_size = world_size
+ self._save_dir = pathlib.Path(save_dir)
+ self._data_type = data_type
+ self._requires_data = False
+
+ self._num_items = len(list(self._save_dir.glob(f"{data_type}-*.pt")))
+
+ def __iter__(self) -> Iterable[Dict[str, Any]]:
+ map_location = torch.device(self._rank)
+ for i in range(self._num_items):
+ if i == self._num_items - 1:
+ self._requires_data = True
+ index = self._rank * self._num_items + i
+ yield _load_item(index, self._save_dir, self._data_type, map_location)
+
+ def __len__(self) -> int:
+ return self._num_items
+
+ @property
+ def requires_data(self):
+ return self._requires_data
+
+
+class PrecomputedOnceDataIterable:
+ """
+ An infinite iterator that loads preprocessed data from disk. Once initialized, this iterator
+ will never set `requires_data` to True, as it is assumed that all the data was configured to
+ be preprocessed by the user.
+ """
+
+ def __init__(self, rank: int, world_size: int, save_dir: str, data_type: str) -> None:
+ self._rank = rank
+ self._world_size = world_size
+ self._save_dir = pathlib.Path(save_dir)
+ self._data_type = data_type
+ self._requires_data = False
+
+ self._num_items = len(list(self._save_dir.glob(f"{data_type}-*.pt")))
+ if self._num_items <= self._rank:
+ raise ValueError(
+ f"Precomputed data directory is empty or does not contain enough items (required {self._rank + 1}, found {self._num_items})."
+ )
+ self._num_items_per_rank = max(1, self._num_items // world_size)
+
+ def __iter__(self) -> Iterable[Dict[str, Any]]:
+ map_location = torch.device(self._rank)
+ i = 0
+ while True:
+ index = self._rank * self._num_items_per_rank + i
+ yield _load_item(index, self._save_dir, self._data_type, map_location)
+ i = (i + 1) % self._num_items_per_rank
+
+ def __len__(self) -> int:
+ return self._num_items_per_rank
+
+ @property
+ def requires_data(self):
+ return self._requires_data
+
+
+class InMemoryDataBuffer:
+ def __init__(self, max_limit: int = -1) -> None:
+ self.max_limit = max_limit
+ self.buffer: Dict[str, List[str]] = {}
+
+ def add(self, data_type: str, item: Dict[str, Any]) -> None:
+ if data_type not in self.buffer:
+ self.buffer[data_type] = []
+ if self.max_limit != -1 and len(self.buffer[data_type]) >= self.max_limit:
+ logger.log_freq(
+ "WARN",
+ "IN_MEMORY_DATA_BUFFER_FULL",
+ "Buffer is full. Dropping the oldest item. This message will be logged every 64th time this happens.",
+ 64,
+ )
+ self.buffer[data_type].pop(0)
+ self.buffer[data_type].append(item)
+
+ def get(self, data_type: str) -> Dict[str, Any]:
+ return self.buffer[data_type].pop(0)
+
+ def get_length(self, data_type: str) -> int:
+ return len(self.buffer[data_type])
+
+
+def _save_item(item: Dict[str, Any], index: int, directory: pathlib.Path, data_type: str) -> None:
+ filename = directory / f"{data_type}-{index}.pt"
+ torch.save(item, filename.as_posix())
+
+
+def _load_item(index: int, directory: pathlib.Path, data_type: str, map_location=None) -> Dict[str, Any]:
+ filename = directory / f"{data_type}-{index}.pt"
+ return torch.load(filename.as_posix(), map_location=map_location, weights_only=True)
diff --git a/docs/finetrainers-src-codebase/finetrainers/data/sampler.py b/docs/finetrainers-src-codebase/finetrainers/data/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d9d650e1d610e8ce91b4168a9960479cfcfe8f7
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/data/sampler.py
@@ -0,0 +1,58 @@
+from typing import Any, Dict, List, Tuple
+
+import torch
+
+
+class ResolutionSampler:
+ def __init__(self, batch_size: int = 1, dim_keys: Dict[str, Tuple[int, ...]] = None) -> None:
+ self.batch_size = batch_size
+ self.dim_keys = dim_keys
+ assert dim_keys is not None, "dim_keys must be provided"
+
+ self._chosen_leader_key = None
+ self._unsatisfied_buckets: Dict[Tuple[int, ...], List[Dict[Any, Any]]] = {}
+ self._satisfied_buckets: List[Dict[Any, Any]] = []
+
+ def consume(self, *dict_items: Dict[Any, Any]) -> None:
+ if self._chosen_leader_key is None:
+ self._determine_leader_item(*dict_items)
+ self._update_buckets(*dict_items)
+
+ def get_batch(self) -> List[Dict[str, Any]]:
+ return list(zip(*self._satisfied_buckets.pop(-1)))
+
+ @property
+ def is_ready(self) -> bool:
+ return len(self._satisfied_buckets) > 0
+
+ def _determine_leader_item(self, *dict_items: Dict[Any, Any]) -> None:
+ num_observed = 0
+ for dict_item in dict_items:
+ for key in self.dim_keys.keys():
+ if key in dict_item.keys():
+ self._chosen_leader_key = key
+ if not torch.is_tensor(dict_item[key]):
+ raise ValueError(f"Leader key {key} must be a tensor")
+ num_observed += 1
+ if num_observed > 1:
+ raise ValueError(
+ f"Only one leader key is allowed in provided list of data dictionaries. Found {num_observed} leader keys"
+ )
+ if self._chosen_leader_key is None:
+ raise ValueError("No leader key found in provided list of data dictionaries")
+
+ def _update_buckets(self, *dict_items: Dict[Any, Any]) -> None:
+ chosen_value = [
+ dict_item[self._chosen_leader_key]
+ for dict_item in dict_items
+ if self._chosen_leader_key in dict_item.keys()
+ ]
+ if len(chosen_value) == 0:
+ raise ValueError(f"Leader key {self._chosen_leader_key} not found in provided list of data dictionaries")
+ chosen_value = chosen_value[0]
+ dims = tuple(chosen_value.size(x) for x in self.dim_keys[self._chosen_leader_key])
+ if dims not in self._unsatisfied_buckets:
+ self._unsatisfied_buckets[dims] = []
+ self._unsatisfied_buckets[dims].append(dict_items)
+ if len(self._unsatisfied_buckets[dims]) == self.batch_size:
+ self._satisfied_buckets.append(self._unsatisfied_buckets.pop(dims))
diff --git a/docs/finetrainers-src-codebase/finetrainers/functional/__init__.py b/docs/finetrainers-src-codebase/finetrainers/functional/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca7b9b22ae0578612e1ebc54b550d86b04eba99c
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/functional/__init__.py
@@ -0,0 +1,17 @@
+from .diffusion import flow_match_target, flow_match_xt
+from .image import (
+ bicubic_resize_image,
+ center_crop_image,
+ find_nearest_resolution_image,
+ resize_crop_image,
+ resize_to_nearest_bucket_image,
+)
+from .normalization import normalize
+from .text import convert_byte_str_to_str, dropout_caption, dropout_embeddings_to_zero, remove_prefix
+from .video import (
+ bicubic_resize_video,
+ center_crop_video,
+ find_nearest_video_resolution,
+ resize_crop_video,
+ resize_to_nearest_bucket_video,
+)
diff --git a/docs/finetrainers-src-codebase/finetrainers/functional/diffusion.py b/docs/finetrainers-src-codebase/finetrainers/functional/diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9d553895c2fb251abf80f01f284049acf84f87d
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/functional/diffusion.py
@@ -0,0 +1,11 @@
+import torch
+
+
+def flow_match_xt(x0: torch.Tensor, n: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
+ r"""Forward process of flow matching."""
+ return (1.0 - t) * x0 + t * n
+
+
+def flow_match_target(n: torch.Tensor, x0: torch.Tensor) -> torch.Tensor:
+ r"""Loss target for flow matching."""
+ return n - x0
diff --git a/docs/finetrainers-src-codebase/finetrainers/functional/image.py b/docs/finetrainers-src-codebase/finetrainers/functional/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..be2b024be001045171ec897064ab51433f875e0e
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/functional/image.py
@@ -0,0 +1,56 @@
+from typing import List, Literal, Tuple
+
+import torch
+import torch.nn.functional as F
+
+
+def center_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
+ num_channels, height, width = image.shape
+ crop_h, crop_w = size
+ if height < crop_h or width < crop_w:
+ raise ValueError(f"Image size {(height, width)} is smaller than the target size {size}.")
+ top = (height - crop_h) // 2
+ left = (width - crop_w) // 2
+ return image[:, top : top + crop_h, left : left + crop_w]
+
+
+def resize_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
+ num_channels, height, width = image.shape
+ target_h, target_w = size
+ scale = max(target_h / height, target_w / width)
+ new_h, new_w = int(height * scale), int(width * scale)
+ image = F.interpolate(image, size=(new_h, new_w), mode="bilinear", align_corners=False)
+ return center_crop_image(image, size)
+
+
+def bicubic_resize_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
+ return F.interpolate(image.unsqueeze(0), size=size, mode="bicubic", align_corners=False)[0]
+
+
+def find_nearest_resolution_image(image: torch.Tensor, resolution_buckets: List[Tuple[int, int]]) -> Tuple[int, int]:
+ num_channels, height, width = image.shape
+ aspect_ratio = width / height
+
+ def aspect_ratio_diff(bucket):
+ return abs((bucket[1] / bucket[0]) - aspect_ratio), (-bucket[0], -bucket[1])
+
+ return min(resolution_buckets, key=aspect_ratio_diff)
+
+
+def resize_to_nearest_bucket_image(
+ image: torch.Tensor,
+ resolution_buckets: List[Tuple[int, int]],
+ resize_mode: Literal["center_crop", "resize_crop", "bicubic"] = "bicubic",
+) -> torch.Tensor:
+ target_size = find_nearest_resolution_image(image, resolution_buckets)
+
+ if resize_mode == "center_crop":
+ return center_crop_image(image, target_size)
+ elif resize_mode == "resize_crop":
+ return resize_crop_image(image, target_size)
+ elif resize_mode == "bicubic":
+ return bicubic_resize_image(image, target_size)
+ else:
+ raise ValueError(
+ f"Invalid resize_mode: {resize_mode}. Choose from 'center_crop', 'resize_crop', or 'bicubic'."
+ )
diff --git a/docs/finetrainers-src-codebase/finetrainers/functional/normalization.py b/docs/finetrainers-src-codebase/finetrainers/functional/normalization.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3433b7636dff5c3d89d17fb94e487d936658741
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/functional/normalization.py
@@ -0,0 +1,37 @@
+from typing import Optional
+
+import torch
+
+
+def normalize(x: torch.Tensor, min: float = -1.0, max: float = 1.0, dim: Optional[int] = None) -> torch.Tensor:
+ """
+ Normalize a tensor to the range [min_val, max_val].
+
+ Args:
+ x (`torch.Tensor`):
+ The input tensor to normalize.
+ min (`float`, defaults to `-1.0`):
+ The minimum value of the normalized range.
+ max (`float`, defaults to `1.0`):
+ The maximum value of the normalized range.
+ dim (`int`, *optional*):
+ The dimension along which to normalize. If `None`, the entire tensor is normalized.
+
+ Returns:
+ The normalized tensor of the same shape as `x`.
+ """
+ if dim is None:
+ x_min = x.min()
+ x_max = x.max()
+ if torch.isclose(x_min, x_max).any():
+ x = torch.full_like(x, min)
+ else:
+ x = min + (max - min) * (x - x_min) / (x_max - x_min)
+ else:
+ x_min = x.amin(dim=dim, keepdim=True)
+ x_max = x.amax(dim=dim, keepdim=True)
+ if torch.isclose(x_min, x_max).any():
+ x = torch.full_like(x, min)
+ else:
+ x = min + (max - min) * (x - x_min) / (x_max - x_min)
+ return x
diff --git a/docs/finetrainers-src-codebase/finetrainers/functional/text.py b/docs/finetrainers-src-codebase/finetrainers/functional/text.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd319aba5437730be6b5b4d20c0de4de2ae9173c
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/functional/text.py
@@ -0,0 +1,40 @@
+import random
+from typing import List, Union
+
+import torch
+
+
+def convert_byte_str_to_str(s: str, encoding: str = "utf-8") -> str:
+ """
+ Extracts the actual string from a stringified bytes array (common in some webdatasets).
+
+ Example: "b'hello world'" -> "hello world"
+ """
+ try:
+ s = s[2:-1]
+ s = s.encode("utf-8").decode(encoding)
+ except (UnicodeDecodeError, UnicodeEncodeError, IndexError):
+ pass
+ return s
+
+
+def dropout_caption(caption: Union[str, List[str]], dropout_p: float = 0) -> Union[str, List[str]]:
+ if random.random() >= dropout_p:
+ return caption
+ if isinstance(caption, str):
+ return ""
+ return [""] * len(caption)
+
+
+def dropout_embeddings_to_zero(embed: torch.Tensor, dropout_p: float = 0) -> torch.Tensor:
+ if random.random() >= dropout_p:
+ return embed
+ embed = torch.zeros_like(embed)
+ return embed
+
+
+def remove_prefix(text: str, prefixes: List[str]) -> str:
+ for prefix in prefixes:
+ if text.startswith(prefix):
+ return text.removeprefix(prefix).strip()
+ return text
diff --git a/docs/finetrainers-src-codebase/finetrainers/functional/video.py b/docs/finetrainers-src-codebase/finetrainers/functional/video.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fadf66373554b749bdb1d68e455932f158ad9b9
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/functional/video.py
@@ -0,0 +1,96 @@
+from typing import List, Literal, Tuple
+
+import torch
+import torch.nn.functional as F
+
+
+def center_crop_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
+ num_frames, num_channels, height, width = video.shape
+ crop_h, crop_w = size
+ if height < crop_h or width < crop_w:
+ raise ValueError(f"Video size {(height, width)} is smaller than the target size {size}.")
+ top = (height - crop_h) // 2
+ left = (width - crop_w) // 2
+ return video[:, :, top : top + crop_h, left : left + crop_w]
+
+
+def resize_crop_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
+ num_frames, num_channels, height, width = video.shape
+ target_h, target_w = size
+ scale = max(target_h / height, target_w / width)
+ new_h, new_w = int(height * scale), int(width * scale)
+ video = F.interpolate(video, size=(new_h, new_w), mode="bilinear", align_corners=False)
+ return center_crop_video(video, size)
+
+
+def bicubic_resize_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
+ num_frames, num_channels, height, width = video.shape
+ video = F.interpolate(video, size=size, mode="bicubic", align_corners=False)
+ return video
+
+
+def find_nearest_video_resolution(
+ video: torch.Tensor, resolution_buckets: List[Tuple[int, int, int]]
+) -> Tuple[int, int, int]:
+ num_frames, num_channels, height, width = video.shape
+ aspect_ratio = width / height
+ possible_buckets = [b for b in resolution_buckets if b[0] <= num_frames]
+
+ if not possible_buckets:
+ best_frame_match = min(resolution_buckets, key=lambda b: abs(b[0] - num_frames))
+ else:
+ best_frame_match = max(possible_buckets, key=lambda b: b[0])
+
+ frame_filtered_buckets = [b for b in resolution_buckets if b[0] == best_frame_match[0]]
+
+ def aspect_ratio_diff(bucket):
+ return abs((bucket[2] / bucket[1]) - aspect_ratio), (-bucket[1], -bucket[2])
+
+ return min(frame_filtered_buckets, key=aspect_ratio_diff)
+
+
+def resize_to_nearest_bucket_video(
+ video: torch.Tensor,
+ resolution_buckets: List[Tuple[int, int, int]],
+ resize_mode: Literal["center_crop", "resize_crop", "bicubic"] = "bicubic",
+) -> torch.Tensor:
+ """
+ Resizes a video tensor to the nearest resolution bucket using the specified mode.
+ - It first finds a frame match with <= T frames.
+ - Then, it selects the closest height/width bucket.
+
+ Args:
+ video (`torch.Tensor`):
+ Input video tensor of shape `(B, T, C, H, W)`.
+ resolution_buckets (`List[Tuple[int, int, int]]`):
+ Available (num_frames, height, width) resolution buckets.
+ resize_mode (`str`):
+ One of ["center_crop", "resize_crop", "bicubic"].
+
+ Returns:
+ `torch.Tensor`:
+ Resized video tensor of the nearest bucket resolution.
+ """
+ target_frames, target_h, target_w = find_nearest_video_resolution(video, resolution_buckets)
+
+ # Adjust frame count: only interpolate frames if no lesser/equal frame count exists
+ num_frames, num_channels, height, width = video.shape
+ _first_frame_only = False
+ if num_frames > target_frames:
+ # Downsample: Select frames evenly
+ indices = torch.linspace(0, num_frames - 1, target_frames).long()
+ video = video[indices, :, :, :]
+ elif num_frames < target_frames:
+ _first_frame_only = False
+
+ # Resize spatial resolution
+ if resize_mode == "center_crop":
+ return center_crop_video(video, (target_h, target_w)), _first_frame_only
+ elif resize_mode == "resize_crop":
+ return resize_crop_video(video, (target_h, target_w)), _first_frame_only
+ elif resize_mode == "bicubic":
+ return bicubic_resize_video(video, (target_h, target_w)), _first_frame_only
+ else:
+ raise ValueError(
+ f"Invalid resize_mode: {resize_mode}. Choose from 'center_crop', 'resize_crop', or 'bicubic'."
+ )
diff --git a/docs/finetrainers-src-codebase/finetrainers/logging.py b/docs/finetrainers-src-codebase/finetrainers/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..66cf41b4cb943486067f74c550e1c53811e4290d
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/logging.py
@@ -0,0 +1,139 @@
+import logging
+import os
+from typing import TYPE_CHECKING, Union
+
+import diffusers
+import transformers
+
+from .constants import FINETRAINERS_LOG_LEVEL
+
+
+if TYPE_CHECKING:
+ from .parallel import ParallelBackendType
+
+
+class FinetrainersLoggerAdapter(logging.LoggerAdapter):
+ def __init__(self, logger: logging.Logger, parallel_backend: "ParallelBackendType" = None) -> None:
+ super().__init__(logger, {})
+ self.parallel_backend = parallel_backend
+ self._log_freq = {}
+ self._log_freq_counter = {}
+
+ def log(
+ self,
+ level,
+ msg,
+ *args,
+ main_process_only: bool = False,
+ local_main_process_only: bool = True,
+ in_order: bool = False,
+ **kwargs,
+ ):
+ # set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice
+ kwargs.setdefault("stacklevel", 2)
+
+ if not self.isEnabledFor(level):
+ return
+
+ if self.parallel_backend is None:
+ if int(os.environ.get("RANK", 0)) == 0:
+ msg, kwargs = self.process(msg, kwargs)
+ self.logger.log(level, msg, *args, **kwargs)
+ return
+
+ if (main_process_only or local_main_process_only) and in_order:
+ raise ValueError(
+ "Cannot set `main_process_only` or `local_main_process_only` to True while `in_order` is True."
+ )
+
+ if (main_process_only and self.parallel_backend.is_main_process) or (
+ local_main_process_only and self.parallel_backend.is_local_main_process
+ ):
+ msg, kwargs = self.process(msg, kwargs)
+ self.logger.log(level, msg, *args, **kwargs)
+ return
+
+ if in_order:
+ for i in range(self.parallel_backend.world_size):
+ if self.rank == i:
+ msg, kwargs = self.process(msg, kwargs)
+ self.logger.log(level, msg, *args, **kwargs)
+ self.parallel_backend.wait_for_everyone()
+ return
+
+ if not main_process_only and not local_main_process_only:
+ msg, kwargs = self.process(msg, kwargs)
+ self.logger.log(level, msg, *args, **kwargs)
+ return
+
+ def log_freq(
+ self,
+ level: str,
+ name: str,
+ msg: str,
+ frequency: int,
+ *,
+ main_process_only: bool = False,
+ local_main_process_only: bool = True,
+ in_order: bool = False,
+ **kwargs,
+ ) -> None:
+ if frequency <= 0:
+ return
+ if name not in self._log_freq_counter:
+ self._log_freq[name] = frequency
+ self._log_freq_counter[name] = 0
+ if self._log_freq_counter[name] % self._log_freq[name] == 0:
+ self.log(
+ level,
+ msg,
+ main_process_only=main_process_only,
+ local_main_process_only=local_main_process_only,
+ in_order=in_order,
+ **kwargs,
+ )
+ self._log_freq_counter[name] += 1
+
+
+def get_logger() -> Union[logging.Logger, FinetrainersLoggerAdapter]:
+ global _logger
+ return _logger
+
+
+def _set_parallel_backend(parallel_backend: "ParallelBackendType") -> FinetrainersLoggerAdapter:
+ _logger.parallel_backend = parallel_backend
+
+
+_logger = logging.getLogger("finetrainers")
+_logger.setLevel(FINETRAINERS_LOG_LEVEL)
+_console_handler = logging.StreamHandler()
+_console_handler.setLevel(FINETRAINERS_LOG_LEVEL)
+_formatter = logging.Formatter("%(asctime)s - [%(levelname)s] - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
+_console_handler.setFormatter(_formatter)
+_logger.addHandler(_console_handler)
+_logger.propagate = False
+_logger = FinetrainersLoggerAdapter(_logger)
+
+
+def set_dependency_log_level(verbose: int = 0, is_local_main_process: bool = False) -> None:
+ transformers_log_level = transformers.utils.logging.set_verbosity_error
+ diffusers_log_level = diffusers.utils.logging.set_verbosity_error
+
+ if verbose == 0:
+ if is_local_main_process:
+ transformers_log_level = transformers.utils.logging.set_verbosity_warning
+ diffusers_log_level = diffusers.utils.logging.set_verbosity_warning
+ elif verbose == 1:
+ if is_local_main_process:
+ transformers_log_level = transformers.utils.logging.set_verbosity_info
+ diffusers_log_level = diffusers.utils.logging.set_verbosity_info
+ elif verbose == 2:
+ if is_local_main_process:
+ transformers_log_level = transformers.utils.logging.set_verbosity_debug
+ diffusers_log_level = diffusers.utils.logging.set_verbosity_debug
+ else:
+ transformers_log_level = transformers.utils.logging.set_verbosity_debug
+ diffusers_log_level = diffusers.utils.logging.set_verbosity_debug
+
+ transformers_log_level()
+ diffusers_log_level()
diff --git a/docs/finetrainers-src-codebase/finetrainers/models/__init__.py b/docs/finetrainers-src-codebase/finetrainers/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec474ff2a5a786085cc5df72f295a590deee08fc
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/models/__init__.py
@@ -0,0 +1,8 @@
+from .attention_dispatch import AttentionProvider, attention_dispatch, attention_provider
+from .modeling_utils import ControlModelSpecification, ModelSpecification
+
+
+from ._metadata.transformer import register_transformer_metadata # isort: skip
+
+
+register_transformer_metadata()
diff --git a/docs/finetrainers-src-codebase/finetrainers/models/_metadata/__init__.py b/docs/finetrainers-src-codebase/finetrainers/models/_metadata/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d5baeb109014d64f9bff2f102c28ee0a3da40f8
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/models/_metadata/__init__.py
@@ -0,0 +1 @@
+from .transformer import register_transformer_metadata
diff --git a/docs/finetrainers-src-codebase/finetrainers/models/_metadata/transformer.py b/docs/finetrainers-src-codebase/finetrainers/models/_metadata/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..33de3148ca0b025d41be8602dc0c17c9b4eed4aa
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/models/_metadata/transformer.py
@@ -0,0 +1,86 @@
+from diffusers import (
+ CogVideoXTransformer3DModel,
+ CogView4Transformer2DModel,
+ FluxTransformer2DModel,
+ WanTransformer3DModel,
+)
+
+from finetrainers._metadata import CPInput, CPOutput, ParamId, TransformerMetadata, TransformerRegistry
+from finetrainers.logging import get_logger
+
+
+logger = get_logger()
+
+
+def register_transformer_metadata():
+ # CogVideoX
+ TransformerRegistry.register(
+ model_class=CogVideoXTransformer3DModel,
+ metadata=TransformerMetadata(
+ cp_plan={
+ "": {
+ ParamId("image_rotary_emb", 5): [CPInput(0, 2), CPInput(0, 2)],
+ },
+ "transformer_blocks.0": {
+ ParamId("hidden_states", 0): CPInput(1, 3),
+ ParamId("encoder_hidden_states", 1): CPInput(1, 3),
+ },
+ "proj_out": [CPOutput(1, 3)],
+ }
+ ),
+ )
+
+ # CogView4
+ TransformerRegistry.register(
+ model_class=CogView4Transformer2DModel,
+ metadata=TransformerMetadata(
+ cp_plan={
+ "patch_embed": {
+ ParamId(index=0): CPInput(1, 3, split_output=True),
+ ParamId(index=1): CPInput(1, 3, split_output=True),
+ },
+ "rope": {
+ ParamId(index=0): CPInput(0, 2, split_output=True),
+ ParamId(index=1): CPInput(0, 2, split_output=True),
+ },
+ "proj_out": [CPOutput(1, 3)],
+ }
+ ),
+ )
+
+ # Flux
+ TransformerRegistry.register(
+ model_class=FluxTransformer2DModel,
+ metadata=TransformerMetadata(
+ cp_plan={
+ "": {
+ ParamId("hidden_states", 0): CPInput(1, 3),
+ ParamId("encoder_hidden_states", 1): CPInput(1, 3),
+ ParamId("img_ids", 4): CPInput(0, 2),
+ ParamId("txt_ids", 5): CPInput(0, 2),
+ },
+ "proj_out": [CPOutput(1, 3)],
+ }
+ ),
+ )
+
+ # Wan2.1
+ TransformerRegistry.register(
+ model_class=WanTransformer3DModel,
+ metadata=TransformerMetadata(
+ cp_plan={
+ "rope": {
+ ParamId(index=0): CPInput(2, 4, split_output=True),
+ },
+ "blocks.*": {
+ ParamId("encoder_hidden_states", 1): CPInput(1, 3),
+ },
+ "blocks.0": {
+ ParamId("hidden_states", 0): CPInput(1, 3),
+ },
+ "proj_out": [CPOutput(1, 3)],
+ }
+ ),
+ )
+
+ logger.debug("Metadata for transformer registered")
diff --git a/docs/finetrainers-src-codebase/finetrainers/models/attention_dispatch.py b/docs/finetrainers-src-codebase/finetrainers/models/attention_dispatch.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1438674db1f7bf242df1af0fd385a09a081eda6
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/models/attention_dispatch.py
@@ -0,0 +1,1812 @@
+import contextlib
+import functools
+import inspect
+from enum import Enum
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
+
+import torch
+
+# Since we will be patching the `scaled_dot_product_attention` function with `attention_dispatch` to take
+# control for dispatching to different attention providers, we need to import the original function
+# to be able to use it and not go into infinite recursion when the dispatcher calls `scaled_dot_product_attention`.
+import torch.autograd
+from diffusers.utils.import_utils import OptionalDependencyNotAvailable
+from torch.nn.functional import scaled_dot_product_attention as native_sdpa
+
+from finetrainers.constants import FINETRAINERS_ATTN_CHECKS, FINETRAINERS_ATTN_PROVIDER
+from finetrainers.logging import get_logger
+from finetrainers.utils.import_utils import (
+ is_flash_attn_available,
+ is_flash_attn_version,
+ is_sageattention_available,
+ is_sageattention_version,
+ is_torch_version,
+ is_xformers_available,
+ is_xformers_version,
+)
+
+
+if is_flash_attn_available():
+ if is_flash_attn_version("<", "2.6.3"):
+ raise OptionalDependencyNotAvailable(
+ "The `flash-attn` library version is too old. Please update it to at least 2.6.3."
+ )
+
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward
+else:
+ flash_attn_func = None
+ flash_attn_varlen_func = None
+ _flash_attn_forward = None
+ _flash_attn_backward = None
+
+
+if is_sageattention_available():
+ if is_sageattention_version("<", "2.1.1"):
+ raise OptionalDependencyNotAvailable(
+ "The `sageattention` library version is too old. Please update it to at least 2.1.1."
+ )
+
+ from sageattention import (
+ sageattn,
+ sageattn_qk_int8_pv_fp8_cuda,
+ sageattn_qk_int8_pv_fp8_cuda_sm90,
+ sageattn_qk_int8_pv_fp16_cuda,
+ sageattn_qk_int8_pv_fp16_triton,
+ sageattn_varlen,
+ )
+else:
+ sageattn = None
+ sageattn_qk_int8_pv_fp16_cuda = None
+ sageattn_qk_int8_pv_fp16_triton = None
+ sageattn_qk_int8_pv_fp8_cuda = None
+ sageattn_qk_int8_pv_fp8_cuda_sm90 = None
+ sageattn_varlen = None
+
+
+if is_torch_version(">=", "2.5.0"):
+ import torch.nn.attention.flex_attention as flex_attention
+
+
+if is_torch_version(">=", "2.6.0"):
+ from torch.distributed.tensor.experimental._attention import (
+ _AttentionOp,
+ _cp_options,
+ _templated_ring_attention,
+ _templated_ring_attention_backward,
+ set_rotate_method,
+ )
+else:
+ _cp_options = None
+ _templated_ring_attention = None
+ set_rotate_method = None
+
+ class _AttentionOp:
+ def __init__(self, *args, **kwargs):
+ raise OptionalDependencyNotAvailable(
+ "The `torch.distributed.tensor.experimental._attention` module is not available. Please update PyTorch to at least 2.6.0."
+ )
+
+
+if is_xformers_available():
+ if is_xformers_version("<", "0.0.29"):
+ raise OptionalDependencyNotAvailable(
+ "The `xformers` library version is too old. Please update it to at least 0.0.29."
+ )
+
+ import xformers.ops as xops
+else:
+ xops = None
+
+
+logger = get_logger()
+
+_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
+_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
+_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
+
+
+# ===== Custom operator implementations/wrappers =====
+
+
+def _finetrainers_scaled_dot_product_efficient_attention_forward(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_bias: Optional[torch.Tensor] = None,
+ compute_log_sumexp: bool = False,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Wrapper for https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14946
+ # See: https://github.com/pytorch/pytorch/issues/152942
+ seqlen_q = query.shape[-2]
+ out, lse, philox_seed, philox_offset = torch.ops.aten._scaled_dot_product_efficient_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_bias=attn_bias,
+ compute_log_sumexp=compute_log_sumexp,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ )
+
+ # LSE is aligned to the next nearest multiple of 32. This is a workaround to return the lse without alignment so that pytorch
+ # ring attention does not error out with shape mismatch
+ if compute_log_sumexp:
+ assert lse.ndim == 3
+ lse = lse[:, :, :seqlen_q] # .contiguous()
+
+ return out, lse, philox_seed, philox_offset
+
+
+# aten::_scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor)
+def _finetrainers_scaled_dot_product_efficient_attention_backward(
+ grad_out_: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_bias: torch.Tensor,
+ out: torch.Tensor,
+ logsumexp: torch.Tensor,
+ philox_seed: torch.Tensor,
+ philox_offset: torch.Tensor,
+ dropout_p: float,
+ grad_input_mask: List[bool],
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ assert len(grad_input_mask) == 4
+ # https://github.com/pytorch/pytorch/blob/bb9fbb294af385057a72e5b1386cf40f86aadbec/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h#L113
+ kAlignLSE = 32
+
+ logsumexp = torch.nn.functional.pad(
+ logsumexp, (0, kAlignLSE - (logsumexp.shape[-1] % kAlignLSE)), value=float("inf")
+ )
+
+ grad_query, grad_key, grad_value, grad_attn_bias = torch.ops.aten._scaled_dot_product_efficient_attention_backward(
+ grad_out_=grad_out_,
+ query=query,
+ key=key,
+ value=value,
+ attn_bias=attn_bias,
+ out=out,
+ logsumexp=logsumexp,
+ philox_seed=philox_seed,
+ philox_offset=philox_offset,
+ dropout_p=dropout_p,
+ grad_input_mask=grad_input_mask,
+ is_causal=is_causal,
+ scale=scale,
+ )
+
+ return grad_query, grad_key, grad_value, grad_attn_bias
+
+
+# This function wraps the actual _flash_attn_forward call to return LSE at index 1 to be compatible with pytorch's native ring attention
+def _finetrainers_flash_attn_forward(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ dropout_p: float = 0.0,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ window_size: Tuple[int, int] = (-1, -1),
+ softcap: float = 0.0,
+ alibi_slopes: Optional[torch.Tensor] = None,
+ return_softmax: bool = False,
+):
+ query, key, value = (
+ x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value)
+ ) # [B, N, S, D] -> [B, S, N, D]
+ out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
+ query, key, value, dropout_p, scale, is_causal, window_size, softcap, alibi_slopes, return_softmax
+ )
+ out = out.permute(0, 2, 1, 3).contiguous() # [B, S, N, D] -> [B, N, S, D]
+ return out, softmax_lse, q, k, v, out_padded, S_dmask, rng_state
+
+
+# This function wraps the actual _flash_attn_backward call as the counterpart of the _finetrainers_flash_attn_forward function
+def _finetrainers_flash_attn_backward(
+ grad_out: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ out: torch.Tensor,
+ logsumexp: torch.Tensor, # Needs a different names than the one used in flash-attn because _templated_ring_attention_backward assumes name is logsumexp
+ dropout_p: float,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ window_size: Tuple[int, int] = (-1, -1),
+ softcap: float = 0.0,
+ alibi_slopes: Optional[torch.Tensor] = None,
+ deterministic: bool = False,
+ rng_state: Optional[torch.Tensor] = None,
+ _permute_outputs: bool = True,
+):
+ dq, dk, dv = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
+ grad_out = grad_out.permute(0, 2, 1, 3).contiguous() # [B, N, S, D] -> [B, S, N, D]
+
+ dq, dk, dv, softmax_d = _flash_attn_backward(
+ grad_out,
+ query,
+ key,
+ value,
+ out,
+ logsumexp,
+ dq,
+ dk,
+ dv,
+ dropout_p,
+ scale,
+ is_causal,
+ window_size,
+ softcap,
+ alibi_slopes,
+ deterministic,
+ rng_state,
+ )
+
+ # Head dimension may have been padded
+ dq = dq[..., : grad_out.shape[-1]]
+ dk = dk[..., : grad_out.shape[-1]]
+ dv = dv[..., : grad_out.shape[-1]]
+
+ if _permute_outputs:
+ dq, dk, dv = (x.permute(0, 2, 1, 3).contiguous() for x in (dq, dk, dv)) # [B, S, N, D] -> [B, N, S, D]
+ return dq, dk, dv
+
+
+# ===== Attention provider =====
+
+
+class AttentionProvider(str, Enum):
+ # EAGER = "eager"
+
+ # `flash-attn`
+ FLASH = "flash"
+ FLASH_VARLEN = "flash_varlen"
+
+ # PyTorch native
+ FLEX = "flex"
+ NATIVE = "native"
+ _NATIVE_CUDNN = "_native_cudnn"
+ _NATIVE_EFFICIENT = "_native_efficient"
+ _NATIVE_FLASH = "_native_flash"
+ _NATIVE_MATH = "_native_math"
+
+ # `sageattention`
+ SAGE = "sage"
+ SAGE_VARLEN = "sage_varlen"
+ _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
+ _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
+ _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda"
+ _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton"
+ # TODO: let's not add support for Sparge Attention now because it requires tuning per model
+ # We can look into supporting something "autotune"-ing in the future
+ # SPARGE = "sparge"
+
+ # `xformers`
+ XFORMERS = "xformers"
+
+
+class _AttentionProviderRegistry:
+ _providers = {}
+ _constraints = {}
+ _supports_cp = {}
+ _supported_arg_names = {}
+
+ _active_provider = AttentionProvider(FINETRAINERS_ATTN_PROVIDER)
+ _checks_enabled = FINETRAINERS_ATTN_CHECKS
+
+ # Context parallel attributes
+ _mesh: torch.distributed.device_mesh.DeviceMesh = None
+ _convert_to_fp32: bool = None
+ _rotate_method: Literal["allgather", "alltoall"] = None
+
+ @classmethod
+ def register(
+ cls, provider: AttentionProvider, constraints: Optional[List[Callable]] = None, supports_cp: bool = False
+ ):
+ logger.debug(f"Registering attention provider: {provider}")
+
+ def decorator(func):
+ cls._providers[provider] = func
+ cls._constraints[provider] = constraints or []
+ cls._supports_cp[provider] = supports_cp
+ cls._supported_arg_names[provider] = set(inspect.signature(func).parameters.keys())
+ return func
+
+ return decorator
+
+ @classmethod
+ def get_active_provider(cls):
+ return cls._active_provider, cls._providers[cls._active_provider]
+
+ @classmethod
+ def list_providers(cls):
+ return list(cls._providers.keys())
+
+ @classmethod
+ def supports_context_parallel(cls, provider: AttentionProvider):
+ if provider not in cls._providers:
+ raise ValueError(f"Provider {provider} is not registered.")
+ return cls._supports_cp.get(provider, False)
+
+ @classmethod
+ def context_parallel_enabled(cls):
+ return cls._mesh is not None
+
+ @classmethod
+ def _set_context_parallel(
+ cls,
+ mesh: torch.distributed.device_mesh.DeviceMesh = None,
+ convert_to_fp32: bool = None,
+ rotate_method: str = None,
+ *,
+ reset: bool = False,
+ ):
+ if reset:
+ mesh = convert_to_fp32 = rotate_method = None
+ cls._mesh = mesh
+ cls._convert_to_fp32 = convert_to_fp32
+ cls._rotate_method = rotate_method
+
+ @classmethod
+ def _raise_cp_error_if_mesh_not_set(cls):
+ if cls._mesh is None:
+ raise ValueError(
+ "`_AttentionProviderRegistry._mesh` is None. It must be set before calling context parallel attention methods."
+ )
+
+
+@contextlib.contextmanager
+def attention_provider(
+ provider: AttentionProvider = AttentionProvider.NATIVE,
+ *,
+ mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
+ convert_to_fp32: bool = True,
+ rotate_method: str = "allgather",
+):
+ """Context manager to set the active attention provider and possibly enable context parallelism."""
+
+ if provider not in _AttentionProviderRegistry._providers:
+ raise ValueError(f"Provider {provider} is not registered.")
+ if mesh is not None and not _AttentionProviderRegistry.supports_context_parallel(provider):
+ raise ValueError(f"Provider {provider} does not support context parallelism.")
+
+ old_provider = _AttentionProviderRegistry._active_provider
+ _AttentionProviderRegistry._active_provider = provider
+
+ _AttentionProviderRegistry._mesh = mesh
+ _AttentionProviderRegistry._convert_to_fp32 = convert_to_fp32
+ _AttentionProviderRegistry._rotate_method = rotate_method
+ if mesh is not None:
+ _convert_to_f32 = _cp_options.convert_to_f32
+ _enable_load_balance = _cp_options.enable_load_balance
+ _rotate_method = _cp_options.rotate_method
+
+ try:
+ yield
+ finally:
+ _AttentionProviderRegistry._active_provider = old_provider
+
+ _AttentionProviderRegistry._mesh = None
+ _AttentionProviderRegistry._convert_to_fp32 = None
+ _AttentionProviderRegistry._rotate_method = None
+ if mesh is not None:
+ _cp_options.convert_to_f32 = _convert_to_f32
+ _cp_options.enable_load_balance = _enable_load_balance
+ _cp_options.rotate_method = _rotate_method
+
+
+def attention_dispatch(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+) -> torch.Tensor:
+ attention_kwargs = attention_kwargs or {}
+ provider_name, provider_fn = _AttentionProviderRegistry.get_active_provider()
+ kwargs = {
+ "query": query,
+ "key": key,
+ "value": value,
+ "attn_mask": attn_mask,
+ "dropout_p": dropout_p,
+ "is_causal": is_causal,
+ "scale": scale,
+ "enable_gqa": enable_gqa,
+ **attention_kwargs,
+ }
+
+ if _AttentionProviderRegistry._checks_enabled:
+ removed_kwargs = set(kwargs) - set(_AttentionProviderRegistry._supported_arg_names[provider_name])
+ if removed_kwargs:
+ log_freq = 512
+ msg = (
+ f"Removing unsupported arguments for attention provider {provider_name}: {removed_kwargs}. This "
+ f"message will be logged every {log_freq} calls."
+ )
+ logger.log_freq("WARNING", "REMOVING_ATTN_UNSUPPORTED_KWARGS", msg, log_freq)
+ for check in _AttentionProviderRegistry._constraints.get(provider_name):
+ check(**kwargs)
+
+ kwargs = {k: v for k, v in kwargs.items() if k in _AttentionProviderRegistry._supported_arg_names[provider_name]}
+
+ if _AttentionProviderRegistry.context_parallel_enabled():
+ _set_context_parallel_options(**kwargs)
+
+ return provider_fn(**kwargs)
+
+
+# ===== Helper functions =====
+
+
+# @torch.compiler.assume_constant_result
+def _set_context_parallel_options(is_causal: bool, **kwargs):
+ _cp_options.enable_load_balance = is_causal
+ _cp_options.convert_to_f32 = _AttentionProviderRegistry._convert_to_fp32
+ set_rotate_method(_AttentionProviderRegistry._rotate_method)
+
+
+def _check_attn_mask_is_none(attn_mask: Optional[torch.Tensor], **kwargs) -> None:
+ if attn_mask is not None:
+ raise ValueError("Attention mask must be None for this provider.")
+
+
+def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None:
+ if attn_mask is not None and is_causal:
+ raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.")
+
+
+def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ if query.device != key.device or query.device != value.device:
+ raise ValueError("Query, key, and value must be on the same device.")
+ if query.dtype != key.dtype or query.dtype != value.dtype:
+ raise ValueError("Query, key, and value must have the same dtype.")
+
+
+def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ _check_device(query, key, value)
+ if query.device.type != "cuda":
+ raise ValueError("Query, key, and value must be on a CUDA device.")
+
+
+def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable:
+ def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ _check_device_cuda(query, key, value)
+ if torch.cuda.get_device_capability(query.device) < (major, minor):
+ raise ValueError(
+ f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}."
+ )
+
+ return check_device_cuda
+
+
+def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ if query.dtype != key.dtype:
+ raise ValueError("Query and key must have the same dtype.")
+ if query.dtype != value.dtype:
+ raise ValueError("Query and value must have the same dtype.")
+
+
+def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ _check_qkv_dtype_match(query, key, value)
+ if query.dtype not in (torch.bfloat16, torch.float16):
+ raise ValueError("Query, key, and value must be either bfloat16 or float16.")
+
+
+def _check_shape(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+) -> None:
+ if query.shape[-1] != key.shape[-1]:
+ raise ValueError("Query and key must have the same last dimension.")
+ if query.shape[-2] != value.shape[-2]:
+ raise ValueError("Query and value must have the same second to last dimension.")
+ if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]:
+ raise ValueError("Attention mask must match the key's second to last dimension.")
+
+
+def _prepare_for_flash_attn_or_sage_varlen(
+ batch_size: int,
+ seq_len_q: int,
+ seq_len_kv: int,
+ attn_mask: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+) -> None:
+ seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
+ if attn_mask is None:
+ seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
+ else:
+ seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)
+ cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
+ cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
+ cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
+ cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
+ max_seqlen_q = seqlens_q.max().item()
+ max_seqlen_k = seqlens_k.max().item()
+ return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
+
+
+def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor:
+ """
+ Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_k in
+ FlashAttention/Sage varlen.
+
+ Supports 1D to 4D shapes and common broadcasting patterns.
+ """
+ if attn_mask.dtype != torch.bool:
+ raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.")
+
+ if attn_mask.ndim == 1:
+ # [seq_len_k] -> broadcast across batch
+ attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k)
+
+ elif attn_mask.ndim == 2:
+ # [batch_size, seq_len_k]. Maybe broadcast across batch
+ if attn_mask.size(0) not in [1, batch_size]:
+ raise ValueError(
+ f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask."
+ )
+ attn_mask = attn_mask.expand(batch_size, seq_len_k)
+
+ elif attn_mask.ndim == 3:
+ # [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension
+ if attn_mask.size(0) not in [1, batch_size]:
+ raise ValueError(
+ f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask."
+ )
+ attn_mask = attn_mask.any(dim=1)
+ attn_mask = attn_mask.expand(batch_size, seq_len_k)
+
+ elif attn_mask.ndim == 4:
+ # [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions
+ if attn_mask.size(0) not in [1, batch_size]:
+ raise ValueError(
+ f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask."
+ )
+ attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K]
+ attn_mask = attn_mask.any(dim=(1, 2)) # [B, K]
+
+ else:
+ raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}")
+
+ if attn_mask.shape != (batch_size, seq_len_k):
+ raise ValueError(
+ f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})"
+ )
+
+ return attn_mask
+
+
+def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
+ return q_idx >= kv_idx
+
+
+# ===== Attention provider implementations =====
+
+
+# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807
+class _flash_attn_flash_attention(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ dropout_p: float = 0.0,
+ softmax_scale: Optional[float] = None,
+ causal: bool = False,
+ window_size: Tuple[int, int] = (-1, -1),
+ softcap: float = 0.0,
+ alibi_slopes: Optional[torch.Tensor] = None,
+ deterministic: bool = False,
+ return_softmax: bool = False,
+ ):
+ if softmax_scale is None:
+ softmax_scale = q.shape[-1] ** (-0.5)
+
+ ctx.dropout_p = dropout_p
+ ctx.softmax_scale = softmax_scale
+ ctx.causal = causal
+ ctx.window_size = window_size
+ ctx.softcap = softcap
+ ctx.alibi_slopes = alibi_slopes
+ ctx.deterministic = deterministic
+
+ out, lse, q, k, v, out_padded, S_dmask, rng_state = _finetrainers_flash_attn_forward(
+ query=q,
+ key=k,
+ value=v,
+ dropout_p=dropout_p,
+ scale=softmax_scale,
+ is_causal=causal,
+ window_size=window_size,
+ softcap=softcap,
+ alibi_slopes=alibi_slopes,
+ return_softmax=return_softmax,
+ )
+
+ ctx.save_for_backward(q, k, v, out_padded, lse, rng_state)
+
+ return (out, lse) if return_softmax else out
+
+ @staticmethod
+ def backward(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args: torch.Tensor,
+ ):
+ q, k, v, out, lse, rng_state = ctx.saved_tensors
+
+ grad_query, grad_key, grad_value = _finetrainers_flash_attn_backward(
+ grad_out=grad_out,
+ query=q,
+ key=k,
+ value=v,
+ out=out,
+ logsumexp=lse,
+ dropout_p=ctx.dropout_p,
+ scale=ctx.softmax_scale,
+ is_causal=ctx.causal,
+ window_size=ctx.window_size,
+ softcap=ctx.softcap,
+ alibi_slopes=ctx.alibi_slopes,
+ deterministic=ctx.deterministic,
+ rng_state=rng_state,
+ )
+
+ return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
+
+
+# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807
+class _native_ring_flash_attn_flash_attention(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ dropout_p: float = 0.0,
+ softmax_scale: Optional[float] = None,
+ causal: bool = False,
+ window_size: Tuple[int, int] = (-1, -1),
+ softcap: float = 0.0,
+ alibi_slopes: Optional[torch.Tensor] = None,
+ deterministic: bool = False,
+ return_softmax: bool = False,
+ ):
+ if softmax_scale is None:
+ softmax_scale = q.shape[-1] ** (-0.5)
+
+ # For ring flash attention using the flash-attn repo, we want the LSE but flash-attn only supports it if dropout_p > 0
+ dropout_p = dropout_p if dropout_p > 0 else 1e-30
+
+ ctx.dropout_p = dropout_p
+ ctx.softmax_scale = softmax_scale
+ ctx.causal = causal
+ ctx.window_size = window_size
+ ctx.softcap = softcap
+ ctx.alibi_slopes = alibi_slopes
+ ctx.deterministic = deterministic
+
+ out, lse, q, k, v, out_padded, S_dmask, rng_state = _templated_ring_attention(
+ mesh=_AttentionProviderRegistry._mesh,
+ seq_dim=2,
+ op=_finetrainers_flash_attn_forward,
+ query=q,
+ key=k,
+ value=v,
+ dropout_p=dropout_p,
+ scale=softmax_scale,
+ is_causal=causal,
+ window_size=window_size,
+ softcap=softcap,
+ alibi_slopes=alibi_slopes,
+ return_softmax=True,
+ )
+
+ ctx.save_for_backward(q, k, v, out_padded, lse, rng_state)
+
+ return (out, lse) if return_softmax else out
+
+ @staticmethod
+ def backward(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args: torch.Tensor,
+ ):
+ q, k, v, out, lse, rng_state = ctx.saved_tensors
+ lse = lse.permute(0, 2, 1).contiguous() # [B, N, S] -> [B, S, N]
+
+ grad_query, grad_key, grad_value = _templated_ring_attention_backward(
+ mesh=_AttentionProviderRegistry._mesh,
+ # This needs to be 1 because q, k, v, out_padded returned from forward are BSND instead of BNSD
+ # The grad_out permutation is handled in _finetrainers_flash_attn_backward, and the outputs from that are expected to have
+ # shape BSND instead of BNSD (requirement of _templated_ring_attention_backward), so we need to set seq_dim=1 and permute the
+ # returned outputs
+ seq_dim=1,
+ op=functools.partial(_finetrainers_flash_attn_backward, _permute_outputs=False),
+ grad_out=grad_out,
+ grad_out_name="grad_out",
+ query=q,
+ key=k,
+ value=v,
+ out=out,
+ logsumexp=lse,
+ dropout_p=ctx.dropout_p,
+ scale=ctx.softmax_scale,
+ is_causal=ctx.causal,
+ window_size=ctx.window_size,
+ softcap=ctx.softcap,
+ alibi_slopes=ctx.alibi_slopes,
+ deterministic=ctx.deterministic,
+ rng_state=rng_state,
+ )
+ grad_query, grad_key, grad_value = (
+ x.permute(0, 2, 1, 3).contiguous() for x in (grad_query, grad_key, grad_value)
+ )
+
+ return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
+
+
+@_AttentionProviderRegistry.register(
+ AttentionProvider.FLASH,
+ constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_cp=True,
+)
+def flash_attn_flash_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ dropout_p: float = 0.0,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ window_size: Tuple[int, int] = (-1, -1),
+ softcap: float = 0.0,
+ alibi_slopes: Optional[torch.Tensor] = None,
+ deterministic: bool = False,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ dispatch_fn = (
+ _native_ring_flash_attn_flash_attention
+ if _AttentionProviderRegistry.context_parallel_enabled()
+ else _flash_attn_flash_attention
+ )
+ return dispatch_fn.apply(
+ query, key, value, dropout_p, scale, is_causal, window_size, softcap, alibi_slopes, deterministic, return_lse
+ )
+
+
+@_AttentionProviderRegistry.register(
+ AttentionProvider.FLASH_VARLEN,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_cp=False,
+)
+def _flash_varlen_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ cu_seqlens_q: Optional[torch.Tensor] = None,
+ cu_seqlens_k: Optional[torch.Tensor] = None,
+ max_seqlen_q: Optional[int] = None,
+ max_seqlen_k: Optional[int] = None,
+ dropout_p: float = 0.0,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ window_size: Tuple[int, int] = (-1, -1),
+ softcap: float = 0.0,
+ alibi_slopes: Optional[torch.Tensor] = None,
+ deterministic: bool = False,
+ return_attn_probs: bool = False,
+ attn_mask: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ batch_size, _, seq_len_q, _ = query.shape
+ _, _, seq_len_kv, _ = key.shape
+
+ if attn_mask is not None:
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
+
+ if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
+ )
+ )
+ else:
+ seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
+ cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
+ cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
+
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+
+ key_valid, value_valid = [], []
+ for b in range(batch_size):
+ valid_len = seqlens_k[b]
+ key_valid.append(key[b, :valid_len])
+ value_valid.append(value[b, :valid_len])
+
+ query_packed = query.flatten(0, 1)
+ key_packed = torch.cat(key_valid, dim=0)
+ value_packed = torch.cat(value_valid, dim=0)
+
+ if _AttentionProviderRegistry.context_parallel_enabled():
+ return_attn_probs = True
+
+ out = flash_attn_varlen_func(
+ q=query_packed,
+ k=key_packed,
+ v=value_packed,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ dropout_p=dropout_p,
+ softmax_scale=scale,
+ causal=is_causal,
+ window_size=window_size,
+ softcap=softcap,
+ alibi_slopes=alibi_slopes,
+ deterministic=deterministic,
+ return_attn_probs=return_attn_probs,
+ )
+
+ rest = None
+ if return_attn_probs:
+ out, *rest = out
+ out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) # .contiguous()
+ if return_attn_probs:
+ return out, *rest[:1]
+ return out
+
+
+@_AttentionProviderRegistry.register(
+ AttentionProvider.FLEX,
+ constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
+ supports_cp=False,
+)
+def _native_flex_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ kernel_options: Optional[Dict[str, Any]] = None,
+) -> torch.Tensor:
+ # TODO: should we LRU cache the block mask creation?
+ score_mod = None
+ block_mask = None
+ batch_size, num_heads, seq_len_q, _ = query.shape
+ _, _, seq_len_kv, _ = key.shape
+
+ if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask):
+ block_mask = attn_mask
+ elif is_causal:
+ block_mask = flex_attention.create_block_mask(
+ _flex_attention_causal_mask_mod, None, None, seq_len_q, seq_len_kv, query.device
+ )
+ elif torch.is_tensor(attn_mask):
+ if attn_mask.ndim == 2:
+ attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
+
+ attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv)
+
+ if attn_mask.dtype == torch.bool:
+ # TODO: this probably does not work but verify!
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
+ return attn_mask[batch_idx, head_idx, q_idx, kv_idx]
+
+ block_mask = flex_attention.create_block_mask(
+ mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device
+ )
+ else:
+
+ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
+ return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx]
+ else:
+ raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.")
+
+ return flex_attention.flex_attention(
+ query=query,
+ key=key,
+ value=value,
+ score_mod=score_mod,
+ block_mask=block_mask,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ return_lse=return_lse,
+ kernel_options=None,
+ )
+
+
+@_AttentionProviderRegistry.register(
+ AttentionProvider.NATIVE,
+ constraints=[_check_device, _check_shape],
+ supports_cp=False,
+)
+def _native_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ return native_sdpa(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+
+
+class _native_cudnn_attention(torch.autograd.Function):
+ # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
+ # forward declaration:
+ # aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
+ # backward declaration:
+ # aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
+
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ ):
+ ctx.dropout_p = dropout_p
+ ctx.is_causal = is_causal
+ ctx.scale = scale
+ ctx.attn_mask = attn_mask
+
+ out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
+ torch.ops.aten._scaled_dot_product_cudnn_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_bias=attn_mask,
+ compute_log_sumexp=True,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ return_debug_mask=False,
+ scale=scale,
+ )
+ )
+
+ ctx.max_q = max_q
+ ctx.max_k = max_k
+ ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
+
+ return (out, lse) if return_lse else out
+
+ @staticmethod
+ def backward(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args: torch.Tensor,
+ ):
+ query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors
+
+ grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_cudnn_attention_backward(
+ grad_out=grad_out,
+ query=query,
+ key=key,
+ value=value,
+ out=out,
+ logsumexp=lse,
+ philox_seed=philox_seed,
+ philox_offset=philox_offset,
+ attn_bias=ctx.attn_mask,
+ cum_seq_q=cum_seq_q,
+ cum_seq_k=cum_seq_k,
+ max_q=ctx.max_q,
+ max_k=ctx.max_k,
+ dropout_p=ctx.dropout_p,
+ is_causal=ctx.is_causal,
+ scale=ctx.scale,
+ )
+
+ return grad_query, grad_key, grad_value, None, None, None, None, None
+
+
+class _native_ring_native_cudnn_attention(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ ):
+ _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set()
+ ctx.dropout_p = dropout_p
+ ctx.is_causal = is_causal
+ ctx.scale = scale
+ ctx.attn_mask = attn_mask
+
+ out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
+ _templated_ring_attention(
+ mesh=_AttentionProviderRegistry._mesh,
+ seq_dim=2,
+ op=torch.ops.aten._scaled_dot_product_cudnn_attention,
+ query=query,
+ key=key,
+ value=value,
+ attn_bias=attn_mask,
+ compute_log_sumexp=True,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ return_debug_mask=False,
+ scale=scale,
+ )
+ )
+
+ ctx.max_q = max_q
+ ctx.max_k = max_k
+ ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
+
+ return (out, lse) if return_lse else out
+
+ @staticmethod
+ def backward(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args: torch.Tensor,
+ ):
+ _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set()
+ query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors
+
+ grad_query, grad_key, grad_value = _templated_ring_attention_backward(
+ mesh=_AttentionProviderRegistry._mesh,
+ seq_dim=2,
+ op=torch.ops.aten._scaled_dot_product_cudnn_attention_backward,
+ grad_out=grad_out,
+ grad_out_name="grad_out",
+ query=query,
+ key=key,
+ value=value,
+ out=out,
+ logsumexp=lse,
+ philox_seed=philox_seed,
+ philox_offset=philox_offset,
+ attn_bias=ctx.attn_mask,
+ cum_seq_q=cum_seq_q,
+ cum_seq_k=cum_seq_k,
+ max_q=ctx.max_q,
+ max_k=ctx.max_k,
+ dropout_p=ctx.dropout_p,
+ is_causal=ctx.is_causal,
+ scale=ctx.scale,
+ )
+
+ return grad_query, grad_key, grad_value, None, None, None, None, None
+
+
+@_AttentionProviderRegistry.register(
+ AttentionProvider._NATIVE_CUDNN,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_cp=True,
+)
+def native_cudnn_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ dispatch_fn = (
+ _native_ring_native_cudnn_attention
+ if _AttentionProviderRegistry.context_parallel_enabled()
+ else _native_cudnn_attention
+ )
+ return dispatch_fn.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, return_lse)
+
+
+class _native_efficient_attention(torch.autograd.Function):
+ # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14946
+ # forward declaration:
+ # aten::_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
+ # backward declaration:
+ # aten::_scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor)
+
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ ):
+ ctx.dropout_p = dropout_p
+ ctx.is_causal = is_causal
+ ctx.scale = scale
+ ctx.attn_mask = attn_mask
+
+ # NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details.
+ out, lse, philox_seed, philox_offset = _finetrainers_scaled_dot_product_efficient_attention_forward(
+ query=query,
+ key=key,
+ value=value,
+ attn_bias=attn_mask,
+ compute_log_sumexp=True,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ )
+
+ ctx.save_for_backward(query, key, value, out, lse, philox_seed, philox_offset)
+
+ return (out, lse) if return_lse else out
+
+ @staticmethod
+ def backward(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args: torch.Tensor,
+ ):
+ query, key, value, out, lse, philox_seed, philox_offset = ctx.saved_tensors
+
+ # NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details.
+ grad_query, grad_key, grad_value, grad_attn_bias = (
+ _finetrainers_scaled_dot_product_efficient_attention_backward(
+ grad_out_=grad_out,
+ query=query,
+ key=key,
+ value=value,
+ attn_bias=ctx.attn_mask,
+ out=out,
+ logsumexp=lse,
+ philox_seed=philox_seed,
+ philox_offset=philox_offset,
+ dropout_p=ctx.dropout_p,
+ grad_input_mask=[True, True, True, False],
+ is_causal=ctx.is_causal,
+ scale=ctx.scale,
+ )
+ )
+
+ return grad_query, grad_key, grad_value, None, None, None, None, None
+
+
+class _native_ring_native_efficient_attention(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ ):
+ _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set()
+ ctx.dropout_p = dropout_p
+ ctx.is_causal = is_causal
+ ctx.scale = scale
+ ctx.attn_mask = attn_mask
+
+ # NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details.
+ out, lse, philox_seed, philox_offset = _templated_ring_attention(
+ mesh=_AttentionProviderRegistry._mesh,
+ seq_dim=2,
+ op=_finetrainers_scaled_dot_product_efficient_attention_forward,
+ query=query,
+ key=key,
+ value=value,
+ attn_bias=attn_mask,
+ compute_log_sumexp=True,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ )
+
+ ctx.save_for_backward(query, key, value, out, lse, philox_seed, philox_offset)
+
+ return (out, lse) if return_lse else out
+
+ @staticmethod
+ def backward(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args: torch.Tensor,
+ ):
+ _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set()
+ query, key, value, out, lse, philox_seed, philox_offset = ctx.saved_tensors
+
+ # NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details.
+ grad_query, grad_key, grad_value, grad_attn_bias = _templated_ring_attention_backward(
+ mesh=_AttentionProviderRegistry._mesh,
+ seq_dim=2,
+ op=_finetrainers_scaled_dot_product_efficient_attention_backward,
+ grad_out=grad_out,
+ grad_out_name="grad_out_",
+ query=query,
+ key=key,
+ value=value,
+ attn_bias=ctx.attn_mask,
+ out=out,
+ logsumexp=lse,
+ philox_seed=philox_seed,
+ philox_offset=philox_offset,
+ dropout_p=ctx.dropout_p,
+ grad_input_mask=[True, True, True, False],
+ is_causal=ctx.is_causal,
+ scale=ctx.scale,
+ )
+
+ return grad_query, grad_key, grad_value, None, None, None, None, None
+
+
+@_AttentionProviderRegistry.register(
+ AttentionProvider._NATIVE_EFFICIENT,
+ constraints=[_check_device, _check_shape],
+ supports_cp=True,
+)
+def native_efficient_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+) -> torch.Tensor:
+ dispatch_fn = (
+ _native_ring_native_efficient_attention
+ if _AttentionProviderRegistry.context_parallel_enabled()
+ else _native_efficient_attention
+ )
+ return dispatch_fn.apply(query, key, value, attn_mask, dropout_p, is_causal, scale)
+
+
+class _native_flash_attention(torch.autograd.Function):
+ # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14910
+ # forward declaration:
+ # aten::_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
+ # backward declaration:
+ # aten::_scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
+
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ ):
+ ctx.dropout_p = dropout_p
+ ctx.is_causal = is_causal
+ ctx.scale = scale
+
+ out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
+ torch.ops.aten._scaled_dot_product_flash_attention(
+ query=query,
+ key=key,
+ value=value,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ return_debug_mask=False,
+ scale=scale,
+ )
+ )
+
+ ctx.max_q = max_q
+ ctx.max_k = max_k
+ ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
+
+ return (out, lse) if return_lse else out
+
+ @staticmethod
+ def backward(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args: torch.Tensor,
+ ):
+ query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors
+
+ grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_flash_attention_backward(
+ grad_out=grad_out,
+ query=query,
+ key=key,
+ value=value,
+ out=out,
+ logsumexp=lse,
+ cum_seq_q=cum_seq_q,
+ cum_seq_k=cum_seq_k,
+ max_q=ctx.max_q,
+ max_k=ctx.max_k,
+ dropout_p=ctx.dropout_p,
+ is_causal=ctx.is_causal,
+ philox_seed=philox_seed,
+ philox_offset=philox_offset,
+ scale=ctx.scale,
+ )
+
+ return grad_query, grad_key, grad_value, None, None, None, None
+
+
+class _native_ring_native_flash_attention(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ ):
+ _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set()
+ ctx.dropout_p = dropout_p
+ ctx.is_causal = is_causal
+ ctx.scale = scale
+
+ out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
+ _templated_ring_attention(
+ mesh=_AttentionProviderRegistry._mesh,
+ seq_dim=2,
+ op=torch.ops.aten._scaled_dot_product_flash_attention,
+ query=query,
+ key=key,
+ value=value,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ )
+ )
+
+ ctx.max_q = max_q
+ ctx.max_k = max_k
+ ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
+
+ return (out, lse) if return_lse else out
+
+ @staticmethod
+ def backward(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args: torch.Tensor,
+ ):
+ _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set()
+ query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors
+
+ grad_query, grad_key, grad_value, *_ = _templated_ring_attention_backward(
+ mesh=_AttentionProviderRegistry._mesh,
+ seq_dim=2,
+ op=torch.ops.aten._scaled_dot_product_flash_attention_backward,
+ grad_out=grad_out,
+ grad_out_name="grad_out",
+ query=query,
+ key=key,
+ value=value,
+ out=out,
+ logsumexp=lse,
+ dropout_p=ctx.dropout_p,
+ is_causal=ctx.is_causal,
+ scale=ctx.scale,
+ cum_seq_q=cum_seq_q,
+ cum_seq_k=cum_seq_k,
+ max_q=ctx.max_q,
+ max_k=ctx.max_k,
+ philox_seed=philox_seed,
+ philox_offset=philox_offset,
+ )
+
+ return grad_query, grad_key, grad_value, None, None, None, None
+
+
+@_AttentionProviderRegistry.register(
+ AttentionProvider._NATIVE_FLASH,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_cp=True,
+)
+def native_flash_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ dispatch_fn = (
+ _native_ring_native_flash_attention
+ if _AttentionProviderRegistry.context_parallel_enabled()
+ else _native_flash_attention
+ )
+ return dispatch_fn.apply(query, key, value, dropout_p, is_causal, scale, return_lse)
+
+
+# class _native_math_attention(torch.autograd.Function):
+# # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14901
+# # forward declaration:
+# # aten::_scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0., bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None, bool enable_gqa=False) -> (Tensor, Tensor)
+# # backward declaration:
+# # does not exist
+# @staticmethod
+# def forward(
+# ctx: torch.autograd.function.FunctionCtx,
+# query: torch.Tensor,
+# key: torch.Tensor,
+# value: torch.Tensor,
+# attn_mask: Optional[torch.Tensor] = None,
+# dropout_p: float = 0.0,
+# is_causal: bool = False,
+# dropout_mask: Optional[torch.Tensor] = None,
+# scale: Optional[float] = None,
+# enable_gqa: bool = False,
+# return_scores: bool = False,
+# ):
+# ctx.dropout_p = dropout_p
+# ctx.is_causal = is_causal
+# ctx.scale = scale
+# ctx.enable_gqa = enable_gqa
+
+# print(f"query.shape: {query.shape}")
+# with torch.enable_grad():
+# out, scores = torch.ops.aten._scaled_dot_product_attention_math(
+# query=query,
+# key=key,
+# value=value,
+# attn_mask=attn_mask,
+# dropout_p=dropout_p,
+# is_causal=is_causal,
+# dropout_mask=dropout_mask,
+# scale=scale,
+# enable_gqa=enable_gqa,
+# )
+
+# ctx.save_for_backward(query, key, value, out)
+
+# return (out, scores) if return_scores else out
+
+# @staticmethod
+# def backward(
+# ctx: torch.autograd.function.FunctionCtx,
+# grad_out: torch.Tensor,
+# ):
+# raise NotImplementedError("Backward pass for native math attention is not implemented.")
+
+
+@_AttentionProviderRegistry.register(
+ AttentionProvider._NATIVE_MATH,
+ constraints=[_check_device, _check_shape],
+ supports_cp=False,
+)
+def native_math_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
+ return native_sdpa(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+
+
+@_AttentionProviderRegistry.register(
+ AttentionProvider.SAGE,
+ constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_cp=False,
+)
+def _sage_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ if _AttentionProviderRegistry.context_parallel_enabled():
+ return_lse = True
+
+ kwargs = {
+ "q": query,
+ "k": key,
+ "v": value,
+ "tensor_layout": "HND",
+ "is_causal": is_causal,
+ "sm_scale": scale,
+ "return_lse": return_lse,
+ }
+ out = sageattn(**kwargs)
+
+ rest = None
+ if return_lse:
+ out, *rest = out
+ if return_lse:
+ return out, *rest[:1]
+ return out
+
+
+@_AttentionProviderRegistry.register(
+ AttentionProvider.SAGE_VARLEN,
+ constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _sage_varlen_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ cu_seqlens_q: Optional[torch.Tensor] = None,
+ cu_seqlens_k: Optional[torch.Tensor] = None,
+ max_seqlen_q: Optional[int] = None,
+ max_seqlen_k: Optional[int] = None,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ smooth_k: bool = True,
+ attn_mask: Optional[torch.Tensor] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ batch_size, _, seq_len_q, _ = query.shape
+ _, _, seq_len_kv, _ = key.shape
+
+ if attn_mask is not None:
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
+
+ if enable_gqa:
+ # TODO
+ pass
+
+ if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
+ )
+ )
+ else:
+ seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
+ cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
+ cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
+
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+
+ key_valid, value_valid = [], []
+ for b in range(batch_size):
+ valid_len = seqlens_k[b]
+ key_valid.append(key[b, :valid_len])
+ value_valid.append(value[b, :valid_len])
+
+ query_packed = query.flatten(0, 1)
+ key_packed = torch.cat(key_valid, dim=0)
+ value_packed = torch.cat(value_valid, dim=0)
+
+ out = sageattn_varlen(
+ q=query_packed,
+ k=key_packed,
+ v=value_packed,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ is_causal=is_causal,
+ sm_scale=scale,
+ smooth_k=smooth_k,
+ )
+ out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) # .contiguous()
+
+ return out
+
+
+@_AttentionProviderRegistry.register(
+ AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA,
+ constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
+ supports_cp=False,
+)
+def _sage_qk_int8_pv_fp8_cuda_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
+ pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
+ smooth_k: bool = True,
+ smooth_v: bool = False,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ return sageattn_qk_int8_pv_fp8_cuda(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="HND",
+ is_causal=is_causal,
+ qk_quant_gran=qk_quant_gran,
+ sm_scale=scale,
+ pv_accum_dtype=pv_accum_dtype,
+ smooth_k=smooth_k,
+ smooth_v=smooth_v,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionProviderRegistry.register(
+ AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
+ constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
+ supports_cp=False,
+)
+def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
+ pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
+ smooth_k: bool = True,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ return sageattn_qk_int8_pv_fp8_cuda_sm90(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="HND",
+ is_causal=is_causal,
+ qk_quant_gran=qk_quant_gran,
+ sm_scale=scale,
+ pv_accum_dtype=pv_accum_dtype,
+ smooth_k=smooth_k,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionProviderRegistry.register(
+ AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA,
+ constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
+ supports_cp=False,
+)
+def _sage_qk_int8_pv_fp16_cuda_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
+ pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
+ smooth_k: bool = True,
+ smooth_v: bool = False,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ return sageattn_qk_int8_pv_fp16_cuda(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="HND",
+ is_causal=is_causal,
+ qk_quant_gran=qk_quant_gran,
+ sm_scale=scale,
+ pv_accum_dtype=pv_accum_dtype,
+ smooth_k=smooth_k,
+ smooth_v=smooth_v,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionProviderRegistry.register(
+ AttentionProvider._SAGE_QK_INT8_PV_FP16_TRITON,
+ constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
+ supports_cp=False,
+)
+def _sage_qk_int8_pv_fp16_triton_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton",
+ smooth_k: bool = True,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ return sageattn_qk_int8_pv_fp16_triton(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="HND",
+ quantization_backend=quantization_backend,
+ is_causal=is_causal,
+ sm_scale=scale,
+ smooth_k=smooth_k,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionProviderRegistry.register(
+ AttentionProvider.XFORMERS,
+ constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
+)
+def _xformers_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ batch_size, num_heads_q, seq_len_q, _ = query.shape
+ _, num_heads_kv, seq_len_kv, _ = key.shape
+
+ # TODO: check if `contiguous` is really needed since it may cause unnecessary slowdowns
+ if is_causal:
+ attn_mask = xops.LowerTriangularMask()
+ elif attn_mask is not None:
+ if attn_mask.ndim == 2:
+ attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
+ elif attn_mask.ndim != 4:
+ raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
+ attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
+
+ # QKV need to be in [batch, seq_len, num_heads, head_dim] format for xformers
+ # query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value))
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+
+ if enable_gqa:
+ if num_heads_q % num_heads_kv != 0:
+ raise ValueError("Number of heads in query must be divisible by number of heads in key/value.")
+ num_heads_per_group = num_heads_q // num_heads_kv
+ query = query.unflatten(2, (num_heads_kv, -1))
+ key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
+ value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
+
+ out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale)
+ if enable_gqa:
+ out = out.flatten(2, 3)
+
+ out = out.permute(0, 2, 1, 3) # .contiguous()
+ return out
diff --git a/docs/finetrainers-src-codebase/finetrainers/models/cogvideox/__init__.py b/docs/finetrainers-src-codebase/finetrainers/models/cogvideox/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1f9a84073541b0e764877bac0335637f03d32ca
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/models/cogvideox/__init__.py
@@ -0,0 +1 @@
+from .base_specification import CogVideoXModelSpecification
diff --git a/docs/finetrainers-src-codebase/finetrainers/models/cogvideox/base_specification.py b/docs/finetrainers-src-codebase/finetrainers/models/cogvideox/base_specification.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c0e6210c47f448cfda570f3f07d324f7f980a71
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/models/cogvideox/base_specification.py
@@ -0,0 +1,410 @@
+import functools
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+from accelerate import init_empty_weights
+from diffusers import (
+ AutoencoderKLCogVideoX,
+ CogVideoXDDIMScheduler,
+ CogVideoXImageToVideoPipeline,
+ CogVideoXPipeline,
+ CogVideoXTransformer3DModel,
+)
+from PIL.Image import Image
+from transformers import AutoModel, AutoTokenizer, T5EncoderModel, T5Tokenizer
+
+from finetrainers.data import VideoArtifact
+from finetrainers.logging import get_logger
+from finetrainers.models.modeling_utils import ModelSpecification
+from finetrainers.models.utils import DiagonalGaussianDistribution
+from finetrainers.processors import ProcessorMixin, T5Processor
+from finetrainers.typing import ArtifactType, SchedulerType
+from finetrainers.utils import _enable_vae_memory_optimizations, get_non_null_items, safetensors_torch_save_function
+
+from .utils import prepare_rotary_positional_embeddings
+
+
+logger = get_logger()
+
+
+class CogVideoXLatentEncodeProcessor(ProcessorMixin):
+ r"""
+ Processor to encode image/video into latents using the CogVideoX VAE.
+
+ Args:
+ output_names (`List[str]`):
+ The names of the outputs that the processor returns. The outputs are in the following order:
+ - latents: The latents of the input image/video.
+ """
+
+ def __init__(self, output_names: List[str]):
+ super().__init__()
+ self.output_names = output_names
+ assert len(self.output_names) == 1
+
+ def forward(
+ self,
+ vae: AutoencoderKLCogVideoX,
+ image: Optional[torch.Tensor] = None,
+ video: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ ) -> Dict[str, torch.Tensor]:
+ device = vae.device
+ dtype = vae.dtype
+
+ if image is not None:
+ video = image.unsqueeze(1)
+
+ assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
+ video = video.to(device=device, dtype=vae.dtype)
+ video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
+
+ if compute_posterior:
+ latents = vae.encode(video).latent_dist.sample(generator=generator)
+ latents = latents.to(dtype=dtype)
+ else:
+ if vae.use_slicing and video.shape[0] > 1:
+ encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
+ moments = torch.cat(encoded_slices)
+ else:
+ moments = vae._encode(video)
+ latents = moments.to(dtype=dtype)
+
+ latents = latents.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] -> [B, F, C, H, W]
+ return {self.output_names[0]: latents}
+
+
+class CogVideoXModelSpecification(ModelSpecification):
+ def __init__(
+ self,
+ pretrained_model_name_or_path: str = "THUDM/CogVideoX-5b",
+ tokenizer_id: Optional[str] = None,
+ text_encoder_id: Optional[str] = None,
+ transformer_id: Optional[str] = None,
+ vae_id: Optional[str] = None,
+ text_encoder_dtype: torch.dtype = torch.bfloat16,
+ transformer_dtype: torch.dtype = torch.bfloat16,
+ vae_dtype: torch.dtype = torch.bfloat16,
+ revision: Optional[str] = None,
+ cache_dir: Optional[str] = None,
+ condition_model_processors: List[ProcessorMixin] = None,
+ latent_model_processors: List[ProcessorMixin] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ tokenizer_id=tokenizer_id,
+ text_encoder_id=text_encoder_id,
+ transformer_id=transformer_id,
+ vae_id=vae_id,
+ text_encoder_dtype=text_encoder_dtype,
+ transformer_dtype=transformer_dtype,
+ vae_dtype=vae_dtype,
+ revision=revision,
+ cache_dir=cache_dir,
+ )
+
+ if condition_model_processors is None:
+ condition_model_processors = [T5Processor(["encoder_hidden_states", "prompt_attention_mask"])]
+ if latent_model_processors is None:
+ latent_model_processors = [CogVideoXLatentEncodeProcessor(["latents"])]
+
+ self.condition_model_processors = condition_model_processors
+ self.latent_model_processors = latent_model_processors
+
+ @property
+ def _resolution_dim_keys(self):
+ return {"latents": (1, 3, 4)}
+
+ def load_condition_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.tokenizer_id is not None:
+ tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs)
+ else:
+ tokenizer = T5Tokenizer.from_pretrained(
+ self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs
+ )
+
+ if self.text_encoder_id is not None:
+ text_encoder = AutoModel.from_pretrained(
+ self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs
+ )
+ else:
+ text_encoder = T5EncoderModel.from_pretrained(
+ self.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ torch_dtype=self.text_encoder_dtype,
+ **common_kwargs,
+ )
+
+ return {"tokenizer": tokenizer, "text_encoder": text_encoder}
+
+ def load_latent_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.vae_id is not None:
+ vae = AutoencoderKLCogVideoX.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs)
+ else:
+ vae = AutoencoderKLCogVideoX.from_pretrained(
+ self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs
+ )
+
+ return {"vae": vae}
+
+ def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.transformer_id is not None:
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
+ self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs
+ )
+ else:
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
+ self.pretrained_model_name_or_path,
+ subfolder="transformer",
+ torch_dtype=self.transformer_dtype,
+ **common_kwargs,
+ )
+
+ scheduler = CogVideoXDDIMScheduler.from_pretrained(
+ self.pretrained_model_name_or_path, subfolder="scheduler", **common_kwargs
+ )
+
+ return {"transformer": transformer, "scheduler": scheduler}
+
+ def load_pipeline(
+ self,
+ tokenizer: Optional[T5Tokenizer] = None,
+ text_encoder: Optional[T5EncoderModel] = None,
+ transformer: Optional[CogVideoXTransformer3DModel] = None,
+ vae: Optional[AutoencoderKLCogVideoX] = None,
+ scheduler: Optional[CogVideoXDDIMScheduler] = None,
+ enable_slicing: bool = False,
+ enable_tiling: bool = False,
+ enable_model_cpu_offload: bool = False,
+ training: bool = False,
+ **kwargs,
+ ) -> CogVideoXPipeline:
+ components = {
+ "tokenizer": tokenizer,
+ "text_encoder": text_encoder,
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ }
+ components = get_non_null_items(components)
+
+ pipe = CogVideoXPipeline.from_pretrained(
+ self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
+ )
+ pipe.text_encoder.to(self.text_encoder_dtype)
+ pipe.vae.to(self.vae_dtype)
+
+ _enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling)
+ if not training:
+ pipe.transformer.to(self.transformer_dtype)
+ if enable_model_cpu_offload:
+ pipe.enable_model_cpu_offload()
+ return pipe
+
+ @torch.no_grad()
+ def prepare_conditions(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ caption: str,
+ max_sequence_length: int = 226,
+ **kwargs,
+ ) -> Dict[str, Any]:
+ conditions = {
+ "tokenizer": tokenizer,
+ "text_encoder": text_encoder,
+ "caption": caption,
+ "max_sequence_length": max_sequence_length,
+ **kwargs,
+ }
+ input_keys = set(conditions.keys())
+ conditions = super().prepare_conditions(**conditions)
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+ conditions.pop("prompt_attention_mask", None)
+ return conditions
+
+ @torch.no_grad()
+ def prepare_latents(
+ self,
+ vae: AutoencoderKLCogVideoX,
+ image: Optional[torch.Tensor] = None,
+ video: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ **kwargs,
+ ) -> Dict[str, torch.Tensor]:
+ conditions = {
+ "vae": vae,
+ "image": image,
+ "video": video,
+ "generator": generator,
+ "compute_posterior": compute_posterior,
+ **kwargs,
+ }
+ input_keys = set(conditions.keys())
+ conditions = super().prepare_latents(**conditions)
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+ return conditions
+
+ def forward(
+ self,
+ transformer: CogVideoXTransformer3DModel,
+ scheduler: CogVideoXDDIMScheduler,
+ condition_model_conditions: Dict[str, torch.Tensor],
+ latent_model_conditions: Dict[str, torch.Tensor],
+ sigmas: torch.Tensor,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, ...]:
+ # Just hardcode for now. In Diffusers, we will refactor such that RoPE would be handled within the model itself.
+ VAE_SPATIAL_SCALE_FACTOR = 8
+ rope_base_height = self.transformer_config.sample_height * VAE_SPATIAL_SCALE_FACTOR
+ rope_base_width = self.transformer_config.sample_width * VAE_SPATIAL_SCALE_FACTOR
+ patch_size = self.transformer_config.patch_size
+ patch_size_t = getattr(self.transformer_config, "patch_size_t", None)
+
+ if compute_posterior:
+ latents = latent_model_conditions.pop("latents")
+ else:
+ posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"), _dim=2)
+ latents = posterior.sample(generator=generator)
+ del posterior
+
+ if not getattr(self.vae_config, "invert_scale_latents", False):
+ latents = latents * self.vae_config.scaling_factor
+
+ if patch_size_t is not None:
+ latents = self._pad_frames(latents, patch_size_t)
+
+ timesteps = (sigmas.flatten() * 1000.0).long()
+
+ noise = torch.zeros_like(latents).normal_(generator=generator)
+ noisy_latents = scheduler.add_noise(latents, noise, timesteps)
+
+ batch_size, num_frames, num_channels, height, width = latents.shape
+ ofs_emb = (
+ None
+ if getattr(self.transformer_config, "ofs_embed_dim", None) is None
+ else latents.new_full((batch_size,), fill_value=2.0)
+ )
+
+ image_rotary_emb = (
+ prepare_rotary_positional_embeddings(
+ height=height * VAE_SPATIAL_SCALE_FACTOR,
+ width=width * VAE_SPATIAL_SCALE_FACTOR,
+ num_frames=num_frames,
+ vae_scale_factor_spatial=VAE_SPATIAL_SCALE_FACTOR,
+ patch_size=patch_size,
+ patch_size_t=patch_size_t,
+ attention_head_dim=self.transformer_config.attention_head_dim,
+ device=transformer.device,
+ base_height=rope_base_height,
+ base_width=rope_base_width,
+ )
+ if self.transformer_config.use_rotary_positional_embeddings
+ else None
+ )
+
+ latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
+ latent_model_conditions["image_rotary_emb"] = image_rotary_emb
+ latent_model_conditions["ofs"] = ofs_emb
+
+ velocity = transformer(
+ **latent_model_conditions,
+ **condition_model_conditions,
+ timestep=timesteps,
+ return_dict=False,
+ )[0]
+ # For CogVideoX, the transformer predicts the velocity. The denoised output is calculated by applying the same
+ # code paths as scheduler.get_velocity(), which can be confusing to understand.
+ pred = scheduler.get_velocity(velocity, noisy_latents, timesteps)
+ target = latents
+
+ return pred, target, sigmas
+
+ def validation(
+ self,
+ pipeline: CogVideoXPipeline,
+ prompt: str,
+ image: Optional[Image] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_frames: Optional[int] = None,
+ num_inference_steps: int = 50,
+ generator: Optional[torch.Generator] = None,
+ **kwargs,
+ ) -> List[ArtifactType]:
+ # TODO(aryan): add support for more parameters
+ if image is not None:
+ pipeline = CogVideoXImageToVideoPipeline.from_pipe(pipeline)
+
+ generation_kwargs = {
+ "prompt": prompt,
+ "image": image,
+ "height": height,
+ "width": width,
+ "num_frames": num_frames,
+ "num_inference_steps": num_inference_steps,
+ "generator": generator,
+ "return_dict": True,
+ "output_type": "pil",
+ }
+ generation_kwargs = get_non_null_items(generation_kwargs)
+ video = pipeline(**generation_kwargs).frames[0]
+ return [VideoArtifact(value=video)]
+
+ def _save_lora_weights(
+ self,
+ directory: str,
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ scheduler: Optional[SchedulerType] = None,
+ metadata: Optional[Dict[str, str]] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ # TODO(aryan): this needs refactoring
+ if transformer_state_dict is not None:
+ CogVideoXPipeline.save_lora_weights(
+ directory,
+ transformer_state_dict,
+ save_function=functools.partial(safetensors_torch_save_function, metadata=metadata),
+ safe_serialization=True,
+ )
+ if scheduler is not None:
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
+
+ def _save_model(
+ self,
+ directory: str,
+ transformer: CogVideoXTransformer3DModel,
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ scheduler: Optional[SchedulerType] = None,
+ ) -> None:
+ # TODO(aryan): this needs refactoring
+ if transformer_state_dict is not None:
+ with init_empty_weights():
+ transformer_copy = CogVideoXTransformer3DModel.from_config(transformer.config)
+ transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
+ transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
+ if scheduler is not None:
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
+
+ @staticmethod
+ def _pad_frames(latents: torch.Tensor, patch_size_t: int) -> torch.Tensor:
+ num_frames = latents.size(1)
+ additional_frames = patch_size_t - (num_frames % patch_size_t)
+ if additional_frames > 0:
+ last_frame = latents[:, -1:]
+ padding_frames = last_frame.expand(-1, additional_frames, -1, -1, -1)
+ latents = torch.cat([latents, padding_frames], dim=1)
+ return latents
diff --git a/docs/finetrainers-src-codebase/finetrainers/models/cogvideox/utils.py b/docs/finetrainers-src-codebase/finetrainers/models/cogvideox/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd98c1f3653dbe23a6f53fa54dfe3e7073ea9b99
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/models/cogvideox/utils.py
@@ -0,0 +1,51 @@
+from typing import Optional, Tuple
+
+import torch
+from diffusers.models.embeddings import get_3d_rotary_pos_embed
+from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
+
+
+def prepare_rotary_positional_embeddings(
+ height: int,
+ width: int,
+ num_frames: int,
+ vae_scale_factor_spatial: int = 8,
+ patch_size: int = 2,
+ patch_size_t: int = None,
+ attention_head_dim: int = 64,
+ device: Optional[torch.device] = None,
+ base_height: int = 480,
+ base_width: int = 720,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ grid_height = height // (vae_scale_factor_spatial * patch_size)
+ grid_width = width // (vae_scale_factor_spatial * patch_size)
+ base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
+ base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
+
+ if patch_size_t is None:
+ # CogVideoX 1.0
+ grid_crops_coords = get_resize_crop_region_for_grid(
+ (grid_height, grid_width), base_size_width, base_size_height
+ )
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=attention_head_dim,
+ crops_coords=grid_crops_coords,
+ grid_size=(grid_height, grid_width),
+ temporal_size=num_frames,
+ )
+ else:
+ # CogVideoX 1.5
+ base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t
+
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=attention_head_dim,
+ crops_coords=None,
+ grid_size=(grid_height, grid_width),
+ temporal_size=base_num_frames,
+ grid_type="slice",
+ max_size=(base_size_height, base_size_width),
+ )
+
+ freqs_cos = freqs_cos.to(device=device)
+ freqs_sin = freqs_sin.to(device=device)
+ return freqs_cos, freqs_sin
diff --git a/docs/finetrainers-src-codebase/finetrainers/models/cogview4/__init__.py b/docs/finetrainers-src-codebase/finetrainers/models/cogview4/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b6afde3a5f16247cc6f47fc16561186e31a22ad
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/models/cogview4/__init__.py
@@ -0,0 +1,2 @@
+from .base_specification import CogView4ModelSpecification
+from .control_specification import CogView4ControlModelSpecification
diff --git a/docs/finetrainers-src-codebase/finetrainers/models/cogview4/base_specification.py b/docs/finetrainers-src-codebase/finetrainers/models/cogview4/base_specification.py
new file mode 100644
index 0000000000000000000000000000000000000000..f89eb21d878f475f361d6def12591c5037b248cc
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/models/cogview4/base_specification.py
@@ -0,0 +1,385 @@
+import functools
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+from accelerate import init_empty_weights
+from diffusers import (
+ AutoencoderKL,
+ CogView4Pipeline,
+ CogView4Transformer2DModel,
+ FlowMatchEulerDiscreteScheduler,
+)
+from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
+from transformers import AutoTokenizer, GlmModel
+
+import finetrainers.functional as FF
+from finetrainers.data import ImageArtifact
+from finetrainers.logging import get_logger
+from finetrainers.models.modeling_utils import ModelSpecification
+from finetrainers.processors import CogView4GLMProcessor, ProcessorMixin
+from finetrainers.typing import ArtifactType, SchedulerType
+from finetrainers.utils import _enable_vae_memory_optimizations, get_non_null_items, safetensors_torch_save_function
+
+
+logger = get_logger()
+
+
+class CogView4LatentEncodeProcessor(ProcessorMixin):
+ r"""
+ Processor to encode image/video into latents using the LTX VAE.
+
+ Args:
+ output_names (`List[str]`):
+ The names of the outputs that the processor returns. The outputs are in the following order:
+ - latents: The latents of the input image/video.
+ - original_size: The original size of the input image/video.
+ - target_size: The target size of the input image/video.
+ - crop_coords: The top-left crop coordinates of the input image/video.
+ """
+
+ def __init__(self, output_names: List[str]):
+ super().__init__()
+
+ self.output_names = output_names
+ assert len(self.output_names) == 4
+
+ def forward(
+ self,
+ vae: AutoencoderKL,
+ image: Optional[torch.Tensor] = None,
+ video: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ _original_height: Optional[int] = None,
+ _original_width: Optional[int] = None,
+ ) -> Dict[str, torch.Tensor]:
+ device = vae.device
+ dtype = vae.dtype
+
+ if video is not None:
+ # TODO(aryan): perhaps better would be to flatten(0, 1), but need to account for reshaping sigmas accordingly
+ image = video[:, 0] # [B, F, C, H, W] -> [B, 1, C, H, W]
+
+ assert image.ndim == 4, f"Expected 4D tensor, got {image.ndim}D tensor"
+ image = image.to(device=device, dtype=vae.dtype)
+
+ if compute_posterior:
+ latents = vae.encode(image).latent_dist.sample(generator=generator)
+ latents = latents.to(dtype=dtype)
+ else:
+ if vae.use_slicing and image.shape[0] > 1:
+ encoded_slices = [vae._encode(x_slice) for x_slice in image.split(1)]
+ moments = torch.cat(encoded_slices)
+ else:
+ moments = vae._encode(image)
+ latents = moments.to(dtype=dtype)
+
+ batch_size = latents.size(0)
+ target_height = image.size(2)
+ target_width = image.size(3)
+ original_size = torch.tensor([(_original_height, _original_width)], device=device, dtype=dtype).repeat(
+ batch_size, 1
+ )
+ target_size = torch.tensor([(target_height, target_width)], device=device, dtype=dtype).repeat(batch_size, 1)
+ crop_coords = torch.tensor([(0, 0)], device=device, dtype=dtype).repeat(batch_size, 1)
+
+ return {
+ self.output_names[0]: latents,
+ self.output_names[1]: original_size,
+ self.output_names[2]: target_size,
+ self.output_names[3]: crop_coords,
+ }
+
+
+class CogView4ModelSpecification(ModelSpecification):
+ def __init__(
+ self,
+ pretrained_model_name_or_path: str = "THUDM/CogView4-6B",
+ tokenizer_id: Optional[str] = None,
+ text_encoder_id: Optional[str] = None,
+ transformer_id: Optional[str] = None,
+ vae_id: Optional[str] = None,
+ text_encoder_dtype: torch.dtype = torch.bfloat16,
+ transformer_dtype: torch.dtype = torch.bfloat16,
+ vae_dtype: torch.dtype = torch.bfloat16,
+ revision: Optional[str] = None,
+ cache_dir: Optional[str] = None,
+ condition_model_processors: List[ProcessorMixin] = None,
+ latent_model_processors: List[ProcessorMixin] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ tokenizer_id=tokenizer_id,
+ text_encoder_id=text_encoder_id,
+ transformer_id=transformer_id,
+ vae_id=vae_id,
+ text_encoder_dtype=text_encoder_dtype,
+ transformer_dtype=transformer_dtype,
+ vae_dtype=vae_dtype,
+ revision=revision,
+ cache_dir=cache_dir,
+ )
+
+ if condition_model_processors is None:
+ condition_model_processors = [CogView4GLMProcessor(["encoder_hidden_states"])]
+ if latent_model_processors is None:
+ latent_model_processors = [
+ CogView4LatentEncodeProcessor(["latents", "original_size", "target_size", "crop_coords"])
+ ]
+
+ self.condition_model_processors = condition_model_processors
+ self.latent_model_processors = latent_model_processors
+
+ @property
+ def _resolution_dim_keys(self):
+ return {"latents": (2, 3)}
+
+ def load_condition_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.tokenizer_id is not None:
+ tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs)
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(
+ self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs
+ )
+
+ if self.text_encoder_id is not None:
+ text_encoder = GlmModel.from_pretrained(
+ self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs
+ )
+ else:
+ text_encoder = GlmModel.from_pretrained(
+ self.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ torch_dtype=self.text_encoder_dtype,
+ **common_kwargs,
+ )
+
+ return {"tokenizer": tokenizer, "text_encoder": text_encoder}
+
+ def load_latent_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.vae_id is not None:
+ vae = AutoencoderKL.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs)
+ else:
+ vae = AutoencoderKL.from_pretrained(
+ self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs
+ )
+
+ return {"vae": vae}
+
+ def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.transformer_id is not None:
+ transformer = CogView4Transformer2DModel.from_pretrained(
+ self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs
+ )
+ else:
+ transformer = CogView4Transformer2DModel.from_pretrained(
+ self.pretrained_model_name_or_path,
+ subfolder="transformer",
+ torch_dtype=self.transformer_dtype,
+ **common_kwargs,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {"transformer": transformer, "scheduler": scheduler}
+
+ def load_pipeline(
+ self,
+ tokenizer: Optional[AutoTokenizer] = None,
+ text_encoder: Optional[GlmModel] = None,
+ transformer: Optional[CogView4Transformer2DModel] = None,
+ vae: Optional[AutoencoderKL] = None,
+ scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
+ enable_slicing: bool = False,
+ enable_tiling: bool = False,
+ enable_model_cpu_offload: bool = False,
+ training: bool = False,
+ **kwargs,
+ ) -> CogView4Pipeline:
+ components = {
+ "tokenizer": tokenizer,
+ "text_encoder": text_encoder,
+ "transformer": transformer,
+ "vae": vae,
+ # Load the scheduler based on CogView4's config instead of using the default initialization being used for training
+ # "scheduler": scheduler,
+ }
+ components = get_non_null_items(components)
+
+ pipe = CogView4Pipeline.from_pretrained(
+ self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
+ )
+ pipe.text_encoder.to(self.text_encoder_dtype)
+ pipe.vae.to(self.vae_dtype)
+
+ _enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling)
+ if not training:
+ pipe.transformer.to(self.transformer_dtype)
+ if enable_model_cpu_offload:
+ pipe.enable_model_cpu_offload()
+ return pipe
+
+ @torch.no_grad()
+ def prepare_conditions(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: GlmModel,
+ caption: str,
+ max_sequence_length: int = 1024,
+ **kwargs,
+ ) -> Dict[str, Any]:
+ conditions = {
+ "tokenizer": tokenizer,
+ "text_encoder": text_encoder,
+ "caption": caption,
+ "max_sequence_length": max_sequence_length,
+ **kwargs,
+ }
+ input_keys = set(conditions.keys())
+ conditions = super().prepare_conditions(**conditions)
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+ return conditions
+
+ @torch.no_grad()
+ def prepare_latents(
+ self,
+ vae: AutoencoderKL,
+ image: Optional[torch.Tensor] = None,
+ video: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ _original_height: Optional[int] = None,
+ _original_width: Optional[int] = None,
+ **kwargs,
+ ) -> Dict[str, torch.Tensor]:
+ conditions = {
+ "vae": vae,
+ "image": image,
+ "video": video,
+ "generator": generator,
+ "compute_posterior": compute_posterior,
+ "_original_height": _original_height,
+ "_original_width": _original_width,
+ **kwargs,
+ }
+ input_keys = set(conditions.keys())
+ conditions = super().prepare_latents(**conditions)
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+ return conditions
+
+ def forward(
+ self,
+ transformer: CogView4Transformer2DModel,
+ condition_model_conditions: Dict[str, torch.Tensor],
+ latent_model_conditions: Dict[str, torch.Tensor],
+ sigmas: torch.Tensor,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, ...]:
+ base_image_sequence_length = 256
+ base_shift = 0.25
+ max_shift = 0.75
+
+ if compute_posterior:
+ latents = latent_model_conditions.pop("latents")
+ else:
+ posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"))
+ latents = posterior.sample(generator=generator)
+ del posterior
+
+ if getattr(self.vae_config, "shift_factor", None) is not None:
+ latents = (latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor
+ else:
+ latents = latents * self.vae_config.scaling_factor
+
+ noise = torch.zeros_like(latents).normal_(generator=generator)
+ timesteps = (sigmas.flatten() * 1000.0).long()
+
+ image_sequence_length = latents.size(2) * latents.size(3) // self.transformer_config.patch_size**2
+ mu = (image_sequence_length / base_image_sequence_length) ** 0.5
+ mu = mu * max_shift + base_shift
+ shifted_sigmas = mu / (mu + (1 / sigmas - 1) ** 1.0)
+ noisy_latents = FF.flow_match_xt(latents, noise, shifted_sigmas)
+
+ latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
+
+ pred = transformer(
+ **latent_model_conditions,
+ **condition_model_conditions,
+ timestep=timesteps,
+ return_dict=False,
+ )[0]
+ target = FF.flow_match_target(noise, latents)
+
+ # NOTE: shifted_sigmas loss weighting seems to work better than sigmas. Needs more investigation
+ # but let's keep it this way for now. Longer training runs should reveal more insights.
+ # return pred, target, sigmas
+ return pred, target, shifted_sigmas
+
+ def validation(
+ self,
+ pipeline: CogView4Pipeline,
+ prompt: str,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ generator: Optional[torch.Generator] = None,
+ **kwargs,
+ ) -> List[ArtifactType]:
+ generation_kwargs = {
+ "prompt": prompt,
+ "height": height,
+ "width": width,
+ "num_inference_steps": num_inference_steps,
+ "generator": generator,
+ "return_dict": True,
+ "output_type": "pil",
+ }
+ generation_kwargs = get_non_null_items(generation_kwargs)
+ image = pipeline(**generation_kwargs).images[0]
+ return [ImageArtifact(value=image)]
+
+ def _save_lora_weights(
+ self,
+ directory: str,
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ scheduler: Optional[SchedulerType] = None,
+ metadata: Optional[Dict[str, str]] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ # TODO(aryan): this needs refactoring
+ if transformer_state_dict is not None:
+ CogView4Pipeline.save_lora_weights(
+ directory,
+ transformer_state_dict,
+ save_function=functools.partial(safetensors_torch_save_function, metadata=metadata),
+ safe_serialization=True,
+ )
+ if scheduler is not None:
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
+
+ def _save_model(
+ self,
+ directory: str,
+ transformer: CogView4Transformer2DModel,
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ scheduler: Optional[SchedulerType] = None,
+ ) -> None:
+ # TODO(aryan): this needs refactoring
+ if transformer_state_dict is not None:
+ with init_empty_weights():
+ transformer_copy = CogView4Transformer2DModel.from_config(transformer.config)
+ transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
+ transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
+ if scheduler is not None:
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
diff --git a/docs/finetrainers-src-codebase/finetrainers/models/cogview4/control_specification.py b/docs/finetrainers-src-codebase/finetrainers/models/cogview4/control_specification.py
new file mode 100644
index 0000000000000000000000000000000000000000..16f359fa4a59f6b218aebfa217c5a22bfd6afdb2
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/models/cogview4/control_specification.py
@@ -0,0 +1,375 @@
+import functools
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+import safetensors.torch
+import torch
+from accelerate import init_empty_weights
+from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
+from transformers import AutoTokenizer, GlmModel
+
+import finetrainers.functional as FF
+from finetrainers.data import ImageArtifact
+from finetrainers.models.modeling_utils import ControlModelSpecification
+from finetrainers.models.utils import DiagonalGaussianDistribution, _expand_linear_with_zeroed_weights
+from finetrainers.patches.dependencies.diffusers.control import control_channel_concat
+from finetrainers.processors import CogView4GLMProcessor, ProcessorMixin
+from finetrainers.typing import ArtifactType, SchedulerType
+from finetrainers.utils import _enable_vae_memory_optimizations, get_non_null_items, safetensors_torch_save_function
+
+from .base_specification import CogView4LatentEncodeProcessor
+
+
+class CogView4ControlModelSpecification(ControlModelSpecification):
+ def __init__(
+ self,
+ pretrained_model_name_or_path: str = "THUDM/CogView4-6B",
+ tokenizer_id: Optional[str] = None,
+ text_encoder_id: Optional[str] = None,
+ transformer_id: Optional[str] = None,
+ vae_id: Optional[str] = None,
+ text_encoder_dtype: torch.dtype = torch.bfloat16,
+ transformer_dtype: torch.dtype = torch.bfloat16,
+ vae_dtype: torch.dtype = torch.bfloat16,
+ revision: Optional[str] = None,
+ cache_dir: Optional[str] = None,
+ condition_model_processors: List[ProcessorMixin] = None,
+ latent_model_processors: List[ProcessorMixin] = None,
+ control_model_processors: List[ProcessorMixin] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ tokenizer_id=tokenizer_id,
+ text_encoder_id=text_encoder_id,
+ transformer_id=transformer_id,
+ vae_id=vae_id,
+ text_encoder_dtype=text_encoder_dtype,
+ transformer_dtype=transformer_dtype,
+ vae_dtype=vae_dtype,
+ revision=revision,
+ cache_dir=cache_dir,
+ )
+
+ if condition_model_processors is None:
+ condition_model_processors = [CogView4GLMProcessor(["encoder_hidden_states"])]
+ if latent_model_processors is None:
+ latent_model_processors = [
+ CogView4LatentEncodeProcessor(["latents", "original_size", "target_size", "crop_coords"])
+ ]
+ if control_model_processors is None:
+ control_model_processors = [
+ CogView4LatentEncodeProcessor(["control_latents", "original_size", "target_size", "crop_coords"])
+ ]
+
+ self.condition_model_processors = condition_model_processors
+ self.latent_model_processors = latent_model_processors
+ self.control_model_processors = control_model_processors
+
+ @property
+ def control_injection_layer_name(self):
+ return "patch_embed.proj"
+
+ @property
+ def _resolution_dim_keys(self):
+ return {"latents": (2, 3)}
+
+ def load_condition_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.tokenizer_id is not None:
+ tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs)
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(
+ self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs
+ )
+
+ if self.text_encoder_id is not None:
+ text_encoder = GlmModel.from_pretrained(
+ self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs
+ )
+ else:
+ text_encoder = GlmModel.from_pretrained(
+ self.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ torch_dtype=self.text_encoder_dtype,
+ **common_kwargs,
+ )
+
+ return {"tokenizer": tokenizer, "text_encoder": text_encoder}
+
+ def load_latent_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.vae_id is not None:
+ vae = AutoencoderKL.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs)
+ else:
+ vae = AutoencoderKL.from_pretrained(
+ self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs
+ )
+
+ return {"vae": vae}
+
+ def load_diffusion_models(self, new_in_features: int) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.transformer_id is not None:
+ transformer = CogView4Transformer2DModel.from_pretrained(
+ self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs
+ )
+ else:
+ transformer = CogView4Transformer2DModel.from_pretrained(
+ self.pretrained_model_name_or_path,
+ subfolder="transformer",
+ torch_dtype=self.transformer_dtype,
+ **common_kwargs,
+ )
+
+ actual_new_in_features = new_in_features * transformer.config.patch_size**2
+ transformer.patch_embed.proj = _expand_linear_with_zeroed_weights(
+ transformer.patch_embed.proj, new_in_features=actual_new_in_features
+ )
+ transformer.register_to_config(in_channels=new_in_features)
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {"transformer": transformer, "scheduler": scheduler}
+
+ def load_pipeline(
+ self,
+ tokenizer: Optional[AutoTokenizer] = None,
+ text_encoder: Optional[GlmModel] = None,
+ transformer: Optional[CogView4Transformer2DModel] = None,
+ vae: Optional[AutoencoderKL] = None,
+ scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
+ enable_slicing: bool = False,
+ enable_tiling: bool = False,
+ enable_model_cpu_offload: bool = False,
+ training: bool = False,
+ **kwargs,
+ ) -> CogView4Pipeline:
+ components = {
+ "tokenizer": tokenizer,
+ "text_encoder": text_encoder,
+ "transformer": transformer,
+ "vae": vae,
+ # Load the scheduler based on CogView4's config instead of using the default initialization being used for training
+ # "scheduler": scheduler,
+ }
+ components = get_non_null_items(components)
+
+ pipe = CogView4Pipeline.from_pretrained(
+ self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
+ )
+ pipe.text_encoder.to(self.text_encoder_dtype)
+ pipe.vae.to(self.vae_dtype)
+
+ _enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling)
+ if not training:
+ pipe.transformer.to(self.transformer_dtype)
+ if enable_model_cpu_offload:
+ pipe.enable_model_cpu_offload()
+
+ return pipe
+
+ @torch.no_grad()
+ def prepare_conditions(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: GlmModel,
+ caption: str,
+ max_sequence_length: int = 1024,
+ **kwargs,
+ ) -> Dict[str, Any]:
+ conditions = {
+ "tokenizer": tokenizer,
+ "text_encoder": text_encoder,
+ "caption": caption,
+ "max_sequence_length": max_sequence_length,
+ **kwargs,
+ }
+ input_keys = set(conditions.keys())
+ conditions = super().prepare_conditions(**conditions)
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+ return conditions
+
+ @torch.no_grad()
+ def prepare_latents(
+ self,
+ vae: AutoencoderKL,
+ image: Optional[torch.Tensor] = None,
+ video: Optional[torch.Tensor] = None,
+ control_image: Optional[torch.Tensor] = None,
+ control_video: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ _original_height: Optional[int] = None,
+ _original_width: Optional[int] = None,
+ **kwargs,
+ ) -> Dict[str, torch.Tensor]:
+ common_kwargs = {
+ "vae": vae,
+ "generator": generator,
+ "compute_posterior": compute_posterior,
+ "_original_height": _original_height,
+ "_original_width": _original_width,
+ **kwargs,
+ }
+ conditions = {"image": image, "video": video, **common_kwargs}
+ input_keys = set(conditions.keys())
+ conditions = ControlModelSpecification.prepare_latents(self, self.latent_model_processors, **conditions)
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+
+ control_conditions = {"image": control_image, "video": control_video, **common_kwargs}
+ input_keys = set(control_conditions.keys())
+ control_conditions = ControlModelSpecification.prepare_latents(
+ self, self.control_model_processors, **control_conditions
+ )
+ control_conditions = {k: v for k, v in control_conditions.items() if k not in input_keys}
+
+ return {**control_conditions, **conditions}
+
+ def forward(
+ self,
+ transformer: CogView4Transformer2DModel,
+ condition_model_conditions: Dict[str, torch.Tensor],
+ latent_model_conditions: Dict[str, torch.Tensor],
+ sigmas: torch.Tensor,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, ...]:
+ base_image_sequence_length = 256
+ base_shift = 0.25
+ max_shift = 0.75
+
+ if compute_posterior:
+ latents = latent_model_conditions.pop("latents")
+ control_latents = latent_model_conditions.pop("control_latents")
+ else:
+ posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"))
+ latents = posterior.sample(generator=generator)
+ del posterior
+
+ control_posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("control_latents"))
+ control_latents = control_posterior.sample(generator=generator)
+ del control_posterior
+
+ if getattr(self.vae_config, "shift_factor") is not None:
+ latents = (latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor
+ control_latents = (control_latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor
+
+ noise = torch.zeros_like(latents).normal_(generator=generator)
+ timesteps = (sigmas.flatten() * 1000.0).long()
+
+ image_sequence_length = latents.size(2) * latents.size(3) // self.transformer_config.patch_size**2
+ mu = (image_sequence_length / base_image_sequence_length) ** 0.5
+ mu = mu * max_shift + base_shift
+ shifted_sigmas = mu / (mu + (1 / sigmas - 1) ** 1.0)
+ noisy_latents = FF.flow_match_xt(latents, noise, shifted_sigmas)
+ noisy_latents = torch.cat([noisy_latents, control_latents], dim=1)
+
+ latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
+
+ pred = transformer(
+ **latent_model_conditions,
+ **condition_model_conditions,
+ timestep=timesteps,
+ return_dict=False,
+ )[0]
+ target = FF.flow_match_target(noise, latents)
+
+ # NOTE: shifted_sigmas loss weighting seems to work better than sigmas. Needs more investigation
+ # but let's keep it this way for now. Longer training runs should reveal more insights.
+ # return pred, target, sigmas
+ return pred, target, shifted_sigmas
+
+ def validation(
+ self,
+ pipeline: CogView4Pipeline,
+ prompt: str,
+ control_image: torch.Tensor,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ generator: Optional[torch.Generator] = None,
+ **kwargs,
+ ) -> List[ArtifactType]:
+ with torch.no_grad():
+ dtype = pipeline.vae.dtype
+ device = pipeline._execution_device
+ in_channels = self.transformer_config.in_channels # We need to use the original in_channels
+ latents = pipeline.prepare_latents(1, in_channels, height, width, dtype, device, generator)
+ control_image = pipeline.image_processor.preprocess(control_image, height=height, width=width)
+ control_image = control_image.to(device=device, dtype=dtype)
+ control_latents = pipeline.vae.encode(control_image).latent_dist.sample(generator=generator)
+ if getattr(self.vae_config, "shift_factor") is not None:
+ control_latents = (control_latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor
+
+ generation_kwargs = {
+ "latents": latents,
+ "prompt": prompt,
+ "height": height,
+ "width": width,
+ "num_inference_steps": num_inference_steps,
+ "generator": generator,
+ "return_dict": True,
+ "output_type": "pil",
+ }
+ generation_kwargs = get_non_null_items(generation_kwargs)
+
+ with control_channel_concat(pipeline.transformer, ["hidden_states"], [control_latents], dims=[1]):
+ image = pipeline(**generation_kwargs).images[0]
+
+ return [ImageArtifact(value=image)]
+
+ def _save_lora_weights(
+ self,
+ directory: str,
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ norm_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ scheduler: Optional[SchedulerType] = None,
+ metadata: Optional[Dict[str, str]] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ # TODO(aryan): this needs refactoring
+ if transformer_state_dict is not None:
+ CogView4Pipeline.save_lora_weights(
+ directory,
+ transformer_state_dict,
+ save_function=functools.partial(safetensors_torch_save_function, metadata=metadata),
+ safe_serialization=True,
+ )
+ if norm_state_dict is not None:
+ safetensors.torch.save_file(norm_state_dict, os.path.join(directory, "norm_state_dict.safetensors"))
+ if scheduler is not None:
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
+
+ def _save_model(
+ self,
+ directory: str,
+ transformer: CogView4Transformer2DModel,
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ scheduler: Optional[SchedulerType] = None,
+ ) -> None:
+ # TODO(aryan): this needs refactoring
+ if transformer_state_dict is not None:
+ with init_empty_weights():
+ transformer_copy = CogView4Transformer2DModel.from_config(transformer.config)
+ transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
+ transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
+ if scheduler is not None:
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
+
+ @property
+ def _original_control_layer_in_features(self):
+ return self.transformer_config.in_channels
+
+ @property
+ def _original_control_layer_out_features(self):
+ return self.transformer_config.num_attention_heads * self.transformer_config.attention_head_dim
+
+ @property
+ def _qk_norm_identifiers(self):
+ return ["attn1.norm_q", "attn1.norm_k"]
diff --git a/docs/finetrainers-src-codebase/finetrainers/models/flux/__init__.py b/docs/finetrainers-src-codebase/finetrainers/models/flux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d1d172114ae4f79a2f89e5196ecbdc8be279e3a
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/models/flux/__init__.py
@@ -0,0 +1 @@
+from .base_specification import FluxModelSpecification
diff --git a/docs/finetrainers-src-codebase/finetrainers/models/flux/base_specification.py b/docs/finetrainers-src-codebase/finetrainers/models/flux/base_specification.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e3ea1e167cfd72e5ee7c158a9a715ee6b7ad09f
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/models/flux/base_specification.py
@@ -0,0 +1,411 @@
+import functools
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+from accelerate import init_empty_weights
+from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
+from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
+from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
+
+import finetrainers.functional as FF
+from finetrainers.data import ImageArtifact
+from finetrainers.logging import get_logger
+from finetrainers.models.modeling_utils import ModelSpecification
+from finetrainers.processors import CLIPPooledProcessor, ProcessorMixin, T5Processor
+from finetrainers.typing import ArtifactType, SchedulerType
+from finetrainers.utils import _enable_vae_memory_optimizations, get_non_null_items, safetensors_torch_save_function
+
+
+logger = get_logger()
+
+
+class FluxLatentEncodeProcessor(ProcessorMixin):
+ r"""
+ Processor to encode image/video into latents using the Flux VAE.
+
+ Args:
+ output_names (`List[str]`):
+ The names of the outputs that the processor returns. The outputs are in the following order:
+ - latents: The latents of the input image/video.
+ """
+
+ def __init__(self, output_names: List[str]):
+ super().__init__()
+
+ self.output_names = output_names
+ assert len(self.output_names) == 1
+
+ def forward(
+ self,
+ vae: AutoencoderKL,
+ image: Optional[torch.Tensor] = None,
+ video: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ ) -> Dict[str, torch.Tensor]:
+ device = vae.device
+ dtype = vae.dtype
+
+ if video is not None:
+ # TODO(aryan): perhaps better would be to flatten(0, 1), but need to account for reshaping sigmas accordingly
+ image = video[:, 0] # [B, F, C, H, W] -> [B, 1, C, H, W]
+
+ assert image.ndim == 4, f"Expected 4D tensor, got {image.ndim}D tensor"
+ image = image.to(device=device, dtype=vae.dtype)
+
+ if compute_posterior:
+ latents = vae.encode(image).latent_dist.sample(generator=generator)
+ latents = latents.to(dtype=dtype)
+ else:
+ if vae.use_slicing and image.shape[0] > 1:
+ encoded_slices = [vae._encode(x_slice) for x_slice in image.split(1)]
+ moments = torch.cat(encoded_slices)
+ else:
+ moments = vae._encode(image)
+ latents = moments.to(dtype=dtype)
+
+ return {self.output_names[0]: latents}
+
+
+class FluxModelSpecification(ModelSpecification):
+ def __init__(
+ self,
+ pretrained_model_name_or_path: str = "black-forest-labs/FLUX.1-dev",
+ tokenizer_id: Optional[str] = None,
+ tokenizer_2_id: Optional[str] = None,
+ text_encoder_id: Optional[str] = None,
+ text_encoder_2_id: Optional[str] = None,
+ transformer_id: Optional[str] = None,
+ vae_id: Optional[str] = None,
+ text_encoder_dtype: torch.dtype = torch.bfloat16,
+ transformer_dtype: torch.dtype = torch.bfloat16,
+ vae_dtype: torch.dtype = torch.bfloat16,
+ revision: Optional[str] = None,
+ cache_dir: Optional[str] = None,
+ condition_model_processors: List[ProcessorMixin] = None,
+ latent_model_processors: List[ProcessorMixin] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ tokenizer_id=tokenizer_id,
+ tokenizer_2_id=tokenizer_2_id,
+ text_encoder_id=text_encoder_id,
+ text_encoder_2_id=text_encoder_2_id,
+ transformer_id=transformer_id,
+ vae_id=vae_id,
+ text_encoder_dtype=text_encoder_dtype,
+ transformer_dtype=transformer_dtype,
+ vae_dtype=vae_dtype,
+ revision=revision,
+ cache_dir=cache_dir,
+ )
+
+ if condition_model_processors is None:
+ condition_model_processors = [
+ CLIPPooledProcessor(["pooled_projections"]),
+ T5Processor(
+ ["encoder_hidden_states", "prompt_attention_mask"],
+ input_names={"tokenizer_2": "tokenizer", "text_encoder_2": "text_encoder"},
+ ),
+ ]
+ if latent_model_processors is None:
+ latent_model_processors = [FluxLatentEncodeProcessor(["latents"])]
+
+ self.condition_model_processors = condition_model_processors
+ self.latent_model_processors = latent_model_processors
+
+ @property
+ def _resolution_dim_keys(self):
+ return {"latents": (2, 3)}
+
+ def load_condition_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.tokenizer_id is not None:
+ tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs)
+ else:
+ tokenizer = CLIPTokenizer.from_pretrained(
+ self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs
+ )
+
+ if self.tokenizer_2_id is not None:
+ tokenizer_2 = AutoTokenizer.from_pretrained(self.tokenizer_2_id, **common_kwargs)
+ else:
+ tokenizer_2 = AutoTokenizer.from_pretrained(
+ self.pretrained_model_name_or_path, subfolder="tokenizer_2", **common_kwargs
+ )
+
+ if self.text_encoder_id is not None:
+ text_encoder = CLIPTextModel.from_pretrained(
+ self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs
+ )
+ else:
+ text_encoder = CLIPTextModel.from_pretrained(
+ self.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ torch_dtype=self.text_encoder_dtype,
+ **common_kwargs,
+ )
+
+ if self.text_encoder_2_id is not None:
+ text_encoder_2 = T5EncoderModel.from_pretrained(
+ self.text_encoder_2_id, torch_dtype=self.text_encoder_2_dtype, **common_kwargs
+ )
+ else:
+ text_encoder_2 = T5EncoderModel.from_pretrained(
+ self.pretrained_model_name_or_path,
+ subfolder="text_encoder_2",
+ torch_dtype=self.text_encoder_2_dtype,
+ **common_kwargs,
+ )
+
+ return {
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ }
+
+ def load_latent_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.vae_id is not None:
+ vae = AutoencoderKL.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs)
+ else:
+ vae = AutoencoderKL.from_pretrained(
+ self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs
+ )
+
+ return {"vae": vae}
+
+ def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.transformer_id is not None:
+ transformer = FluxTransformer2DModel.from_pretrained(
+ self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs
+ )
+ else:
+ transformer = FluxTransformer2DModel.from_pretrained(
+ self.pretrained_model_name_or_path,
+ subfolder="transformer",
+ torch_dtype=self.transformer_dtype,
+ **common_kwargs,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {"transformer": transformer, "scheduler": scheduler}
+
+ def load_pipeline(
+ self,
+ tokenizer: Optional[AutoTokenizer] = None,
+ tokenizer_2: Optional[CLIPTokenizer] = None,
+ text_encoder: Optional[CLIPTextModel] = None,
+ text_encoder_2: Optional[T5EncoderModel] = None,
+ transformer: Optional[FluxTransformer2DModel] = None,
+ vae: Optional[AutoencoderKL] = None,
+ scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
+ enable_slicing: bool = False,
+ enable_tiling: bool = False,
+ enable_model_cpu_offload: bool = False,
+ training: bool = False,
+ **kwargs,
+ ) -> FluxPipeline:
+ components = {
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "transformer": transformer,
+ "vae": vae,
+ # Load the scheduler based on Flux's config instead of using the default initialization being used for training
+ # "scheduler": scheduler,
+ }
+ components = get_non_null_items(components)
+
+ pipe = FluxPipeline.from_pretrained(
+ self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
+ )
+ pipe.text_encoder.to(self.text_encoder_dtype)
+ pipe.text_encoder_2.to(self.text_encoder_2_dtype)
+ pipe.vae.to(self.vae_dtype)
+
+ _enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling)
+ if not training:
+ pipe.transformer.to(self.transformer_dtype)
+ if enable_model_cpu_offload:
+ pipe.enable_model_cpu_offload()
+ return pipe
+
+ @torch.no_grad()
+ def prepare_conditions(
+ self,
+ tokenizer: AutoTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: T5EncoderModel,
+ caption: str,
+ max_sequence_length: int = 512,
+ **kwargs,
+ ) -> Dict[str, Any]:
+ conditions = {
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "caption": caption,
+ "max_sequence_length": max_sequence_length,
+ **kwargs,
+ }
+ input_keys = set(conditions.keys())
+ conditions = super().prepare_conditions(**conditions)
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+ return conditions
+
+ @torch.no_grad()
+ def prepare_latents(
+ self,
+ vae: AutoencoderKL,
+ image: Optional[torch.Tensor] = None,
+ video: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ **kwargs,
+ ) -> Dict[str, torch.Tensor]:
+ conditions = {
+ "vae": vae,
+ "image": image,
+ "video": video,
+ "generator": generator,
+ "compute_posterior": compute_posterior,
+ **kwargs,
+ }
+ input_keys = set(conditions.keys())
+ conditions = super().prepare_latents(**conditions)
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+ return conditions
+
+ def forward(
+ self,
+ transformer: FluxTransformer2DModel,
+ condition_model_conditions: Dict[str, torch.Tensor],
+ latent_model_conditions: Dict[str, torch.Tensor],
+ sigmas: torch.Tensor,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, ...]:
+ if compute_posterior:
+ latents = latent_model_conditions.pop("latents")
+ else:
+ posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"))
+ latents = posterior.sample(generator=generator)
+ del posterior
+
+ if getattr(self.vae_config, "shift_factor", None) is not None:
+ latents = (latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor
+ else:
+ latents = latents * self.vae_config.scaling_factor
+
+ noise = torch.zeros_like(latents).normal_(generator=generator)
+ timesteps = (sigmas.flatten() * 1000.0).long()
+ img_ids = FluxPipeline._prepare_latent_image_ids(
+ latents.size(0), latents.size(2) // 2, latents.size(3) // 2, latents.device, latents.dtype
+ )
+ text_ids = latents.new_zeros(condition_model_conditions["encoder_hidden_states"].shape[1], 3)
+
+ if self.transformer_config.guidance_embeds:
+ guidance_scale = 1.0
+ guidance = latents.new_full((1,), guidance_scale).expand(latents.shape[0])
+ else:
+ guidance = None
+
+ noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
+ noisy_latents = FluxPipeline._pack_latents(noisy_latents, *latents.shape)
+
+ latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
+ condition_model_conditions.pop("prompt_attention_mask", None)
+
+ spatial_compression_ratio = 2 ** len(self.vae_config.block_out_channels)
+ pred = transformer(
+ **latent_model_conditions,
+ **condition_model_conditions,
+ timestep=timesteps / 1000.0,
+ guidance=guidance,
+ img_ids=img_ids,
+ txt_ids=text_ids,
+ return_dict=False,
+ )[0]
+ pred = FluxPipeline._unpack_latents(
+ pred,
+ latents.size(2) * spatial_compression_ratio,
+ latents.size(3) * spatial_compression_ratio,
+ spatial_compression_ratio,
+ )
+ target = FF.flow_match_target(noise, latents)
+
+ return pred, target, sigmas
+
+ def validation(
+ self,
+ pipeline: FluxPipeline,
+ prompt: str,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 3.5,
+ generator: Optional[torch.Generator] = None,
+ **kwargs,
+ ) -> List[ArtifactType]:
+ generation_kwargs = {
+ "prompt": prompt,
+ "height": height,
+ "width": width,
+ "num_inference_steps": num_inference_steps,
+ "guidance_scale": guidance_scale,
+ "generator": generator,
+ "return_dict": True,
+ "output_type": "pil",
+ }
+ generation_kwargs = get_non_null_items(generation_kwargs)
+ image = pipeline(**generation_kwargs).images[0]
+ return [ImageArtifact(value=image)]
+
+ def _save_lora_weights(
+ self,
+ directory: str,
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ scheduler: Optional[SchedulerType] = None,
+ metadata: Optional[Dict[str, str]] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ # TODO(aryan): this needs refactoring
+ if transformer_state_dict is not None:
+ FluxPipeline.save_lora_weights(
+ directory,
+ transformer_state_dict,
+ save_function=functools.partial(safetensors_torch_save_function, metadata=metadata),
+ safe_serialization=True,
+ )
+ if scheduler is not None:
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
+
+ def _save_model(
+ self,
+ directory: str,
+ transformer: FluxTransformer2DModel,
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ scheduler: Optional[SchedulerType] = None,
+ ) -> None:
+ # TODO(aryan): this needs refactoring
+ if transformer_state_dict is not None:
+ with init_empty_weights():
+ transformer_copy = FluxTransformer2DModel.from_config(transformer.config)
+ transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
+ transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
+ if scheduler is not None:
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
diff --git a/docs/finetrainers-src-codebase/finetrainers/models/hunyuan_video/__init__.py b/docs/finetrainers-src-codebase/finetrainers/models/hunyuan_video/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..518a42865f0cee30a534da458ec63b08c1a8d7e4
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/models/hunyuan_video/__init__.py
@@ -0,0 +1 @@
+from .base_specification import HunyuanVideoModelSpecification
diff --git a/docs/finetrainers-src-codebase/finetrainers/models/hunyuan_video/base_specification.py b/docs/finetrainers-src-codebase/finetrainers/models/hunyuan_video/base_specification.py
new file mode 100644
index 0000000000000000000000000000000000000000..80d02c931dc54fe8e5578a383b239cb0091d26f2
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/models/hunyuan_video/base_specification.py
@@ -0,0 +1,391 @@
+import functools
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+from accelerate import init_empty_weights
+from diffusers import (
+ AutoencoderKLHunyuanVideo,
+ FlowMatchEulerDiscreteScheduler,
+ HunyuanVideoPipeline,
+ HunyuanVideoTransformer3DModel,
+)
+from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
+from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel
+
+import finetrainers.functional as FF
+from finetrainers.data import VideoArtifact
+from finetrainers.logging import get_logger
+from finetrainers.models.modeling_utils import ModelSpecification
+from finetrainers.processors import CLIPPooledProcessor, LlamaProcessor, ProcessorMixin
+from finetrainers.typing import ArtifactType, SchedulerType
+from finetrainers.utils import _enable_vae_memory_optimizations, get_non_null_items, safetensors_torch_save_function
+
+
+logger = get_logger()
+
+
+class HunyuanLatentEncodeProcessor(ProcessorMixin):
+ r"""
+ Processor to encode image/video into latents using the HunyuanVideo VAE.
+
+ Args:
+ output_names (`List[str]`):
+ The names of the outputs that the processor returns. The outputs are in the following order:
+ - latents: The latents of the input image/video.
+ """
+
+ def __init__(self, output_names: List[str]):
+ super().__init__()
+ self.output_names = output_names
+ assert len(self.output_names) == 1
+
+ def forward(
+ self,
+ vae: AutoencoderKLHunyuanVideo,
+ image: Optional[torch.Tensor] = None,
+ video: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ ) -> Dict[str, torch.Tensor]:
+ device = vae.device
+ dtype = vae.dtype
+
+ if image is not None:
+ video = image.unsqueeze(1)
+
+ assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
+ video = video.to(device=device, dtype=vae.dtype)
+ video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
+
+ if compute_posterior:
+ latents = vae.encode(video).latent_dist.sample(generator=generator)
+ latents = latents.to(dtype=dtype)
+ else:
+ if vae.use_slicing and video.shape[0] > 1:
+ encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
+ moments = torch.cat(encoded_slices)
+ else:
+ moments = vae._encode(video)
+ latents = moments.to(dtype=dtype)
+
+ return {self.output_names[0]: latents}
+
+
+class HunyuanVideoModelSpecification(ModelSpecification):
+ def __init__(
+ self,
+ pretrained_model_name_or_path: str = "hunyuanvideo-community/HunyuanVideo",
+ tokenizer_id: Optional[str] = None,
+ tokenizer_2_id: Optional[str] = None,
+ text_encoder_id: Optional[str] = None,
+ text_encoder_2_id: Optional[str] = None,
+ transformer_id: Optional[str] = None,
+ vae_id: Optional[str] = None,
+ text_encoder_dtype: torch.dtype = torch.bfloat16,
+ transformer_dtype: torch.dtype = torch.bfloat16,
+ vae_dtype: torch.dtype = torch.bfloat16,
+ revision: Optional[str] = None,
+ cache_dir: Optional[str] = None,
+ condition_model_processors: List[ProcessorMixin] = None,
+ latent_model_processors: List[ProcessorMixin] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ tokenizer_id=tokenizer_id,
+ tokenizer_2_id=tokenizer_2_id,
+ text_encoder_id=text_encoder_id,
+ text_encoder_2_id=text_encoder_2_id,
+ transformer_id=transformer_id,
+ vae_id=vae_id,
+ text_encoder_dtype=text_encoder_dtype,
+ transformer_dtype=transformer_dtype,
+ vae_dtype=vae_dtype,
+ revision=revision,
+ cache_dir=cache_dir,
+ )
+
+ if condition_model_processors is None:
+ condition_model_processors = [
+ LlamaProcessor(["encoder_hidden_states", "encoder_attention_mask"]),
+ CLIPPooledProcessor(
+ ["pooled_projections"],
+ input_names={"tokenizer_2": "tokenizer", "text_encoder_2": "text_encoder"},
+ ),
+ ]
+ if latent_model_processors is None:
+ latent_model_processors = [HunyuanLatentEncodeProcessor(["latents"])]
+
+ self.condition_model_processors = condition_model_processors
+ self.latent_model_processors = latent_model_processors
+
+ @property
+ def _resolution_dim_keys(self):
+ return {"latents": (2, 3, 4)}
+
+ def load_condition_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.tokenizer_id is not None:
+ tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs)
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(
+ self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs
+ )
+
+ if self.tokenizer_2_id is not None:
+ tokenizer_2 = AutoTokenizer.from_pretrained(self.tokenizer_2_id, **common_kwargs)
+ else:
+ tokenizer_2 = CLIPTokenizer.from_pretrained(
+ self.pretrained_model_name_or_path, subfolder="tokenizer_2", **common_kwargs
+ )
+
+ if self.text_encoder_id is not None:
+ text_encoder = LlamaModel.from_pretrained(
+ self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs
+ )
+ else:
+ text_encoder = LlamaModel.from_pretrained(
+ self.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ torch_dtype=self.text_encoder_dtype,
+ **common_kwargs,
+ )
+
+ if self.text_encoder_2_id is not None:
+ text_encoder_2 = CLIPTextModel.from_pretrained(
+ self.text_encoder_2_id, torch_dtype=self.text_encoder_2_dtype, **common_kwargs
+ )
+ else:
+ text_encoder_2 = CLIPTextModel.from_pretrained(
+ self.pretrained_model_name_or_path,
+ subfolder="text_encoder_2",
+ torch_dtype=self.text_encoder_2_dtype,
+ **common_kwargs,
+ )
+
+ return {
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ }
+
+ def load_latent_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.vae_id is not None:
+ vae = AutoencoderKLHunyuanVideo.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs)
+ else:
+ vae = AutoencoderKLHunyuanVideo.from_pretrained(
+ self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs
+ )
+
+ return {"vae": vae}
+
+ def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.transformer_id is not None:
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained(
+ self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs
+ )
+ else:
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained(
+ self.pretrained_model_name_or_path,
+ subfolder="transformer",
+ torch_dtype=self.transformer_dtype,
+ **common_kwargs,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {"transformer": transformer, "scheduler": scheduler}
+
+ def load_pipeline(
+ self,
+ tokenizer: Optional[AutoTokenizer] = None,
+ tokenizer_2: Optional[CLIPTokenizer] = None,
+ text_encoder: Optional[LlamaModel] = None,
+ text_encoder_2: Optional[CLIPTextModel] = None,
+ transformer: Optional[HunyuanVideoTransformer3DModel] = None,
+ vae: Optional[AutoencoderKLHunyuanVideo] = None,
+ scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
+ enable_slicing: bool = False,
+ enable_tiling: bool = False,
+ enable_model_cpu_offload: bool = False,
+ training: bool = False,
+ **kwargs,
+ ) -> HunyuanVideoPipeline:
+ components = {
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ }
+ components = get_non_null_items(components)
+
+ pipe = HunyuanVideoPipeline.from_pretrained(
+ self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
+ )
+ pipe.text_encoder.to(self.text_encoder_dtype)
+ pipe.text_encoder_2.to(self.text_encoder_2_dtype)
+ pipe.vae.to(self.vae_dtype)
+
+ _enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling)
+ if not training:
+ pipe.transformer.to(self.transformer_dtype)
+ if enable_model_cpu_offload:
+ pipe.enable_model_cpu_offload()
+ return pipe
+
+ @torch.no_grad()
+ def prepare_conditions(
+ self,
+ tokenizer: AutoTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ text_encoder: LlamaModel,
+ text_encoder_2: CLIPTextModel,
+ caption: str,
+ max_sequence_length: int = 256,
+ **kwargs,
+ ) -> Dict[str, Any]:
+ conditions = {
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "caption": caption,
+ "max_sequence_length": max_sequence_length,
+ **kwargs,
+ }
+ input_keys = set(conditions.keys())
+ conditions = super().prepare_conditions(**conditions)
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+ return conditions
+
+ @torch.no_grad()
+ def prepare_latents(
+ self,
+ vae: AutoencoderKLHunyuanVideo,
+ image: Optional[torch.Tensor] = None,
+ video: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ **kwargs,
+ ) -> Dict[str, torch.Tensor]:
+ conditions = {
+ "vae": vae,
+ "image": image,
+ "video": video,
+ "generator": generator,
+ "compute_posterior": compute_posterior,
+ **kwargs,
+ }
+ input_keys = set(conditions.keys())
+ conditions = super().prepare_latents(**conditions)
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+ return conditions
+
+ def forward(
+ self,
+ transformer: HunyuanVideoTransformer3DModel,
+ condition_model_conditions: Dict[str, torch.Tensor],
+ latent_model_conditions: Dict[str, torch.Tensor],
+ sigmas: torch.Tensor,
+ guidance: float = 1.0,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, ...]:
+ if compute_posterior:
+ latents = latent_model_conditions.pop("latents")
+ else:
+ posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"))
+ latents = posterior.sample(generator=generator)
+ del posterior
+
+ latents = latents * self.vae_config.scaling_factor
+ noise = torch.zeros_like(latents).normal_(generator=generator)
+ noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
+
+ timesteps = (sigmas.flatten() * 1000.0).long()
+ guidance = latents.new_full((latents.size(0),), fill_value=guidance) * 1000.0
+
+ latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
+ latent_model_conditions["guidance"] = guidance
+
+ pred = transformer(
+ **latent_model_conditions,
+ **condition_model_conditions,
+ timestep=timesteps,
+ return_dict=False,
+ )[0]
+ target = FF.flow_match_target(noise, latents)
+
+ return pred, target, sigmas
+
+ def validation(
+ self,
+ pipeline: HunyuanVideoPipeline,
+ prompt: str,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_frames: Optional[int] = None,
+ num_inference_steps: int = 50,
+ generator: Optional[torch.Generator] = None,
+ **kwargs,
+ ) -> List[ArtifactType]:
+ generation_kwargs = {
+ "prompt": prompt,
+ "height": height,
+ "width": width,
+ "num_frames": num_frames,
+ "num_inference_steps": num_inference_steps,
+ "generator": generator,
+ "return_dict": True,
+ "output_type": "pil",
+ }
+ generation_kwargs = get_non_null_items(generation_kwargs)
+ video = pipeline(**generation_kwargs).frames[0]
+ return [VideoArtifact(value=video)]
+
+ def _save_lora_weights(
+ self,
+ directory: str,
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ scheduler: Optional[SchedulerType] = None,
+ metadata: Optional[Dict[str, str]] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ # TODO(aryan): this needs refactoring
+ if transformer_state_dict is not None:
+ HunyuanVideoPipeline.save_lora_weights(
+ directory,
+ transformer_state_dict,
+ save_function=functools.partial(safetensors_torch_save_function, metadata=metadata),
+ safe_serialization=True,
+ )
+ if scheduler is not None:
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
+
+ def _save_model(
+ self,
+ directory: str,
+ transformer: HunyuanVideoTransformer3DModel,
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ scheduler: Optional[SchedulerType] = None,
+ ) -> None:
+ # TODO(aryan): this needs refactoring
+ if transformer_state_dict is not None:
+ with init_empty_weights():
+ transformer_copy = HunyuanVideoTransformer3DModel.from_config(transformer.config)
+ transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
+ transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
+ if scheduler is not None:
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
diff --git a/docs/finetrainers-src-codebase/finetrainers/models/ltx_video/__init__.py b/docs/finetrainers-src-codebase/finetrainers/models/ltx_video/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff4e3550d54bb33fac80dd2d075ad2846eeeed46
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/models/ltx_video/__init__.py
@@ -0,0 +1 @@
+from .base_specification import LTXVideoModelSpecification
diff --git a/docs/finetrainers-src-codebase/finetrainers/models/ltx_video/base_specification.py b/docs/finetrainers-src-codebase/finetrainers/models/ltx_video/base_specification.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8eaa5e420a604e2f13d9f3adee07e7ef1d02dee
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/models/ltx_video/base_specification.py
@@ -0,0 +1,504 @@
+import functools
+import os
+import random
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+from accelerate import init_empty_weights
+from diffusers import (
+ AutoencoderKLLTXVideo,
+ FlowMatchEulerDiscreteScheduler,
+ LTXImageToVideoPipeline,
+ LTXPipeline,
+ LTXVideoTransformer3DModel,
+)
+from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
+from PIL.Image import Image
+from transformers import AutoModel, AutoTokenizer, T5EncoderModel, T5Tokenizer
+
+import finetrainers.functional as FF
+from finetrainers.data import VideoArtifact
+from finetrainers.logging import get_logger
+from finetrainers.models.modeling_utils import ModelSpecification
+from finetrainers.parallel import ParallelBackendEnum
+from finetrainers.processors import ProcessorMixin, T5Processor
+from finetrainers.typing import ArtifactType, SchedulerType
+from finetrainers.utils import _enable_vae_memory_optimizations, get_non_null_items, safetensors_torch_save_function
+
+
+logger = get_logger()
+
+
+class LTXLatentEncodeProcessor(ProcessorMixin):
+ r"""
+ Processor to encode image/video into latents using the LTX VAE.
+
+ Args:
+ output_names (`List[str]`):
+ The names of the outputs that the processor returns. The outputs are in the following order:
+ - latents: The latents of the input image/video.
+ - num_frames: The number of frames in the input video.
+ - height: The height of the input image/video.
+ - width: The width of the input image/video.
+ - latents_mean: The latent channel means from the VAE state dict.
+ - latents_std: The latent channel standard deviations from the VAE state dict.
+ """
+
+ def __init__(self, output_names: List[str]):
+ super().__init__()
+ self.output_names = output_names
+ assert len(self.output_names) == 6
+
+ def forward(
+ self,
+ vae: AutoencoderKLLTXVideo,
+ image: Optional[torch.Tensor] = None,
+ video: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ ) -> Dict[str, torch.Tensor]:
+ device = vae.device
+ dtype = vae.dtype
+
+ if image is not None:
+ video = image.unsqueeze(1)
+
+ assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
+ video = video.to(device=device, dtype=vae.dtype)
+ video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
+
+ if compute_posterior:
+ latents = vae.encode(video).latent_dist.sample(generator=generator)
+ latents = latents.to(dtype=dtype)
+ else:
+ if vae.use_slicing and video.shape[0] > 1:
+ encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
+ moments = torch.cat(encoded_slices)
+ else:
+ moments = vae._encode(video)
+ latents = moments.to(dtype=dtype)
+
+ _, _, num_frames, height, width = latents.shape
+
+ return {
+ self.output_names[0]: latents,
+ self.output_names[1]: num_frames,
+ self.output_names[2]: height,
+ self.output_names[3]: width,
+ self.output_names[4]: vae.latents_mean,
+ self.output_names[5]: vae.latents_std,
+ }
+
+
+class LTXVideoModelSpecification(ModelSpecification):
+ def __init__(
+ self,
+ pretrained_model_name_or_path: str = "Lightricks/LTX-Video",
+ tokenizer_id: Optional[str] = None,
+ text_encoder_id: Optional[str] = None,
+ transformer_id: Optional[str] = None,
+ vae_id: Optional[str] = None,
+ text_encoder_dtype: torch.dtype = torch.bfloat16,
+ transformer_dtype: torch.dtype = torch.bfloat16,
+ vae_dtype: torch.dtype = torch.bfloat16,
+ revision: Optional[str] = None,
+ cache_dir: Optional[str] = None,
+ condition_model_processors: List[ProcessorMixin] = None,
+ latent_model_processors: List[ProcessorMixin] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ tokenizer_id=tokenizer_id,
+ text_encoder_id=text_encoder_id,
+ transformer_id=transformer_id,
+ vae_id=vae_id,
+ text_encoder_dtype=text_encoder_dtype,
+ transformer_dtype=transformer_dtype,
+ vae_dtype=vae_dtype,
+ revision=revision,
+ cache_dir=cache_dir,
+ )
+
+ if condition_model_processors is None:
+ condition_model_processors = [T5Processor(["encoder_hidden_states", "encoder_attention_mask"])]
+ if latent_model_processors is None:
+ latent_model_processors = [
+ LTXLatentEncodeProcessor(["latents", "num_frames", "height", "width", "latents_mean", "latents_std"])
+ ]
+
+ self.condition_model_processors = condition_model_processors
+ self.latent_model_processors = latent_model_processors
+
+ @property
+ def _resolution_dim_keys(self):
+ return {"latents": (2, 3, 4)}
+
+ def load_condition_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.tokenizer_id is not None:
+ tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs)
+ else:
+ tokenizer = T5Tokenizer.from_pretrained(
+ self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs
+ )
+
+ if self.text_encoder_id is not None:
+ text_encoder = AutoModel.from_pretrained(
+ self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs
+ )
+ else:
+ text_encoder = T5EncoderModel.from_pretrained(
+ self.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ torch_dtype=self.text_encoder_dtype,
+ **common_kwargs,
+ )
+
+ return {"tokenizer": tokenizer, "text_encoder": text_encoder}
+
+ def load_latent_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.vae_id is not None:
+ vae = AutoencoderKLLTXVideo.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs)
+ else:
+ vae = AutoencoderKLLTXVideo.from_pretrained(
+ self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs
+ )
+
+ return {"vae": vae}
+
+ def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.transformer_id is not None:
+ transformer = LTXVideoTransformer3DModel.from_pretrained(
+ self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs
+ )
+ else:
+ transformer = LTXVideoTransformer3DModel.from_pretrained(
+ self.pretrained_model_name_or_path,
+ subfolder="transformer",
+ torch_dtype=self.transformer_dtype,
+ **common_kwargs,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {"transformer": transformer, "scheduler": scheduler}
+
+ def load_pipeline(
+ self,
+ tokenizer: Optional[T5Tokenizer] = None,
+ text_encoder: Optional[T5EncoderModel] = None,
+ transformer: Optional[LTXVideoTransformer3DModel] = None,
+ vae: Optional[AutoencoderKLLTXVideo] = None,
+ scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
+ enable_slicing: bool = False,
+ enable_tiling: bool = False,
+ enable_model_cpu_offload: bool = False,
+ training: bool = False,
+ **kwargs,
+ ) -> LTXPipeline:
+ components = {
+ "tokenizer": tokenizer,
+ "text_encoder": text_encoder,
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ }
+ components = get_non_null_items(components)
+
+ pipe = LTXPipeline.from_pretrained(
+ self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
+ )
+ pipe.text_encoder.to(self.text_encoder_dtype)
+ pipe.vae.to(self.vae_dtype)
+
+ _enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling)
+ if not training:
+ pipe.transformer.to(self.transformer_dtype)
+ if enable_model_cpu_offload:
+ pipe.enable_model_cpu_offload()
+ return pipe
+
+ @torch.no_grad()
+ def prepare_conditions(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ caption: str,
+ max_sequence_length: int = 128,
+ **kwargs,
+ ) -> Dict[str, Any]:
+ conditions = {
+ "tokenizer": tokenizer,
+ "text_encoder": text_encoder,
+ "caption": caption,
+ "max_sequence_length": max_sequence_length,
+ **kwargs,
+ }
+ input_keys = set(conditions.keys())
+ conditions = super().prepare_conditions(**conditions)
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+ return conditions
+
+ @torch.no_grad()
+ def prepare_latents(
+ self,
+ vae: AutoencoderKLLTXVideo,
+ image: Optional[torch.Tensor] = None,
+ video: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ **kwargs,
+ ) -> Dict[str, torch.Tensor]:
+ conditions = {
+ "vae": vae,
+ "image": image,
+ "video": video,
+ "generator": generator,
+ "compute_posterior": compute_posterior,
+ **kwargs,
+ }
+ input_keys = set(conditions.keys())
+ conditions = super().prepare_latents(**conditions)
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+ return conditions
+
+ def forward(
+ self,
+ transformer: LTXVideoTransformer3DModel,
+ condition_model_conditions: Dict[str, torch.Tensor],
+ latent_model_conditions: Dict[str, torch.Tensor],
+ sigmas: torch.Tensor,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, ...]:
+ # TODO(aryan): make this configurable? Should it be?
+ first_frame_conditioning_p = 0.1
+ min_first_frame_sigma = 0.25
+
+ if compute_posterior:
+ latents = latent_model_conditions.pop("latents")
+ else:
+ posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"))
+ latents = posterior.sample(generator=generator)
+ del posterior
+
+ latents_mean = latent_model_conditions.pop("latents_mean")
+ latents_std = latent_model_conditions.pop("latents_std")
+
+ latents = self._normalize_latents(latents, latents_mean, latents_std)
+ noise = torch.zeros_like(latents).normal_(generator=generator)
+
+ if random.random() < first_frame_conditioning_p:
+ # Based on Section 2.4 of the paper, it mentions that the first frame timesteps should be a small random value.
+ # Making as estimated guess, we limit the sigmas to be at least 0.2.
+ # torch.rand_like returns values in [0, 1). We want to make sure that the first frame sigma is <= actual sigmas
+ # for image conditioning. In order to do this, we rescale by multiplying with sigmas so the range is [0, sigmas).
+ first_frame_sigma = torch.rand_like(sigmas) * sigmas
+ first_frame_sigma = torch.min(first_frame_sigma, sigmas.new_full(sigmas.shape, min_first_frame_sigma))
+
+ latents_first_frame, latents_rest = latents[:, :, :1], latents[:, :, 1:]
+ noisy_latents_first_frame = FF.flow_match_xt(latents_first_frame, noise[:, :, :1], first_frame_sigma)
+ noisy_latents_remaining = FF.flow_match_xt(latents_rest, noise[:, :, 1:], sigmas)
+ noisy_latents = torch.cat([noisy_latents_first_frame, noisy_latents_remaining], dim=2)
+ else:
+ noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
+
+ patch_size = self.transformer_config.patch_size
+ patch_size_t = self.transformer_config.patch_size_t
+
+ latents = self._pack_latents(latents, patch_size, patch_size_t)
+ noise = self._pack_latents(noise, patch_size, patch_size_t)
+ noisy_latents = self._pack_latents(noisy_latents, patch_size, patch_size_t)
+ sigmas = sigmas.view(-1, 1, 1).expand(-1, *noisy_latents.shape[1:-1], -1)
+ timesteps = (sigmas * 1000.0).long()
+
+ latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
+
+ # TODO(aryan): make this configurable
+ frame_rate = 25
+ temporal_compression_ratio = 8
+ vae_spatial_compression_ratio = 32
+ latent_frame_rate = frame_rate / temporal_compression_ratio
+
+ rope_interpolation_scale = [
+ 1 / latent_frame_rate,
+ vae_spatial_compression_ratio,
+ vae_spatial_compression_ratio,
+ ]
+
+ pred = transformer(
+ **latent_model_conditions,
+ **condition_model_conditions,
+ timestep=timesteps,
+ rope_interpolation_scale=rope_interpolation_scale,
+ return_dict=False,
+ )[0]
+ target = FF.flow_match_target(noise, latents)
+
+ return pred, target, sigmas
+
+ def validation(
+ self,
+ pipeline: LTXPipeline,
+ prompt: str,
+ image: Optional[Image] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_frames: Optional[int] = None,
+ frame_rate: int = 25,
+ num_inference_steps: int = 50,
+ generator: Optional[torch.Generator] = None,
+ **kwargs,
+ ) -> List[ArtifactType]:
+ if image is not None:
+ pipeline = LTXImageToVideoPipeline.from_pipe(pipeline)
+
+ generation_kwargs = {
+ "prompt": prompt,
+ "image": image,
+ "height": height,
+ "width": width,
+ "num_frames": num_frames,
+ "frame_rate": frame_rate,
+ "num_inference_steps": num_inference_steps,
+ "generator": generator,
+ "return_dict": True,
+ "output_type": "pil",
+ }
+ generation_kwargs = get_non_null_items(generation_kwargs)
+ video = pipeline(**generation_kwargs).frames[0]
+ return [VideoArtifact(value=video)]
+
+ def _save_lora_weights(
+ self,
+ directory: str,
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ scheduler: Optional[SchedulerType] = None,
+ metadata: Optional[Dict[str, str]] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ # TODO(aryan): this needs refactoring
+ if transformer_state_dict is not None:
+ LTXPipeline.save_lora_weights(
+ directory,
+ transformer_state_dict,
+ save_function=functools.partial(safetensors_torch_save_function, metadata=metadata),
+ safe_serialization=True,
+ )
+ if scheduler is not None:
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
+
+ def _save_model(
+ self,
+ directory: str,
+ transformer: LTXVideoTransformer3DModel,
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ scheduler: Optional[SchedulerType] = None,
+ ) -> None:
+ # TODO(aryan): this needs refactoring
+ if transformer_state_dict is not None:
+ with init_empty_weights():
+ transformer_copy = LTXVideoTransformer3DModel.from_config(transformer.config)
+ transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
+ transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
+ if scheduler is not None:
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
+
+ def apply_tensor_parallel(
+ self,
+ backend: ParallelBackendEnum,
+ device_mesh: torch.distributed.DeviceMesh,
+ transformer: LTXVideoTransformer3DModel,
+ **kwargs,
+ ) -> None:
+ if backend == ParallelBackendEnum.PTD:
+ _apply_tensor_parallel_ptd(device_mesh, transformer)
+ else:
+ raise NotImplementedError(f"Parallel backend {backend} is not supported for LTXVideoModelSpecification")
+
+ @staticmethod
+ def _normalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
+ ) -> torch.Tensor:
+ # Normalize latents across the channel dimension [B, C, F, H, W]
+ batch_size = latents.shape[0]
+ latents_mean = latents_mean.view(batch_size, -1, 1, 1, 1).to(device=latents.device)
+ latents_std = latents_std.view(batch_size, -1, 1, 1, 1).to(device=latents.device)
+ latents = ((latents.float() - latents_mean) * scaling_factor / latents_std).to(latents)
+ return latents
+
+ @staticmethod
+ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
+ # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
+ # The patch dimensions are then permuted and collapsed into the channel dimension of shape:
+ # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
+ # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
+ batch_size, num_channels, num_frames, height, width = latents.shape
+ post_patch_num_frames = num_frames // patch_size_t
+ post_patch_height = height // patch_size
+ post_patch_width = width // patch_size
+ latents = latents.reshape(
+ batch_size,
+ -1,
+ post_patch_num_frames,
+ patch_size_t,
+ post_patch_height,
+ patch_size,
+ post_patch_width,
+ patch_size,
+ )
+ latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
+ return latents
+
+
+def _apply_tensor_parallel_ptd(
+ device_mesh: torch.distributed.device_mesh.DeviceMesh, transformer: LTXVideoTransformer3DModel
+) -> None:
+ from torch.distributed.tensor.parallel import parallelize_module
+ from torch.distributed.tensor.parallel.style import ColwiseParallel, RowwiseParallel
+
+ transformer_plan = {
+ # ===== Condition embeddings =====
+ # "time_embed.emb.timestep_embedder.linear_1": ColwiseParallel(),
+ # "time_embed.emb.timestep_embedder.linear_2": RowwiseParallel(output_layouts=Shard(-1)),
+ # "time_embed.linear": ColwiseParallel(input_layouts=Shard(-1), output_layouts=Replicate()),
+ # "time_embed": PrepareModuleOutput(output_layouts=(Replicate(), Shard(-1)), desired_output_layouts=(Replicate(), Replicate())),
+ # "caption_projection.linear_1": ColwiseParallel(),
+ # "caption_projection.linear_2": RowwiseParallel(),
+ # "rope": PrepareModuleOutput(output_layouts=(Replicate(), Replicate()), desired_output_layouts=(Shard(1), Shard(1)), use_local_output=False),
+ # ===== =====
+ }
+
+ for block in transformer.transformer_blocks:
+ block_plan = {}
+
+ # ===== Attention =====
+ # 8 all-to-all, 3 all-reduce
+ # block_plan["attn1.to_q"] = ColwiseParallel(use_local_output=False)
+ # block_plan["attn1.to_k"] = ColwiseParallel(use_local_output=False)
+ # block_plan["attn1.to_v"] = ColwiseParallel(use_local_output=False)
+ # block_plan["attn1.norm_q"] = SequenceParallel()
+ # block_plan["attn1.norm_k"] = SequenceParallel()
+ # block_plan["attn1.to_out.0"] = RowwiseParallel(input_layouts=Shard(1))
+ # block_plan["attn2.to_q"] = ColwiseParallel(use_local_output=False)
+ # block_plan["attn2.to_k"] = ColwiseParallel(use_local_output=False)
+ # block_plan["attn2.to_v"] = ColwiseParallel(use_local_output=False)
+ # block_plan["attn2.norm_q"] = SequenceParallel()
+ # block_plan["attn2.norm_k"] = SequenceParallel()
+ # block_plan["attn2.to_out.0"] = RowwiseParallel(input_layouts=Shard(1))
+ # ===== =====
+
+ block_plan["ff.net.0.proj"] = ColwiseParallel()
+ block_plan["ff.net.2"] = RowwiseParallel()
+
+ parallelize_module(block, device_mesh, block_plan)
+
+ parallelize_module(transformer, device_mesh, transformer_plan)
diff --git a/docs/finetrainers-src-codebase/finetrainers/models/modeling_utils.py b/docs/finetrainers-src-codebase/finetrainers/models/modeling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b96599801feb3abe16acbada78dd0b4dfb9b9c7
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/models/modeling_utils.py
@@ -0,0 +1,388 @@
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from diffusers import DiffusionPipeline
+from diffusers.configuration_utils import FrozenDict
+from PIL.Image import Image
+
+from finetrainers.logging import get_logger
+from finetrainers.parallel import ParallelBackendEnum
+from finetrainers.processors import ProcessorMixin
+from finetrainers.typing import ArtifactType, SchedulerType, TokenizerType
+from finetrainers.utils import resolve_component_cls
+
+
+if TYPE_CHECKING:
+ from finetrainers.trainer.control_trainer.config import FrameConditioningType
+
+logger = get_logger()
+
+# TODO(aryan): we most likely don't need this. take a look after refactoring more
+# fmt: off
+IGNORE_KEYS_FOR_COLLATION = {"height", "width", "num_frames", "frame_rate", "rope_interpolation_scale", "return_dict", "attention_kwargs", "cross_attention_kwargs", "joint_attention_kwargs", "latents_mean", "latents_std"}
+# fmt: on
+
+
+class ModelSpecification:
+ r"""
+ The ModelSpecification class is an interface to be used for Diffusion training recipes. It provides
+ loose structure about how to organize the code for training. The trainer implementations will
+ make use of this interface to load models, prepare conditions, prepare latents, forward pass, etc.
+ """
+
+ def __init__(
+ self,
+ pretrained_model_name_or_path: Optional[str] = None,
+ tokenizer_id: Optional[str] = None,
+ tokenizer_2_id: Optional[str] = None,
+ tokenizer_3_id: Optional[str] = None,
+ text_encoder_id: Optional[str] = None,
+ text_encoder_2_id: Optional[str] = None,
+ text_encoder_3_id: Optional[str] = None,
+ transformer_id: Optional[str] = None,
+ vae_id: Optional[str] = None,
+ text_encoder_dtype: torch.dtype = torch.bfloat16,
+ text_encoder_2_dtype: torch.dtype = torch.bfloat16,
+ text_encoder_3_dtype: torch.dtype = torch.bfloat16,
+ transformer_dtype: torch.dtype = torch.bfloat16,
+ vae_dtype: str = torch.bfloat16,
+ revision: Optional[str] = None,
+ cache_dir: Optional[str] = None,
+ condition_model_processors: List[ProcessorMixin] = None,
+ latent_model_processors: List[ProcessorMixin] = None,
+ ) -> None:
+ self.pretrained_model_name_or_path = pretrained_model_name_or_path
+ self.tokenizer_id = tokenizer_id
+ self.tokenizer_2_id = tokenizer_2_id
+ self.tokenizer_3_id = tokenizer_3_id
+ self.text_encoder_id = text_encoder_id
+ self.text_encoder_2_id = text_encoder_2_id
+ self.text_encoder_3_id = text_encoder_3_id
+ self.transformer_id = transformer_id
+ self.vae_id = vae_id
+ self.text_encoder_dtype = text_encoder_dtype
+ self.text_encoder_2_dtype = text_encoder_2_dtype
+ self.text_encoder_3_dtype = text_encoder_3_dtype
+ self.transformer_dtype = transformer_dtype
+ self.vae_dtype = vae_dtype
+ self.revision = revision
+ self.cache_dir = cache_dir
+ self.condition_model_processors = condition_model_processors or []
+ self.latent_model_processors = latent_model_processors or []
+
+ self.transformer_config: Dict[str, Any] = None
+ self.vae_config: Dict[str, Any] = None
+
+ self._load_configs()
+
+ def _trainer_init(self, *args, **kwargs):
+ pass
+
+ # TODO(aryan): revisit how to do this better without user having to worry about it
+ @property
+ def _resolution_dim_keys(self) -> Dict[str, Tuple[int, ...]]:
+ raise NotImplementedError(
+ f"ModelSpecification::_resolution_dim_keys is not implemented for {self.__class__.__name__}"
+ )
+
+ def load_condition_models(self) -> Dict[str, torch.nn.Module]:
+ raise NotImplementedError(
+ f"ModelSpecification::load_condition_models is not implemented for {self.__class__.__name__}"
+ )
+
+ def load_latent_models(self) -> Dict[str, torch.nn.Module]:
+ raise NotImplementedError(
+ f"ModelSpecification::load_latent_models is not implemented for {self.__class__.__name__}"
+ )
+
+ def load_diffusion_models(self) -> Dict[str, Union[torch.nn.Module]]:
+ raise NotImplementedError(
+ f"ModelSpecification::load_diffusion_models is not implemented for {self.__class__.__name__}"
+ )
+
+ def load_pipeline(
+ self,
+ tokenizer: Optional[TokenizerType] = None,
+ tokenizer_2: Optional[TokenizerType] = None,
+ tokenizer_3: Optional[TokenizerType] = None,
+ text_encoder: Optional[torch.nn.Module] = None,
+ text_encoder_2: Optional[torch.nn.Module] = None,
+ text_encoder_3: Optional[torch.nn.Module] = None,
+ transformer: Optional[torch.nn.Module] = None,
+ vae: Optional[torch.nn.Module] = None,
+ scheduler: Optional[SchedulerType] = None,
+ enable_slicing: bool = False,
+ enable_tiling: bool = False,
+ enable_model_cpu_offload: bool = False,
+ training: bool = False,
+ **kwargs,
+ ) -> DiffusionPipeline:
+ raise NotImplementedError(
+ f"ModelSpecification::load_pipeline is not implemented for {self.__class__.__name__}"
+ )
+
+ def prepare_conditions(self, processors: Optional[ProcessorMixin] = None, **kwargs) -> Dict[str, Any]:
+ if processors is None:
+ processors = self.condition_model_processors
+ for processor in processors:
+ result = processor(**kwargs)
+ result_keys = set(result.keys())
+ repeat_keys = result_keys.intersection(kwargs.keys())
+ if repeat_keys:
+ logger.warning(
+ f"Processor {processor.__class__.__name__} returned keys that already exist in "
+ f"conditions: {repeat_keys}. Overwriting the existing values, but this may not "
+ f"be intended. Please rename the keys in the processor to avoid conflicts."
+ )
+ kwargs.update(result)
+ return kwargs
+
+ def prepare_latents(self, processors: Optional[ProcessorMixin] = None, **kwargs) -> Dict[str, Any]:
+ if processors is None:
+ processors = self.latent_model_processors
+ for processor in processors:
+ result = processor(**kwargs)
+ result_keys = set(result.keys())
+ repeat_keys = result_keys.intersection(kwargs.keys())
+ if repeat_keys:
+ logger.warning(
+ f"Processor {processor.__class__.__name__} returned keys that already exist in "
+ f"conditions: {repeat_keys}. Overwriting the existing values, but this may not "
+ f"be intended. Please rename the keys in the processor to avoid conflicts."
+ )
+ kwargs.update(result)
+ return kwargs
+
+ def collate_conditions(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
+ keys = list(data[0].keys())
+ collated_data = {}
+ for key in keys:
+ if key in IGNORE_KEYS_FOR_COLLATION:
+ collated_data[key] = data[0][key]
+ continue
+ collated_d = [d[key] for d in data]
+ if isinstance(collated_d[0], torch.Tensor):
+ collated_d = torch.cat(collated_d)
+ collated_data[key] = collated_d
+ return collated_data
+
+ def collate_latents(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
+ keys = list(data[0].keys())
+ collated_data = {}
+ for key in keys:
+ if key in IGNORE_KEYS_FOR_COLLATION:
+ collated_data[key] = data[0][key]
+ continue
+ collated_d = [d[key] for d in data]
+ # TODO(aryan): Support multi-resolution collation
+ if isinstance(collated_d[0], torch.Tensor):
+ collated_d = torch.cat(collated_d)
+ collated_data[key] = collated_d
+ return collated_data
+
+ def forward(
+ self, transformer: torch.nn.Module, generator: Optional[torch.Generator] = None, **kwargs
+ ) -> Dict[str, torch.Tensor]:
+ raise NotImplementedError(f"ModelSpecification::forward is not implemented for {self.__class__.__name__}")
+
+ def validation(
+ self,
+ pipeline: DiffusionPipeline,
+ prompt: Optional[str] = None,
+ image: Optional[Image] = None,
+ video: Optional[List[Image]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_frames: Optional[int] = None,
+ frame_rate: Optional[int] = None,
+ generator: Optional[torch.Generator] = None,
+ ) -> List[ArtifactType]:
+ raise NotImplementedError(f"ModelSpecification::validation is not implemented for {self.__class__.__name__}")
+
+ def _save_lora_weights(
+ self,
+ directory: str,
+ transformer: torch.nn.Module,
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ scheduler: Optional[SchedulerType] = None,
+ metadata: Optional[Dict[str, str]] = None,
+ ) -> None:
+ r"""
+ Save the lora state dicts of the model to the given directory.
+
+ This API is not backwards compatible and will be changed in near future.
+ """
+ raise NotImplementedError(
+ f"ModelSpecification::save_lora_weights is not implemented for {self.__class__.__name__}"
+ )
+
+ def _save_model(
+ self,
+ directory: str,
+ transformer: torch.nn.Module,
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ scheduler: Optional[SchedulerType] = None,
+ ) -> None:
+ r"""
+ Save the state dicts to the given directory.
+
+ This API is not backwards compatible and will be changed in near future.
+ """
+ raise NotImplementedError(f"ModelSpecification::save_model is not implemented for {self.__class__.__name__}")
+
+ def apply_tensor_parallel(
+ self,
+ backend: ParallelBackendEnum,
+ device_mesh: torch.distributed.DeviceMesh,
+ text_encoder: torch.nn.Module,
+ text_encoder_2: torch.nn.Module,
+ text_encoder_3: torch.nn.Module,
+ transformer: torch.nn.Module,
+ vae: torch.nn.Module,
+ ) -> None:
+ raise NotImplementedError(
+ f"ModelSpecification::apply_tensor_parallel is not implemented for {self.__class__.__name__}"
+ )
+
+ def _load_configs(self) -> None:
+ self._load_transformer_config()
+ self._load_vae_config()
+
+ def _load_transformer_config(self) -> None:
+ if self.transformer_id is not None:
+ transformer_cls = resolve_component_cls(
+ self.transformer_id,
+ component_name="_class_name",
+ filename="config.json",
+ revision=self.revision,
+ cache_dir=self.cache_dir,
+ )
+ self.transformer_config = transformer_cls.load_config(
+ self.transformer_id, revision=self.revision, cache_dir=self.cache_dir
+ )
+ else:
+ transformer_cls = resolve_component_cls(
+ self.pretrained_model_name_or_path,
+ component_name="transformer",
+ filename="model_index.json",
+ revision=self.revision,
+ cache_dir=self.cache_dir,
+ )
+ self.transformer_config = transformer_cls.load_config(
+ self.pretrained_model_name_or_path,
+ subfolder="transformer",
+ revision=self.revision,
+ cache_dir=self.cache_dir,
+ )
+ self.transformer_config = FrozenDict(**self.transformer_config)
+
+ def _load_vae_config(self) -> None:
+ if self.vae_id is not None:
+ vae_cls = resolve_component_cls(
+ self.vae_id,
+ component_name="_class_name",
+ filename="config.json",
+ revision=self.revision,
+ cache_dir=self.cache_dir,
+ )
+ self.vae_config = vae_cls.load_config(self.vae_id, revision=self.revision, cache_dir=self.cache_dir)
+ else:
+ vae_cls = resolve_component_cls(
+ self.pretrained_model_name_or_path,
+ component_name="vae",
+ filename="model_index.json",
+ revision=self.revision,
+ cache_dir=self.cache_dir,
+ )
+ self.vae_config = vae_cls.load_config(
+ self.pretrained_model_name_or_path, subfolder="vae", revision=self.revision, cache_dir=self.cache_dir
+ )
+ self.vae_config = FrozenDict(**self.vae_config)
+
+
+class ControlModelSpecification(ModelSpecification):
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+
+ self.frame_conditioning_type: "FrameConditioningType" = None
+ self.frame_conditioning_index: int = None
+ self.frame_conditioning_concatenate_mask: bool = False
+
+ def _trainer_init(
+ self, frame_conditioning_type: "FrameConditioningType", frame_conditioning_index: int, concatenate_mask: bool
+ ) -> None:
+ self.frame_conditioning_type = frame_conditioning_type
+ self.frame_conditioning_index = frame_conditioning_index
+ self.frame_conditioning_concatenate_mask = concatenate_mask
+
+ @property
+ def control_injection_layer_name(self):
+ r"""Must return the FQN (fully-qualified name) of the control injection layer."""
+ raise NotImplementedError(
+ f"ControlModelSpecification::control_injection_layer_name is not implemented for {self.__class__.__name__}"
+ )
+
+ def load_diffusion_models(self, new_in_features: int) -> Dict[str, Union[torch.nn.Module]]:
+ raise NotImplementedError(
+ f"ControlModelSpecification::load_diffusion_models is not implemented for {self.__class__.__name__}"
+ )
+
+ def _save_lora_weights(
+ self,
+ directory: str,
+ transformer: torch.nn.Module,
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ norm_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ scheduler: Optional[SchedulerType] = None,
+ metadata: Optional[Dict[str, str]] = None,
+ ) -> None:
+ r"""
+ Save the lora state dicts of the model to the given directory.
+
+ This API is not backwards compatible and will be changed in near future.
+ """
+ raise NotImplementedError(
+ f"ControlModelSpecification::save_lora_weights is not implemented for {self.__class__.__name__}"
+ )
+
+ def _save_model(
+ self,
+ directory: str,
+ transformer: torch.nn.Module,
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ scheduler: Optional[SchedulerType] = None,
+ ) -> None:
+ r"""
+ Save the state dicts to the given directory.
+
+ This API is not backwards compatible and will be changed in near future.
+ """
+ raise NotImplementedError(
+ f"ControlModelSpecification::save_model is not implemented for {self.__class__.__name__}"
+ )
+
+ @property
+ def _original_control_layer_in_features(self):
+ """
+ Original in_features of the input projection layer where control is injected.
+ """
+ raise NotImplementedError(
+ f"ControlModelSpecification::_original_control_layer_in_features is not implemented for {self.__class__.__name__}"
+ )
+
+ @property
+ def _original_control_layer_out_features(self):
+ """
+ Original out_features of the input projection layer where control is injected.
+
+ This will be used as the rank for control injection layer when performing low-rank training and unused otherwise.
+ """
+ raise NotImplementedError(
+ f"ControlModelSpecification::_original_control_layer_out_features is not implemented for {self.__class__.__name__}"
+ )
+
+ @property
+ def _qk_norm_identifiers(self):
+ raise NotImplementedError(
+ f"ControlModelSpecification::_qk_norm_identifiers is not implemented for {self.__class__.__name__}"
+ )
diff --git a/docs/finetrainers-src-codebase/finetrainers/models/utils.py b/docs/finetrainers-src-codebase/finetrainers/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ea9370146c26b5fa3a9de95e4ed4bc9805ce318
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/models/utils.py
@@ -0,0 +1,109 @@
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+from diffusers.utils.torch_utils import randn_tensor
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False, _dim: int = 1):
+ # Note: _dim is the new argument added here after copying from diffusers
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=_dim)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
+ )
+
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
+ # make sure sample is on the same device as the parameters and has same dtype
+ sample = randn_tensor(
+ self.mean.shape,
+ generator=generator,
+ device=self.parameters.device,
+ dtype=self.parameters.dtype,
+ )
+ x = self.mean + self.std * sample
+ return x
+
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3],
+ )
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims,
+ )
+
+ def mode(self) -> torch.Tensor:
+ return self.mean
+
+
+@torch.no_grad()
+def _expand_linear_with_zeroed_weights(
+ module: torch.nn.Linear, new_in_features: Optional[int] = None, new_out_features: Optional[int] = None
+) -> torch.nn.Linear:
+ if new_in_features is None:
+ new_in_features = module.in_features
+ if new_out_features is None:
+ new_out_features = module.out_features
+ bias = getattr(module, "bias", None)
+ new_module = torch.nn.Linear(new_in_features, new_out_features, bias=bias is not None)
+ new_module.to(device=module.weight.device, dtype=module.weight.dtype)
+ new_module.weight.zero_()
+ new_module.weight.data[: module.weight.data.shape[0], : module.weight.data.shape[1]].copy_(module.weight.data)
+ if bias is not None:
+ new_module.bias.zero_()
+ new_module.bias.data[: bias.data.shape[0]].copy_(bias.data)
+ return new_module
+
+
+@torch.no_grad()
+def _expand_conv3d_with_zeroed_weights(
+ module: torch.nn.Linear, new_in_channels: Optional[int] = None, new_out_channels: Optional[int] = None
+) -> torch.nn.Conv3d:
+ if new_in_channels is None:
+ new_in_channels = module.in_channels
+ if new_out_channels is None:
+ new_out_channels = module.out_channels
+ bias = getattr(module, "bias", None)
+ new_module = torch.nn.Conv3d(
+ new_in_channels,
+ new_out_channels,
+ kernel_size=module.kernel_size,
+ stride=module.stride,
+ padding=module.padding,
+ dilation=module.dilation,
+ groups=module.groups,
+ bias=bias is not None,
+ )
+ new_module.to(device=module.weight.device, dtype=module.weight.dtype)
+ new_module.weight.zero_()
+ new_module.weight.data[: module.weight.data.shape[0], : module.weight.data.shape[1]].copy_(module.weight.data)
+ if bias is not None:
+ new_module.bias.zero_()
+ new_module.bias.data[: bias.data.shape[0]].copy_(bias.data)
+ return new_module
diff --git a/docs/finetrainers-src-codebase/finetrainers/models/wan/__init__.py b/docs/finetrainers-src-codebase/finetrainers/models/wan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb19b5b4f0440a4c620e35efbcda0a3d18b1cbab
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/models/wan/__init__.py
@@ -0,0 +1,2 @@
+from .base_specification import WanModelSpecification
+from .control_specification import WanControlModelSpecification
diff --git a/docs/finetrainers-src-codebase/finetrainers/models/wan/base_specification.py b/docs/finetrainers-src-codebase/finetrainers/models/wan/base_specification.py
new file mode 100644
index 0000000000000000000000000000000000000000..633d532f8ab0f94357ecd33990e91530702544a6
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/models/wan/base_specification.py
@@ -0,0 +1,577 @@
+import functools
+import os
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import PIL.Image
+import torch
+from accelerate import init_empty_weights
+from diffusers import (
+ AutoencoderKLWan,
+ FlowMatchEulerDiscreteScheduler,
+ WanImageToVideoPipeline,
+ WanPipeline,
+ WanTransformer3DModel,
+)
+from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
+from transformers import AutoModel, AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
+
+import finetrainers.functional as FF
+from finetrainers.data import VideoArtifact
+from finetrainers.logging import get_logger
+from finetrainers.models.modeling_utils import ModelSpecification
+from finetrainers.processors import ProcessorMixin, T5Processor
+from finetrainers.typing import ArtifactType, SchedulerType
+from finetrainers.utils import get_non_null_items, safetensors_torch_save_function
+
+
+logger = get_logger()
+
+
+class WanLatentEncodeProcessor(ProcessorMixin):
+ r"""
+ Processor to encode image/video into latents using the Wan VAE.
+
+ Args:
+ output_names (`List[str]`):
+ The names of the outputs that the processor returns. The outputs are in the following order:
+ - latents: The latents of the input image/video.
+ - latents_mean: The channel-wise mean of the latent space.
+ - latents_std: The channel-wise standard deviation of the latent space.
+ """
+
+ def __init__(self, output_names: List[str]):
+ super().__init__()
+ self.output_names = output_names
+ assert len(self.output_names) == 3
+
+ def forward(
+ self,
+ vae: AutoencoderKLWan,
+ image: Optional[torch.Tensor] = None,
+ video: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ ) -> Dict[str, torch.Tensor]:
+ device = vae.device
+ dtype = vae.dtype
+
+ if image is not None:
+ video = image.unsqueeze(1)
+
+ assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
+ video = video.to(device=device, dtype=dtype)
+ video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
+
+ if compute_posterior:
+ latents = vae.encode(video).latent_dist.sample(generator=generator)
+ latents = latents.to(dtype=dtype)
+ else:
+ # TODO(aryan): refactor in diffusers to have use_slicing attribute
+ # if vae.use_slicing and video.shape[0] > 1:
+ # encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
+ # moments = torch.cat(encoded_slices)
+ # else:
+ # moments = vae._encode(video)
+ moments = vae._encode(video)
+ latents = moments.to(dtype=dtype)
+
+ latents_mean = torch.tensor(vae.config.latents_mean)
+ latents_std = 1.0 / torch.tensor(vae.config.latents_std)
+
+ return {self.output_names[0]: latents, self.output_names[1]: latents_mean, self.output_names[2]: latents_std}
+
+
+class WanImageConditioningLatentEncodeProcessor(ProcessorMixin):
+ r"""
+ Processor to encode image/video into latents using the Wan VAE.
+
+ Args:
+ output_names (`List[str]`):
+ The names of the outputs that the processor returns. The outputs are in the following order:
+ - latents: The latents of the input image/video.
+ - latents_mean: The channel-wise mean of the latent space.
+ - latents_std: The channel-wise standard deviation of the latent space.
+ - mask: The conditioning frame mask for the input image/video.
+ """
+
+ def __init__(self, output_names: List[str], *, use_last_frame: bool = False):
+ super().__init__()
+ self.output_names = output_names
+ self.use_last_frame = use_last_frame
+ assert len(self.output_names) == 4
+
+ def forward(
+ self,
+ vae: AutoencoderKLWan,
+ image: Optional[torch.Tensor] = None,
+ video: Optional[torch.Tensor] = None,
+ compute_posterior: bool = True,
+ ) -> Dict[str, torch.Tensor]:
+ device = vae.device
+ dtype = vae.dtype
+
+ if image is not None:
+ video = image.unsqueeze(1)
+
+ assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
+ video = video.to(device=device, dtype=dtype)
+ video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
+
+ num_frames = video.size(2)
+ if not self.use_last_frame:
+ first_frame, remaining_frames = video[:, :, :1], video[:, :, 1:]
+ video = torch.cat([first_frame, torch.zeros_like(remaining_frames)], dim=2)
+ else:
+ first_frame, remaining_frames, last_frame = video[:, :, :1], video[:, :, 1:-1], video[:, :, -1:]
+ video = torch.cat([first_frame, torch.zeros_like(remaining_frames), last_frame], dim=2)
+
+ # Image conditioning uses argmax sampling, so we use "mode" here
+ if compute_posterior:
+ latents = vae.encode(video).latent_dist.mode()
+ latents = latents.to(dtype=dtype)
+ else:
+ # TODO(aryan): refactor in diffusers to have use_slicing attribute
+ # if vae.use_slicing and video.shape[0] > 1:
+ # encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
+ # moments = torch.cat(encoded_slices)
+ # else:
+ # moments = vae._encode(video)
+ moments = vae._encode(video)
+ latents = moments.to(dtype=dtype)
+
+ latents_mean = torch.tensor(vae.config.latents_mean)
+ latents_std = 1.0 / torch.tensor(vae.config.latents_std)
+
+ temporal_downsample = 2 ** sum(vae.temperal_downsample) if getattr(self, "vae", None) else 4
+ mask = latents.new_ones(latents.shape[0], 1, num_frames, latents.shape[3], latents.shape[4])
+ if not self.use_last_frame:
+ mask[:, :, 1:] = 0
+ else:
+ mask[:, :, 1:-1] = 0
+ first_frame_mask = mask[:, :, :1]
+ first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=temporal_downsample)
+ mask = torch.cat([first_frame_mask, mask[:, :, 1:]], dim=2)
+ mask = mask.view(latents.shape[0], -1, temporal_downsample, latents.shape[3], latents.shape[4])
+ mask = mask.transpose(1, 2)
+
+ return {
+ self.output_names[0]: latents,
+ self.output_names[1]: latents_mean,
+ self.output_names[2]: latents_std,
+ self.output_names[3]: mask,
+ }
+
+
+class WanImageEncodeProcessor(ProcessorMixin):
+ r"""
+ Processor to encoding image conditioning for Wan I2V training.
+
+ Args:
+ output_names (`List[str]`):
+ The names of the outputs that the processor returns. The outputs are in the following order:
+ - image_embeds: The CLIP vision model image embeddings of the input image.
+ """
+
+ def __init__(self, output_names: List[str], *, use_last_frame: bool = False):
+ super().__init__()
+ self.output_names = output_names
+ self.use_last_frame = use_last_frame
+ assert len(self.output_names) == 1
+
+ def forward(
+ self,
+ image_encoder: CLIPVisionModel,
+ image_processor: CLIPImageProcessor,
+ image: Optional[torch.Tensor] = None,
+ video: Optional[torch.Tensor] = None,
+ ) -> Dict[str, torch.Tensor]:
+ device = image_encoder.device
+ dtype = image_encoder.dtype
+ last_image = None
+
+ # We know the image here is in the range [-1, 1] (probably a little overshot if using bilinear interpolation), but
+ # the processor expects it to be in the range [0, 1].
+ image = image if video is None else video[:, 0] # [B, F, C, H, W] -> [B, C, H, W] (take first frame)
+ image = FF.normalize(image, min=0.0, max=1.0, dim=1)
+ assert image.ndim == 4, f"Expected 4D tensor, got {image.ndim}D tensor"
+
+ if self.use_last_frame:
+ last_image = image if video is None else video[:, -1]
+ last_image = FF.normalize(last_image, min=0.0, max=1.0, dim=1)
+ image = torch.stack([image, last_image], dim=0)
+
+ image = image_processor(images=image.float(), do_rescale=False, do_convert_rgb=False, return_tensors="pt")
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = image_encoder(**image, output_hidden_states=True)
+ image_embeds = image_embeds.hidden_states[-2]
+ return {self.output_names[0]: image_embeds}
+
+
+class WanModelSpecification(ModelSpecification):
+ def __init__(
+ self,
+ pretrained_model_name_or_path: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
+ tokenizer_id: Optional[str] = None,
+ text_encoder_id: Optional[str] = None,
+ transformer_id: Optional[str] = None,
+ vae_id: Optional[str] = None,
+ text_encoder_dtype: torch.dtype = torch.bfloat16,
+ transformer_dtype: torch.dtype = torch.bfloat16,
+ vae_dtype: torch.dtype = torch.bfloat16,
+ revision: Optional[str] = None,
+ cache_dir: Optional[str] = None,
+ condition_model_processors: List[ProcessorMixin] = None,
+ latent_model_processors: List[ProcessorMixin] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ tokenizer_id=tokenizer_id,
+ text_encoder_id=text_encoder_id,
+ transformer_id=transformer_id,
+ vae_id=vae_id,
+ text_encoder_dtype=text_encoder_dtype,
+ transformer_dtype=transformer_dtype,
+ vae_dtype=vae_dtype,
+ revision=revision,
+ cache_dir=cache_dir,
+ )
+
+ use_last_frame = self.transformer_config.get("pos_embed_seq_len", None) is not None
+
+ if condition_model_processors is None:
+ condition_model_processors = [T5Processor(["encoder_hidden_states", "__drop__"])]
+ if latent_model_processors is None:
+ latent_model_processors = [WanLatentEncodeProcessor(["latents", "latents_mean", "latents_std"])]
+
+ if self.transformer_config.get("image_dim", None) is not None:
+ latent_model_processors.append(
+ WanImageConditioningLatentEncodeProcessor(
+ ["latent_condition", "__drop__", "__drop__", "latent_condition_mask"],
+ use_last_frame=use_last_frame,
+ )
+ )
+ latent_model_processors.append(
+ WanImageEncodeProcessor(["encoder_hidden_states_image"], use_last_frame=use_last_frame)
+ )
+
+ self.condition_model_processors = condition_model_processors
+ self.latent_model_processors = latent_model_processors
+
+ @property
+ def _resolution_dim_keys(self):
+ return {"latents": (2, 3, 4)}
+
+ def load_condition_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.tokenizer_id is not None:
+ tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs)
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(
+ self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs
+ )
+
+ if self.text_encoder_id is not None:
+ text_encoder = AutoModel.from_pretrained(
+ self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs
+ )
+ else:
+ text_encoder = UMT5EncoderModel.from_pretrained(
+ self.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ torch_dtype=self.text_encoder_dtype,
+ **common_kwargs,
+ )
+
+ return {"tokenizer": tokenizer, "text_encoder": text_encoder}
+
+ def load_latent_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.vae_id is not None:
+ vae = AutoencoderKLWan.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs)
+ else:
+ vae = AutoencoderKLWan.from_pretrained(
+ self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs
+ )
+
+ models = {"vae": vae}
+ if self.transformer_config.get("image_dim", None) is not None:
+ # TODO(aryan): refactor the trainer to be able to support these extra models from CLI args more easily
+ image_encoder = CLIPVisionModel.from_pretrained(
+ self.pretrained_model_name_or_path, subfolder="image_encoder", torch_dtype=torch.bfloat16
+ )
+ image_processor = CLIPImageProcessor.from_pretrained(
+ self.pretrained_model_name_or_path, subfolder="image_processor"
+ )
+ models["image_encoder"] = image_encoder
+ models["image_processor"] = image_processor
+
+ return models
+
+ def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.transformer_id is not None:
+ transformer = WanTransformer3DModel.from_pretrained(
+ self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs
+ )
+ else:
+ transformer = WanTransformer3DModel.from_pretrained(
+ self.pretrained_model_name_or_path,
+ subfolder="transformer",
+ torch_dtype=self.transformer_dtype,
+ **common_kwargs,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {"transformer": transformer, "scheduler": scheduler}
+
+ def load_pipeline(
+ self,
+ tokenizer: Optional[AutoTokenizer] = None,
+ text_encoder: Optional[UMT5EncoderModel] = None,
+ transformer: Optional[WanTransformer3DModel] = None,
+ vae: Optional[AutoencoderKLWan] = None,
+ scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
+ image_encoder: Optional[CLIPVisionModel] = None,
+ image_processor: Optional[CLIPImageProcessor] = None,
+ enable_slicing: bool = False,
+ enable_tiling: bool = False,
+ enable_model_cpu_offload: bool = False,
+ training: bool = False,
+ **kwargs,
+ ) -> Union[WanPipeline, WanImageToVideoPipeline]:
+ components = {
+ "tokenizer": tokenizer,
+ "text_encoder": text_encoder,
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "image_encoder": image_encoder,
+ "image_processor": image_processor,
+ }
+ components = get_non_null_items(components)
+
+ if self.transformer_config.get("image_dim", None) is not None:
+ pipe = WanPipeline.from_pretrained(
+ self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
+ )
+ else:
+ pipe = WanImageToVideoPipeline.from_pretrained(
+ self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
+ )
+ pipe.text_encoder.to(self.text_encoder_dtype)
+ pipe.vae.to(self.vae_dtype)
+
+ if not training:
+ pipe.transformer.to(self.transformer_dtype)
+
+ # TODO(aryan): add support in diffusers
+ # if enable_slicing:
+ # pipe.vae.enable_slicing()
+ # if enable_tiling:
+ # pipe.vae.enable_tiling()
+ if enable_model_cpu_offload:
+ pipe.enable_model_cpu_offload()
+
+ return pipe
+
+ @torch.no_grad()
+ def prepare_conditions(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ caption: str,
+ max_sequence_length: int = 512,
+ **kwargs,
+ ) -> Dict[str, Any]:
+ conditions = {
+ "tokenizer": tokenizer,
+ "text_encoder": text_encoder,
+ "caption": caption,
+ "max_sequence_length": max_sequence_length,
+ **kwargs,
+ }
+ input_keys = set(conditions.keys())
+ conditions = super().prepare_conditions(**conditions)
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+ return conditions
+
+ @torch.no_grad()
+ def prepare_latents(
+ self,
+ vae: AutoencoderKLWan,
+ image_encoder: Optional[CLIPVisionModel] = None,
+ image_processor: Optional[CLIPImageProcessor] = None,
+ image: Optional[torch.Tensor] = None,
+ video: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ **kwargs,
+ ) -> Dict[str, torch.Tensor]:
+ conditions = {
+ "vae": vae,
+ "image_encoder": image_encoder,
+ "image_processor": image_processor,
+ "image": image,
+ "video": video,
+ "generator": generator,
+ # We must force this to False because the latent normalization should be done before
+ # the posterior is computed. The VAE does not handle this any more:
+ # https://github.com/huggingface/diffusers/pull/10998
+ "compute_posterior": False,
+ **kwargs,
+ }
+ input_keys = set(conditions.keys())
+ conditions = super().prepare_latents(**conditions)
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+ return conditions
+
+ def forward(
+ self,
+ transformer: WanTransformer3DModel,
+ condition_model_conditions: Dict[str, torch.Tensor],
+ latent_model_conditions: Dict[str, torch.Tensor],
+ sigmas: torch.Tensor,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, ...]:
+ compute_posterior = False # See explanation in prepare_latents
+ latent_condition = latent_condition_mask = None
+
+ if compute_posterior:
+ latents = latent_model_conditions.pop("latents")
+ latent_condition = latent_model_conditions.pop("latent_condition", None)
+ latent_condition_mask = latent_model_conditions.pop("latent_condition_mask", None)
+ else:
+ latents = latent_model_conditions.pop("latents")
+ latents_mean = latent_model_conditions.pop("latents_mean")
+ latents_std = latent_model_conditions.pop("latents_std")
+ latent_condition = latent_model_conditions.pop("latent_condition", None)
+ latent_condition_mask = latent_model_conditions.pop("latent_condition_mask", None)
+
+ mu, logvar = torch.chunk(latents, 2, dim=1)
+ mu = self._normalize_latents(mu, latents_mean, latents_std)
+ logvar = self._normalize_latents(logvar, latents_mean, latents_std)
+ latents = torch.cat([mu, logvar], dim=1)
+
+ posterior = DiagonalGaussianDistribution(latents)
+ latents = posterior.sample(generator=generator)
+
+ if latent_condition is not None:
+ mu, logvar = torch.chunk(latent_condition, 2, dim=1)
+ mu = self._normalize_latents(mu, latents_mean, latents_std)
+ logvar = self._normalize_latents(logvar, latents_mean, latents_std)
+ latent_condition = torch.cat([mu, logvar], dim=1)
+
+ posterior = DiagonalGaussianDistribution(latent_condition)
+ latent_condition = posterior.mode()
+
+ del posterior
+
+ noise = torch.zeros_like(latents).normal_(generator=generator)
+ noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
+ timesteps = (sigmas.flatten() * 1000.0).long()
+
+ if self.transformer_config.get("image_dim", None) is not None:
+ noisy_latents = torch.cat([noisy_latents, latent_condition_mask, latent_condition], dim=1)
+
+ latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
+
+ pred = transformer(
+ **latent_model_conditions,
+ **condition_model_conditions,
+ timestep=timesteps,
+ return_dict=False,
+ )[0]
+ target = FF.flow_match_target(noise, latents)
+
+ return pred, target, sigmas
+
+ def validation(
+ self,
+ pipeline: Union[WanPipeline, WanImageToVideoPipeline],
+ prompt: str,
+ image: Optional[PIL.Image.Image] = None,
+ last_image: Optional[PIL.Image.Image] = None,
+ video: Optional[List[PIL.Image.Image]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_frames: Optional[int] = None,
+ num_inference_steps: int = 50,
+ generator: Optional[torch.Generator] = None,
+ **kwargs,
+ ) -> List[ArtifactType]:
+ generation_kwargs = {
+ "prompt": prompt,
+ "height": height,
+ "width": width,
+ "num_frames": num_frames,
+ "num_inference_steps": num_inference_steps,
+ "generator": generator,
+ "return_dict": True,
+ "output_type": "pil",
+ }
+ if self.transformer_config.get("image_dim", None) is not None:
+ if image is None and video is None:
+ raise ValueError("Either image or video must be provided for Wan I2V validation.")
+ image = image if image is not None else video[0]
+ generation_kwargs["image"] = image
+ if self.transformer_config.get("pos_embed_seq_len", None) is not None:
+ last_image = last_image if last_image is not None else image if video is None else video[-1]
+ generation_kwargs["last_image"] = last_image
+ generation_kwargs = get_non_null_items(generation_kwargs)
+ video = pipeline(**generation_kwargs).frames[0]
+ return [VideoArtifact(value=video)]
+
+ def _save_lora_weights(
+ self,
+ directory: str,
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ scheduler: Optional[SchedulerType] = None,
+ metadata: Optional[Dict[str, str]] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ pipeline_cls = (
+ WanImageToVideoPipeline if self.transformer_config.get("image_dim", None) is not None else WanPipeline
+ )
+ # TODO(aryan): this needs refactoring
+ if transformer_state_dict is not None:
+ pipeline_cls.save_lora_weights(
+ directory,
+ transformer_state_dict,
+ save_function=functools.partial(safetensors_torch_save_function, metadata=metadata),
+ safe_serialization=True,
+ )
+ if scheduler is not None:
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
+
+ def _save_model(
+ self,
+ directory: str,
+ transformer: WanTransformer3DModel,
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ scheduler: Optional[SchedulerType] = None,
+ ) -> None:
+ # TODO(aryan): this needs refactoring
+ if transformer_state_dict is not None:
+ with init_empty_weights():
+ transformer_copy = WanTransformer3DModel.from_config(transformer.config)
+ transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
+ transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
+ if scheduler is not None:
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
+
+ @staticmethod
+ def _normalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
+ ) -> torch.Tensor:
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device)
+ latents = ((latents.float() - latents_mean) * latents_std).to(latents)
+ return latents
diff --git a/docs/finetrainers-src-codebase/finetrainers/models/wan/control_specification.py b/docs/finetrainers-src-codebase/finetrainers/models/wan/control_specification.py
new file mode 100644
index 0000000000000000000000000000000000000000..78a4de9b1cd92e9414227da992731f1facc1c137
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/models/wan/control_specification.py
@@ -0,0 +1,437 @@
+import functools
+import os
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+
+import safetensors
+import torch
+from accelerate import init_empty_weights
+from diffusers import (
+ AutoencoderKLWan,
+ FlowMatchEulerDiscreteScheduler,
+ WanPipeline,
+ WanTransformer3DModel,
+)
+from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
+from transformers import AutoModel, AutoTokenizer, UMT5EncoderModel
+
+import finetrainers.functional as FF
+from finetrainers.data import VideoArtifact
+from finetrainers.logging import get_logger
+from finetrainers.models.modeling_utils import ControlModelSpecification
+from finetrainers.models.utils import _expand_conv3d_with_zeroed_weights
+from finetrainers.patches.dependencies.diffusers.control import control_channel_concat
+from finetrainers.processors import ProcessorMixin, T5Processor
+from finetrainers.typing import ArtifactType, SchedulerType
+from finetrainers.utils import get_non_null_items, safetensors_torch_save_function
+
+from .base_specification import WanLatentEncodeProcessor
+
+
+if TYPE_CHECKING:
+ from finetrainers.trainer.control_trainer.config import FrameConditioningType
+
+logger = get_logger()
+
+
+class WanControlModelSpecification(ControlModelSpecification):
+ def __init__(
+ self,
+ pretrained_model_name_or_path: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
+ tokenizer_id: Optional[str] = None,
+ text_encoder_id: Optional[str] = None,
+ transformer_id: Optional[str] = None,
+ vae_id: Optional[str] = None,
+ text_encoder_dtype: torch.dtype = torch.bfloat16,
+ transformer_dtype: torch.dtype = torch.bfloat16,
+ vae_dtype: torch.dtype = torch.bfloat16,
+ revision: Optional[str] = None,
+ cache_dir: Optional[str] = None,
+ condition_model_processors: List[ProcessorMixin] = None,
+ latent_model_processors: List[ProcessorMixin] = None,
+ control_model_processors: List[ProcessorMixin] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ tokenizer_id=tokenizer_id,
+ text_encoder_id=text_encoder_id,
+ transformer_id=transformer_id,
+ vae_id=vae_id,
+ text_encoder_dtype=text_encoder_dtype,
+ transformer_dtype=transformer_dtype,
+ vae_dtype=vae_dtype,
+ revision=revision,
+ cache_dir=cache_dir,
+ )
+
+ if condition_model_processors is None:
+ condition_model_processors = [T5Processor(["encoder_hidden_states", "__drop__"])]
+ if latent_model_processors is None:
+ latent_model_processors = [WanLatentEncodeProcessor(["latents", "latents_mean", "latents_std"])]
+ if control_model_processors is None:
+ control_model_processors = [WanLatentEncodeProcessor(["control_latents", "__drop__", "__drop__"])]
+
+ self.condition_model_processors = condition_model_processors
+ self.latent_model_processors = latent_model_processors
+ self.control_model_processors = control_model_processors
+
+ @property
+ def control_injection_layer_name(self) -> str:
+ return "patch_embedding"
+
+ @property
+ def _resolution_dim_keys(self):
+ return {"latents": (2, 3, 4)}
+
+ def load_condition_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.tokenizer_id is not None:
+ tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs)
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(
+ self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs
+ )
+
+ if self.text_encoder_id is not None:
+ text_encoder = AutoModel.from_pretrained(
+ self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs
+ )
+ else:
+ text_encoder = UMT5EncoderModel.from_pretrained(
+ self.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ torch_dtype=self.text_encoder_dtype,
+ **common_kwargs,
+ )
+
+ return {"tokenizer": tokenizer, "text_encoder": text_encoder}
+
+ def load_latent_models(self) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.vae_id is not None:
+ vae = AutoencoderKLWan.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs)
+ else:
+ vae = AutoencoderKLWan.from_pretrained(
+ self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs
+ )
+
+ return {"vae": vae}
+
+ def load_diffusion_models(self, new_in_features: int) -> Dict[str, torch.nn.Module]:
+ common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
+
+ if self.transformer_id is not None:
+ transformer = WanTransformer3DModel.from_pretrained(
+ self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs
+ )
+ else:
+ transformer = WanTransformer3DModel.from_pretrained(
+ self.pretrained_model_name_or_path,
+ subfolder="transformer",
+ torch_dtype=self.transformer_dtype,
+ **common_kwargs,
+ )
+
+ transformer.patch_embedding = _expand_conv3d_with_zeroed_weights(
+ transformer.patch_embedding, new_in_channels=new_in_features
+ )
+ transformer.register_to_config(in_channels=new_in_features)
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {"transformer": transformer, "scheduler": scheduler}
+
+ def load_pipeline(
+ self,
+ tokenizer: Optional[AutoTokenizer] = None,
+ text_encoder: Optional[UMT5EncoderModel] = None,
+ transformer: Optional[WanTransformer3DModel] = None,
+ vae: Optional[AutoencoderKLWan] = None,
+ scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
+ enable_slicing: bool = False,
+ enable_tiling: bool = False,
+ enable_model_cpu_offload: bool = False,
+ training: bool = False,
+ **kwargs,
+ ) -> WanPipeline:
+ components = {
+ "tokenizer": tokenizer,
+ "text_encoder": text_encoder,
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ }
+ components = get_non_null_items(components)
+
+ pipe = WanPipeline.from_pretrained(
+ self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
+ )
+ pipe.text_encoder.to(self.text_encoder_dtype)
+ pipe.vae.to(self.vae_dtype)
+
+ if not training:
+ pipe.transformer.to(self.transformer_dtype)
+
+ # TODO(aryan): add support in diffusers
+ # if enable_slicing:
+ # pipe.vae.enable_slicing()
+ # if enable_tiling:
+ # pipe.vae.enable_tiling()
+ if enable_model_cpu_offload:
+ pipe.enable_model_cpu_offload()
+
+ return pipe
+
+ @torch.no_grad()
+ def prepare_conditions(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ caption: str,
+ max_sequence_length: int = 512,
+ **kwargs,
+ ) -> Dict[str, Any]:
+ conditions = {
+ "tokenizer": tokenizer,
+ "text_encoder": text_encoder,
+ "caption": caption,
+ "max_sequence_length": max_sequence_length,
+ **kwargs,
+ }
+ input_keys = set(conditions.keys())
+ conditions = super().prepare_conditions(**conditions)
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+ return conditions
+
+ @torch.no_grad()
+ def prepare_latents(
+ self,
+ vae: AutoencoderKLWan,
+ image: Optional[torch.Tensor] = None,
+ video: Optional[torch.Tensor] = None,
+ control_image: Optional[torch.Tensor] = None,
+ control_video: Optional[torch.Tensor] = None,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ **kwargs,
+ ) -> Dict[str, torch.Tensor]:
+ common_kwargs = {
+ "vae": vae,
+ "generator": generator,
+ # We must force this to False because the latent normalization should be done before
+ # the posterior is computed. The VAE does not handle this any more:
+ # https://github.com/huggingface/diffusers/pull/10998
+ "compute_posterior": False,
+ **kwargs,
+ }
+ conditions = {"image": image, "video": video, **common_kwargs}
+ input_keys = set(conditions.keys())
+ conditions = super().prepare_latents(**conditions)
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+
+ control_conditions = {"image": control_image, "video": control_video, **common_kwargs}
+ input_keys = set(control_conditions.keys())
+ control_conditions = ControlModelSpecification.prepare_latents(
+ self, self.control_model_processors, **control_conditions
+ )
+ control_conditions = {k: v for k, v in control_conditions.items() if k not in input_keys}
+
+ return {**control_conditions, **conditions}
+
+ def forward(
+ self,
+ transformer: WanTransformer3DModel,
+ condition_model_conditions: Dict[str, torch.Tensor],
+ latent_model_conditions: Dict[str, torch.Tensor],
+ sigmas: torch.Tensor,
+ generator: Optional[torch.Generator] = None,
+ compute_posterior: bool = True,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, ...]:
+ from finetrainers.trainer.control_trainer.data import apply_frame_conditioning_on_latents
+
+ compute_posterior = False # See explanation in prepare_latents
+ if compute_posterior:
+ latents = latent_model_conditions.pop("latents")
+ control_latents = latent_model_conditions.pop("control_latents")
+ else:
+ latents = latent_model_conditions.pop("latents")
+ control_latents = latent_model_conditions.pop("control_latents")
+ latents_mean = latent_model_conditions.pop("latents_mean")
+ latents_std = latent_model_conditions.pop("latents_std")
+
+ mu, logvar = torch.chunk(latents, 2, dim=1)
+ mu = self._normalize_latents(mu, latents_mean, latents_std)
+ logvar = self._normalize_latents(logvar, latents_mean, latents_std)
+ latents = torch.cat([mu, logvar], dim=1)
+
+ mu, logvar = torch.chunk(control_latents, 2, dim=1)
+ mu = self._normalize_latents(mu, latents_mean, latents_std)
+ logvar = self._normalize_latents(logvar, latents_mean, latents_std)
+ control_latents = torch.cat([mu, logvar], dim=1)
+
+ posterior = DiagonalGaussianDistribution(latents)
+ latents = posterior.mode()
+ del posterior
+
+ control_posterior = DiagonalGaussianDistribution(control_latents)
+ control_latents = control_posterior.mode()
+ del control_posterior
+
+ noise = torch.zeros_like(latents).normal_(generator=generator)
+ timesteps = (sigmas.flatten() * 1000.0).long()
+
+ noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
+ control_latents = apply_frame_conditioning_on_latents(
+ control_latents,
+ noisy_latents.shape[2],
+ channel_dim=1,
+ frame_dim=2,
+ frame_conditioning_type=self.frame_conditioning_type,
+ frame_conditioning_index=self.frame_conditioning_index,
+ concatenate_mask=self.frame_conditioning_concatenate_mask,
+ )
+ noisy_latents = torch.cat([noisy_latents, control_latents], dim=1)
+
+ latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
+
+ pred = transformer(
+ **latent_model_conditions,
+ **condition_model_conditions,
+ timestep=timesteps,
+ return_dict=False,
+ )[0]
+ target = FF.flow_match_target(noise, latents)
+
+ return pred, target, sigmas
+
+ def validation(
+ self,
+ pipeline: WanPipeline,
+ prompt: str,
+ control_image: Optional[torch.Tensor] = None,
+ control_video: Optional[torch.Tensor] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_frames: Optional[int] = None,
+ num_inference_steps: int = 50,
+ generator: Optional[torch.Generator] = None,
+ frame_conditioning_type: "FrameConditioningType" = "full",
+ frame_conditioning_index: int = 0,
+ **kwargs,
+ ) -> List[ArtifactType]:
+ from finetrainers.trainer.control_trainer.data import apply_frame_conditioning_on_latents
+
+ with torch.no_grad():
+ dtype = pipeline.vae.dtype
+ device = pipeline._execution_device
+ in_channels = self.transformer_config.in_channels # We need to use the original in_channels
+ latents = pipeline.prepare_latents(1, in_channels, height, width, num_frames, dtype, device, generator)
+ latents_mean = (
+ torch.tensor(self.vae_config.latents_mean)
+ .view(1, self.vae_config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae_config.latents_std).view(1, self.vae_config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+
+ if control_image is not None:
+ control_video = pipeline.video_processor.preprocess(
+ control_image, height=height, width=width
+ ).unsqueeze(2)
+ else:
+ control_video = pipeline.video_processor.preprocess_video(control_video, height=height, width=width)
+
+ control_video = control_video.to(device=device, dtype=dtype)
+ control_latents = pipeline.vae.encode(control_video).latent_dist.mode()
+ control_latents = self._normalize_latents(control_latents, latents_mean, latents_std)
+ control_latents = apply_frame_conditioning_on_latents(
+ control_latents,
+ latents.shape[2],
+ channel_dim=1,
+ frame_dim=2,
+ frame_conditioning_type=frame_conditioning_type,
+ frame_conditioning_index=frame_conditioning_index,
+ concatenate_mask=self.frame_conditioning_concatenate_mask,
+ )
+
+ generation_kwargs = {
+ "latents": latents,
+ "prompt": prompt,
+ "height": height,
+ "width": width,
+ "num_frames": num_frames,
+ "num_inference_steps": num_inference_steps,
+ "generator": generator,
+ "return_dict": True,
+ "output_type": "pil",
+ }
+ generation_kwargs = get_non_null_items(generation_kwargs)
+
+ with control_channel_concat(pipeline.transformer, ["hidden_states"], [control_latents], dims=[1]):
+ video = pipeline(**generation_kwargs).frames[0]
+
+ return [VideoArtifact(value=video)]
+
+ def _save_lora_weights(
+ self,
+ directory: str,
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ norm_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ scheduler: Optional[SchedulerType] = None,
+ metadata: Optional[Dict[str, str]] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ # TODO(aryan): this needs refactoring
+ if transformer_state_dict is not None:
+ WanPipeline.save_lora_weights(
+ directory,
+ transformer_state_dict,
+ save_function=functools.partial(safetensors_torch_save_function, metadata=metadata),
+ safe_serialization=True,
+ )
+ if norm_state_dict is not None:
+ safetensors.torch.save_file(norm_state_dict, os.path.join(directory, "norm_state_dict.safetensors"))
+ if scheduler is not None:
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
+
+ def _save_model(
+ self,
+ directory: str,
+ transformer: WanTransformer3DModel,
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+ scheduler: Optional[SchedulerType] = None,
+ ) -> None:
+ # TODO(aryan): this needs refactoring
+ if transformer_state_dict is not None:
+ with init_empty_weights():
+ transformer_copy = WanTransformer3DModel.from_config(transformer.config)
+ transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
+ transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
+ if scheduler is not None:
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
+
+ @staticmethod
+ def _normalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
+ ) -> torch.Tensor:
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device)
+ latents = ((latents.float() - latents_mean) * latents_std).to(latents)
+ return latents
+
+ @property
+ def _original_control_layer_in_features(self):
+ return self.transformer_config.in_channels
+
+ @property
+ def _original_control_layer_out_features(self):
+ return self.transformer_config.num_attention_heads * self.transformer_config.attention_head_dim
+
+ @property
+ def _qk_norm_identifiers(self):
+ return ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
diff --git a/docs/finetrainers-src-codebase/finetrainers/optimizer.py b/docs/finetrainers-src-codebase/finetrainers/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..57da28e9377f2bf82b5307fae83338ad0b9ec385
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/optimizer.py
@@ -0,0 +1,449 @@
+import functools
+import math
+from typing import Any, Callable, Dict, List, Optional, Type, Union
+
+import torch
+from torch.distributed.checkpoint.state_dict import (
+ StateDictOptions,
+ get_optimizer_state_dict,
+ set_optimizer_state_dict,
+)
+from torch.distributed.checkpoint.stateful import Stateful
+
+from .parallel import ParallelBackendEnum
+from .utils.import_utils import is_bitsandbytes_available
+
+
+class OptimizerWrapper(Stateful):
+ r"""
+ Optimizer wrapper that:
+ - allows step/zero_grad on multiple optimizers needed for virtual pipeline stages
+ - saves/loading optimizer state_dict at checkpoint
+ """
+
+ def __init__(
+ self,
+ model_parts: List[torch.nn.Module],
+ optimizer_cls: Type[torch.optim.Optimizer],
+ optimizer_kwargs: Dict[str, Any],
+ ) -> None:
+ self.optimizer_cls = optimizer_cls
+ self.optimizer_kwargs = optimizer_kwargs
+
+ self.optimizers = []
+ self.model_parts = model_parts
+
+ for model in self.model_parts:
+ optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs)
+ self.optimizers.append(optimizer)
+
+ def step(self) -> None:
+ for optimizer in self.optimizers:
+ optimizer.step()
+
+ def zero_grad(self) -> None:
+ for optimizer in self.optimizers:
+ optimizer.zero_grad()
+
+ def state_dict(self) -> Dict[str, Any]:
+ func = functools.partial(
+ get_optimizer_state_dict,
+ options=StateDictOptions(flatten_optimizer_state_dict=True),
+ )
+ return {k: v for sd in map(func, self.model_parts, self.optimizers) for k, v in sd.items()}
+
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+ func = functools.partial(
+ set_optimizer_state_dict,
+ optim_state_dict=state_dict,
+ options=StateDictOptions(flatten_optimizer_state_dict=True),
+ )
+ list(map(func, self.model_parts, self.optimizers))
+
+
+class SchedulerWrapper:
+ def __init__(
+ self, optimizers, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int
+ ) -> None:
+ self.schedulers = []
+ for optimizer in optimizers:
+ self.schedulers.append(torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch))
+
+ def step(self) -> None:
+ for scheduler in self.schedulers:
+ scheduler.step()
+
+ def get_last_lr(self) -> List[float]:
+ # TODO(aryan): look into this later. Currently calling it leads to NCCL hang?????
+ return {f"lr_{idx}": scheduler.get_last_lr() for idx, scheduler in enumerate(self.schedulers)}
+
+ def get_lr_scheduler_state(self) -> Dict[str, Any]:
+ state_dict = {}
+ if len(self.schedulers) == 1:
+ state_dict["lr_scheduler"] = self.schedulers[0]
+ else:
+ # For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler.
+ # It should only support saving and loading a distributed checkpoint with the same number of pp ranks
+ for idx, lr_scheduler in enumerate(self.schedulers):
+ state_dict[f"lr_scheduler_{idx}"] = lr_scheduler
+ return state_dict
+
+
+def get_optimizer(
+ parallel_backend: ParallelBackendEnum,
+ name: str,
+ model_parts: List[torch.nn.Module],
+ learning_rate: float = 1e-3,
+ beta1: float = 0.9,
+ beta2: float = 0.95,
+ beta3: float = 0.999,
+ epsilon: float = 1e-8,
+ weight_decay: float = 1e-4,
+ fused: bool = False,
+) -> Union[torch.optim.Optimizer, OptimizerWrapper]:
+ name = name.lower()
+
+ _raise_errors_if_packages_not_available(name)
+
+ if name == "adam":
+ optimizer_cls = torch.optim.Adam
+ optimizer_kwargs = {
+ "lr": learning_rate,
+ "betas": (beta1, beta2),
+ "eps": epsilon,
+ "weight_decay": weight_decay,
+ "fused": fused,
+ }
+ elif name == "adamw":
+ optimizer_cls = torch.optim.AdamW
+ optimizer_kwargs = {
+ "lr": learning_rate,
+ "betas": (beta1, beta2),
+ "eps": epsilon,
+ "weight_decay": weight_decay,
+ "fused": fused,
+ }
+ elif name == "adam-bnb":
+ from bitsandbytes.optim import Adam
+
+ optimizer_cls = Adam
+ optimizer_kwargs = {
+ "lr": learning_rate,
+ "betas": (beta1, beta2),
+ "eps": epsilon,
+ "weight_decay": weight_decay,
+ }
+ elif name == "adamw-bnb":
+ from bitsandbytes.optim import AdamW
+
+ optimizer_cls = AdamW
+ optimizer_kwargs = {
+ "lr": learning_rate,
+ "betas": (beta1, beta2),
+ "eps": epsilon,
+ "weight_decay": weight_decay,
+ }
+ elif name == "adam-bnb-8bit":
+ from bitsandbytes.optim import Adam8bit
+
+ optimizer_cls = Adam8bit
+ optimizer_kwargs = {
+ "lr": learning_rate,
+ "betas": (beta1, beta2),
+ "eps": epsilon,
+ "weight_decay": weight_decay,
+ }
+ elif name == "adamw-bnb-8bit":
+ from bitsandbytes.optim import AdamW8bit
+
+ optimizer_cls = AdamW8bit
+ optimizer_kwargs = {
+ "lr": learning_rate,
+ "betas": (beta1, beta2),
+ "eps": epsilon,
+ "weight_decay": weight_decay,
+ }
+
+ # TODO(aryan): handle bitsandbytes and torchao
+ else:
+ raise ValueError(f"Unsupported optimizer: {name}")
+
+ if parallel_backend == ParallelBackendEnum.ACCELERATE:
+ return get_optimizer_accelerate(model_parts, optimizer_cls, optimizer_kwargs)
+ elif parallel_backend == ParallelBackendEnum.PTD:
+ return get_optimizer_ptd(model_parts, optimizer_cls, optimizer_kwargs)
+
+
+def get_optimizer_accelerate(
+ model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any]
+) -> torch.optim.Optimizer:
+ params = [param for model in model_parts for param in model.parameters() if param.requires_grad]
+ optimizer = optimizer_cls(params, **optimizer_kwargs)
+ return optimizer
+
+
+def get_optimizer_ptd(
+ model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any]
+) -> OptimizerWrapper:
+ return OptimizerWrapper(model_parts, optimizer_cls, optimizer_kwargs)
+
+
+def get_lr_scheduler(
+ parallel_backend: ParallelBackendEnum,
+ name: str,
+ optimizer: Union[torch.optim.Optimizer, OptimizerWrapper],
+ step_rules: Optional[str] = None,
+ num_warmup_steps: Optional[int] = None,
+ num_training_steps: Optional[int] = None,
+ num_cycles: int = 1,
+ power: float = 1.0,
+ lr_init: float = 1e-3,
+ lr_end: float = 1e-7,
+ last_epoch: int = -1,
+) -> Union[torch.optim.lr_scheduler.LambdaLR, SchedulerWrapper]:
+ name = name.lower()
+ if name == "constant":
+ scheduler_lambda_fn = get_constant_schedule()
+ elif name == "constant_with_warmup":
+ scheduler_lambda_fn = get_constant_schedule_with_warmup(num_warmup_steps)
+ elif name == "piecewise_constant":
+ scheduler_lambda_fn = get_piecewise_constant_schedule(step_rules)
+ elif name == "linear":
+ scheduler_lambda_fn = get_linear_schedule_with_warmup(num_warmup_steps, num_training_steps)
+ elif name == "cosine":
+ scheduler_lambda_fn = get_cosine_schedule_with_warmup(num_warmup_steps, num_training_steps, num_cycles)
+ elif name == "cosine_with_restarts":
+ scheduler_lambda_fn = get_cosine_with_hard_restarts_schedule_with_warmup(
+ num_warmup_steps, num_training_steps, num_cycles
+ )
+ elif name == "polynomial":
+ scheduler_lambda_fn = get_polynomial_decay_schedule_with_warmup(
+ num_warmup_steps, num_training_steps, lr_init, lr_end, power
+ )
+ else:
+ raise ValueError(f"Unsupported scheduler: {name}")
+
+ if parallel_backend == ParallelBackendEnum.ACCELERATE:
+ return get_lr_scheduler_accelerate(optimizer, scheduler_lambda_fn, last_epoch)
+ elif parallel_backend == ParallelBackendEnum.PTD:
+ return get_lr_scheduler_ptd(optimizer, scheduler_lambda_fn, last_epoch)
+
+
+def get_lr_scheduler_accelerate(
+ optimizer: torch.optim.Optimizer,
+ scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler],
+ last_epoch: int = -1,
+) -> torch.optim.lr_scheduler.LambdaLR:
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch)
+ return scheduler
+
+
+def get_lr_scheduler_ptd(
+ optimizer: OptimizerWrapper, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int = -1
+) -> SchedulerWrapper:
+ return SchedulerWrapper(optimizer.optimizers, scheduler_lambda_fn, last_epoch)
+
+
+# ==============================
+# Adapted from https://github.com/huggingface/diffusers/blob/196aef5a6f76e1ad6ba889184860c3633d166910/src/diffusers/optimization.py
+# ==============================
+
+
+def get_constant_schedule() -> Callable[[int], float]:
+ r"""
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
+ """
+
+ def lr_lambda(current_step: int):
+ return 1.0
+
+ return lr_lambda
+
+
+def get_constant_schedule_with_warmup(num_warmup_steps: int) -> Callable[[int], float]:
+ r"""
+ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
+ increases linearly between 0 and the initial lr set in the optimizer.
+
+ Args:
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ """
+
+ def lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1.0, num_warmup_steps))
+ return 1.0
+
+ return lr_lambda
+
+
+def get_piecewise_constant_schedule(step_rules: str) -> Callable[[int], float]:
+ r"""
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
+
+ Args:
+ step_rules (`string`):
+ The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
+ if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
+ steps and multiple 0.005 for the other steps.
+ """
+
+ rules_dict = {}
+ rule_list = step_rules.split(",")
+ for rule_str in rule_list[:-1]:
+ value_str, steps_str = rule_str.split(":")
+ steps = int(steps_str)
+ value = float(value_str)
+ rules_dict[steps] = value
+ last_lr_multiple = float(rule_list[-1])
+
+ def create_rules_function(rules_dict, last_lr_multiple):
+ def rule_func(steps: int) -> float:
+ sorted_steps = sorted(rules_dict.keys())
+ for i, sorted_step in enumerate(sorted_steps):
+ if steps < sorted_step:
+ return rules_dict[sorted_steps[i]]
+ return last_lr_multiple
+
+ return rule_func
+
+ rules_func = create_rules_function(rules_dict, last_lr_multiple)
+ return rules_func
+
+
+def get_linear_schedule_with_warmup(num_warmup_steps: int, num_training_steps: int) -> Callable[[int], float]:
+ r"""
+ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
+ a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
+
+ Args:
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ """
+
+ def lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ return max(
+ 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
+ )
+
+ return lr_lambda
+
+
+def get_cosine_schedule_with_warmup(
+ num_warmup_steps: int,
+ num_training_steps: int,
+ num_cycles: float = 0.5,
+) -> Callable[[int], float]:
+ r"""
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
+ initial lr set in the optimizer.
+
+ Args:
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ num_periods (`float`, *optional*, defaults to 0.5):
+ The number of periods of the cosine function in a schedule (the default is to just decrease from the max
+ value to 0 following a half-cosine).
+ """
+
+ def lr_lambda(current_step):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
+
+ return lr_lambda
+
+
+def get_cosine_with_hard_restarts_schedule_with_warmup(
+ num_warmup_steps: int,
+ num_training_steps: int,
+ num_cycles: int = 1,
+) -> Callable[[int], float]:
+ r"""
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
+ initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
+ linearly between 0 and the initial lr set in the optimizer.
+
+ Args:
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ num_cycles (`int`, *optional*, defaults to 1):
+ The number of hard restarts to use.
+ """
+
+ def lr_lambda(current_step):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+ if progress >= 1.0:
+ return 0.0
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
+
+ return lr_lambda
+
+
+def get_polynomial_decay_schedule_with_warmup(
+ num_warmup_steps: int,
+ num_training_steps: int,
+ lr_init: float,
+ lr_end: float = 1e-7,
+ power: float = 1.0,
+) -> Callable[[int], float]:
+ r"""
+ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
+ optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
+ initial lr set in the optimizer.
+
+ Args:
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ lr_end (`float`, *optional*, defaults to 1e-7):
+ The end LR.
+ power (`float`, *optional*, defaults to 1.0):
+ Power factor.
+
+ Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT implementation at
+ https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
+ """
+
+ if not (lr_init > lr_end):
+ raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})")
+
+ def lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ elif current_step > num_training_steps:
+ return lr_end / lr_init # as LambdaLR multiplies by lr_init
+ else:
+ lr_range = lr_init - lr_end
+ decay_steps = num_training_steps - num_warmup_steps
+ pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
+ decay = lr_range * pct_remaining**power + lr_end
+ return decay / lr_init # as LambdaLR multiplies by lr_init
+
+ return lr_lambda
+
+
+def _raise_errors_if_packages_not_available(name: str) -> None:
+ name_split = name.split("-")
+ if len(name_split) < 2:
+ return
+ package_name = name_split[1]
+ if package_name == "bnb":
+ if not is_bitsandbytes_available():
+ raise ImportError(
+ f"Please install bitsandbytes by running `pip install bitsandbytes` to use the {name} optimizer."
+ )
diff --git a/docs/finetrainers-src-codebase/finetrainers/parallel/__init__.py b/docs/finetrainers-src-codebase/finetrainers/parallel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..878bd31bbe550bab7dae31e98e98d8a30038a6cc
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/parallel/__init__.py
@@ -0,0 +1,22 @@
+from enum import Enum
+from typing import Union
+
+from .accelerate import AccelerateParallelBackend
+from .ptd import PytorchDTensorParallelBackend
+from .utils import dist_max, dist_mean
+
+
+ParallelBackendType = Union[AccelerateParallelBackend, PytorchDTensorParallelBackend]
+
+
+class ParallelBackendEnum(str, Enum):
+ ACCELERATE = "accelerate"
+ PTD = "ptd"
+
+
+def get_parallel_backend_cls(backend: ParallelBackendEnum) -> ParallelBackendType:
+ if backend == ParallelBackendEnum.ACCELERATE:
+ return AccelerateParallelBackend
+ if backend == ParallelBackendEnum.PTD:
+ return PytorchDTensorParallelBackend
+ raise ValueError(f"Unknown parallel backend: {backend}")
diff --git a/docs/finetrainers-src-codebase/finetrainers/parallel/accelerate.py b/docs/finetrainers-src-codebase/finetrainers/parallel/accelerate.py
new file mode 100644
index 0000000000000000000000000000000000000000..59c1b5e4155528c55a2a4c968a6420c1b621d570
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/parallel/accelerate.py
@@ -0,0 +1,383 @@
+import datetime
+import os
+import pathlib
+import shutil
+import time
+from typing import Any, Callable, Dict, Optional
+
+import torch
+from diffusers.utils import is_accelerate_available
+
+from finetrainers.logging import get_logger
+from finetrainers.utils import get_device_info
+
+from .base import BaseCheckpointer, BaseParallelBackend
+
+
+if not is_accelerate_available():
+ raise ImportError(
+ "Please install the accelerate package using `pip install accelerate` to use the AccelerateParallelBackend."
+ )
+
+from accelerate import Accelerator
+from accelerate.data_loader import DataLoader
+from accelerate.utils import (
+ DataLoaderConfiguration,
+ DistributedDataParallelKwargs,
+ InitProcessGroupKwargs,
+ ProjectConfiguration,
+ set_seed,
+)
+
+
+logger = get_logger()
+_device_type, _device_module = get_device_info()
+
+
+class AccelerateParallelBackend(BaseParallelBackend):
+ def __init__(
+ self,
+ world_size: int,
+ pp_degree: int = 1,
+ dp_degree: int = 1,
+ dp_shards: int = -1,
+ cp_degree: int = 1,
+ tp_degree: int = 1,
+ backend: str = "nccl",
+ timeout: int = 180,
+ logging_dir: Optional[str] = None,
+ output_dir: Optional[str] = None,
+ gradient_accumulation_steps: Optional[int] = None,
+ ) -> None:
+ super().__init__()
+
+ self._world_size = world_size
+ self._pp_degree = pp_degree
+ self._dp_degree = dp_degree
+ self._dp_shards = dp_shards
+ self._cp_degree = cp_degree
+ self._tp_degree = tp_degree
+ self._output_dir = pathlib.Path(output_dir) if output_dir is not None else None
+ self._logging_dir = (
+ self._output_dir / logging_dir if output_dir is not None and logging_dir is not None else None
+ )
+ self._backend = backend
+ self._timeout = timeout
+ self._gradient_accumulation_steps = gradient_accumulation_steps
+
+ if pp_degree > 1 or dp_shards > 1 or cp_degree > 1 or tp_degree > 1:
+ raise ValueError(
+ "AccelerateParallelBackend does not support anything but Distributed Data Parallelism at the moment."
+ )
+ if dp_degree != world_size:
+ raise ValueError("Data parallel degree must be equal to world size.")
+
+ self._accelerator = None
+ if world_size == 1:
+ # Needs special handling for single GPU training
+ project_config = ProjectConfiguration(project_dir=self._output_dir, logging_dir=self._logging_dir)
+ dataloader_config = DataLoaderConfiguration(
+ split_batches=False, dispatch_batches=False, use_stateful_dataloader=True
+ )
+ init_process_group_kwargs = InitProcessGroupKwargs(
+ backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout)
+ )
+ self._accelerator = Accelerator(
+ project_config=project_config,
+ dataloader_config=dataloader_config,
+ gradient_accumulation_steps=gradient_accumulation_steps,
+ log_with=None,
+ kwargs_handlers=[init_process_group_kwargs],
+ )
+ if torch.backends.mps.is_available():
+ self._accelerator.native_amp = False
+
+ self._mesh: torch.distributed.DeviceMesh = None
+
+ def enable_determinism(self, seed: int) -> None:
+ set_seed(seed)
+
+ def apply_ddp(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
+ project_config = None
+ ddp_kwargs = None
+ init_process_group_kwargs = None
+ if self._accelerator is None:
+ project_config = ProjectConfiguration(project_dir=self._output_dir, logging_dir=self._logging_dir)
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
+ dataloader_config = DataLoaderConfiguration(
+ split_batches=False, dispatch_batches=False, use_stateful_dataloader=True
+ )
+ init_process_group_kwargs = InitProcessGroupKwargs(
+ backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout)
+ )
+ self._accelerator, model = apply_ddp(
+ model,
+ project_config,
+ ddp_kwargs,
+ init_process_group_kwargs,
+ dataloader_config,
+ self._gradient_accumulation_steps,
+ accelerator=self._accelerator,
+ )
+ logger.debug("Applied AccelerateParallel::apply_ddp to model.")
+ return model
+
+ def prepare_model(self, model: torch.nn.Module) -> torch.nn.Module:
+ return self._accelerator.prepare_model(model)
+
+ def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset:
+ logger.debug("AccelerateParallelBackend::prepare_dataset completed!")
+ return dataset
+
+ def prepare_dataloader(
+ self,
+ dataset: torch.utils.data.IterableDataset,
+ batch_size: int = 1,
+ num_workers: int = 0,
+ pin_memory: bool = False,
+ ) -> DataLoader:
+ dataloader = torch.utils.data.DataLoader(
+ dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory
+ )
+ dataloader = self._accelerator.prepare_data_loader(dataloader)
+ logger.debug("AccelerateParallelBackend::prepare_dataloader completed!")
+ return dataloader
+
+ def prepare_optimizer(self, optimizer, lr_scheduler):
+ optimizer = self._accelerator.prepare_optimizer(optimizer)
+ lr_scheduler = self._accelerator.prepare_scheduler(lr_scheduler)
+ return optimizer, lr_scheduler
+
+ def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh:
+ def _get_mesh():
+ if name is None:
+ return self._mesh
+ try:
+ return self._mesh[name]
+ except (KeyError, RuntimeError):
+ return self._mesh
+
+ if self._mesh is not None:
+ return _get_mesh()
+
+ mesh_list = [("dp_replicate", self._dp_degree), ("dp_shard", self._dp_shards)]
+ mesh_list = [(name, degree) for name, degree in mesh_list if degree > 1]
+ names = [x[0] for x in mesh_list]
+ degrees = [x[1] for x in mesh_list]
+ mesh = torch.distributed.device_mesh.init_device_mesh(_device_type, mesh_shape=degrees, mesh_dim_names=names)
+
+ dp_mesh_names, dp_cp_mesh_names, dp_shard_cp_mesh_names = [], [], []
+
+ if self.data_replication_enabled:
+ dp_mesh_names.append("dp_replicate")
+ dp_cp_mesh_names.append("dp_replicate")
+ if self.data_sharding_enabled:
+ dp_mesh_names.append("dp_shard")
+ dp_cp_mesh_names.append("dp_shard")
+ dp_shard_cp_mesh_names.append("dp_shard")
+ if self.context_parallel_enabled:
+ dp_cp_mesh_names.append("cp")
+ dp_shard_cp_mesh_names.append("cp")
+
+ if len(dp_mesh_names) > 0:
+ mesh[tuple(dp_mesh_names)]._flatten(mesh_dim_name="dp")
+ if len(dp_cp_mesh_names) > 0:
+ mesh[tuple(dp_cp_mesh_names)]._flatten(mesh_dim_name="dp_cp")
+ if len(dp_shard_cp_mesh_names) > 0:
+ mesh[tuple(dp_shard_cp_mesh_names)]._flatten(mesh_dim_name="dp_shard_cp")
+
+ logger.debug(f"Device mesh: {mesh}")
+ self._mesh = mesh
+ return _get_mesh()
+
+ def get_checkpointer(self, *args, **kwargs):
+ return AccelerateCheckpointer(self._accelerator, *args, **kwargs)
+
+ @property
+ def world_size(self):
+ return self._accelerator.num_processes
+
+ @property
+ def rank(self):
+ return self._accelerator.process_index
+
+ @property
+ def local_rank(self):
+ return self._accelerator.local_process_index
+
+ @property
+ def is_main_process(self):
+ r"""Returns `True` if the current process is the main process on the master node."""
+ return self._accelerator.is_main_process
+
+ @property
+ def is_local_main_process(self):
+ r"""Returns `True` if the current process is the main process on local node."""
+ return self._accelerator.is_local_main_process
+
+ @property
+ def device(self):
+ return self._accelerator.device
+
+ def wait_for_everyone(self):
+ self._accelerator.wait_for_everyone()
+
+ def destroy(self):
+ if self.is_main_process and self.tracker is not None:
+ self.tracker.finish()
+ self._accelerator.end_training()
+
+ @property
+ def pipeline_parallel_enabled(self):
+ return self._pp_degree > 1
+
+ @property
+ def data_parallel_enabled(self):
+ return self._dp_degree > 1 or self._dp_shards > 1
+
+ @property
+ def data_replication_enabled(self):
+ return self._dp_degree > 1
+
+ @property
+ def data_sharding_enabled(self):
+ return self._dp_shards > 1
+
+ @property
+ def context_parallel_enabled(self):
+ return self._cp_degree > 1
+
+ @property
+ def tensor_parallel_enabled(self):
+ return self._tp_degree > 1
+
+
+class AccelerateCheckpointer(BaseCheckpointer):
+ def __init__(
+ self,
+ accelerator: Accelerator,
+ states: Dict[str, Any],
+ checkpointing_steps: int,
+ checkpointing_limit: int,
+ output_dir: str,
+ enable: bool = True,
+ _callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None,
+ _prefix: str = "finetrainers_step",
+ *args,
+ **kwargs,
+ ) -> None:
+ self.accelerator = accelerator
+ self.states = states
+
+ self.checkpointing_steps = checkpointing_steps
+ self.checkpointing_limit = checkpointing_limit
+ self.output_dir = pathlib.Path(output_dir)
+ self.enable = enable
+ self._callback_fn = _callback_fn
+ self._prefix = _prefix
+
+ def save_model_hook(models, weights, output_dir: str) -> None:
+ if not self.accelerator.is_main_process:
+ return
+
+ # TODO(aryan): this is a temporary assertion since we only support training transformer at the moment.
+ # Remove it when adding support for training text encoders/vae and more.
+ assert len(models) == 1
+
+ _callback_fn(weights[0])
+ torch.save(self.states, os.path.join(output_dir, "states.pt"))
+
+ def load_model_hook(models, input_dir) -> None:
+ self.states = torch.load(os.path.join(input_dir, "states.pt"))
+
+ self.accelerator.register_save_state_pre_hook(save_model_hook)
+ self.accelerator.register_load_state_pre_hook(load_model_hook)
+
+ logger.info(f"Checkpointing enabled. Checkpoints will be stored in '{self.output_dir}'")
+
+ def save(self, step: int = -1, force: bool = False, *, _device: torch.device, _is_main_process: bool) -> str:
+ if not self._should_checkpoint(step, force):
+ return None
+
+ checkpoint_dir = self._get_checkpoint_dir(step)
+ begin_time = time.monotonic()
+ self.accelerator.save_state(checkpoint_dir.as_posix(), safe_serialization=True)
+ end_time = time.monotonic()
+ logger.info(
+ f"Saved checkpoint in {end_time - begin_time:.2f} seconds at step {step}. Directory: {checkpoint_dir}"
+ )
+ self._purge_stale_checkpoints()
+
+ return checkpoint_dir.as_posix()
+
+ def load(self, step: int = -1) -> bool:
+ if not self.enable:
+ return False
+ if not self.output_dir.exists():
+ return False
+ if step != -1 and not self._get_checkpoint_dir(step).exists():
+ return False
+
+ if step == -1:
+ latest_checkpoint_dir = self._find_latest_checkpoint_dir()
+ if latest_checkpoint_dir is None:
+ return False
+ step = int(latest_checkpoint_dir.name.split("_")[-1])
+
+ checkpoint_dir = self._get_checkpoint_dir(step)
+ logger.info(f"Loading checkpoint from '{checkpoint_dir}' at step {step}")
+
+ begin_time = time.monotonic()
+ self.accelerator.load_state(checkpoint_dir.as_posix())
+ end_time = time.monotonic()
+ logger.info(f"Loaded checkpoint in {end_time - begin_time:.2f} seconds.")
+
+ return True
+
+ def _should_checkpoint(self, step: int, force: bool) -> bool:
+ if not self.enable:
+ return False
+ if not force:
+ if step % self.checkpointing_steps != 0:
+ return False
+ return True
+
+ def _get_checkpoint_dir(self, step: int) -> pathlib.Path:
+ return self.output_dir / f"{self._prefix}_{step}"
+
+ def _find_latest_checkpoint_dir(self) -> Optional[pathlib.Path]:
+ checkpoints = sorted(self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]))
+ return checkpoints[-1] if len(checkpoints) > 0 else None
+
+ def _purge_stale_checkpoints(self) -> None:
+ if self.checkpointing_limit is None or self.checkpointing_limit <= 0:
+ return
+ checkpoints = sorted(
+ self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]), reverse=True
+ )
+ for checkpoint in checkpoints[self.checkpointing_limit :]:
+ logger.info(f"Deleting stale checkpoint: {checkpoint}")
+ shutil.rmtree(checkpoint, ignore_errors=True)
+
+
+def apply_ddp(
+ model: torch.nn.Module,
+ project_config: Optional[ProjectConfiguration] = None,
+ ddp_kwargs: Optional[DistributedDataParallelKwargs] = None,
+ init_process_group_kwargs: Optional[InitProcessGroupKwargs] = None,
+ dataloader_config: Optional[DataLoaderConfiguration] = None,
+ gradient_accumulation_steps: Optional[int] = None,
+ accelerator: Optional[Accelerator] = None,
+) -> torch.nn.Module:
+ if accelerator is None:
+ accelerator = Accelerator(
+ project_config=project_config,
+ dataloader_config=dataloader_config,
+ gradient_accumulation_steps=gradient_accumulation_steps,
+ log_with=None,
+ kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
+ )
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+ accelerator.prepare_model(model)
+ return accelerator, model
diff --git a/docs/finetrainers-src-codebase/finetrainers/parallel/base.py b/docs/finetrainers-src-codebase/finetrainers/parallel/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab04aeb71ed82db34f4154e7ba23b51ce2737579
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/parallel/base.py
@@ -0,0 +1,145 @@
+from contextlib import contextmanager
+from typing import Any, Callable, Dict, List, Optional
+
+import torch
+
+from finetrainers.trackers import DummyTracker, TrackerType, initialize_trackers
+
+
+class BaseParallelBackend:
+ r"""
+ Base class that contains properties and methods that should be implemented by different parallel backends.
+ """
+
+ def __init__(self):
+ self.tracker = None
+
+ def enable_determinism(self, seed: int) -> None:
+ raise NotImplementedError("Method `enable_determinism` must be implemented by subclass.")
+
+ def apply_ddp(self, *args, **kwargs) -> torch.nn.Module:
+ raise NotImplementedError("Method `apply_ddp` must be implemented by subclass.")
+
+ def apply_fsdp2(self, *args, **kwargs) -> torch.nn.Module:
+ raise NotImplementedError("Method `apply_fsdp2` must be implemented by subclass.")
+
+ def apply_context_parallel(self, *args, **kwargs) -> torch.nn.Module:
+ raise NotImplementedError("Method `apply_context_parallel` must be implemented by subclass.")
+
+ def prepare_model(self, *args, **kwargs) -> Any:
+ raise NotImplementedError("Method `prepare_model` must be implemented by subclass.")
+
+ def prepare_dataset(self, *args, **kwargs) -> Any:
+ raise NotImplementedError("Method `prepare_dataset` must be implemented by subclass.")
+
+ def prepare_dataloader(self, *args, **kwargs) -> Any:
+ raise NotImplementedError("Method `prepare_dataloader` must be implemented by subclass.")
+
+ def prepare_optimizer(self, *args, **kwargs) -> Any:
+ raise NotImplementedError("Method `prepare_optimizer` must be implemented by subclass.")
+
+ def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh:
+ raise NotImplementedError("Method `get_mesh` must be implemented by subclass.")
+
+ def get_checkpointer(self, *args, **kwargs) -> None:
+ raise NotImplementedError("Method `get_checkpointer` must be implemented by subclass.")
+
+ def initialize_trackers(
+ self, trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str
+ ) -> TrackerType:
+ if self.is_main_process:
+ self.tracker = initialize_trackers(trackers, experiment_name, config, log_dir)
+ else:
+ self.tracker = DummyTracker()
+
+ def log(self, metrics: Dict[str, Any], step: int) -> None:
+ if self.is_main_process:
+ self.tracker.log(metrics, step)
+
+ def wait_for_everyone(self):
+ raise NotImplementedError("Method `wait_for_everyone` must be implemented by subclass.")
+
+ @contextmanager
+ def main_process_first(self):
+ raise NotImplementedError("Method `main_process_first` must be implemented by subclass.")
+
+ def destroy(self):
+ raise NotImplementedError("Method `destroy` must be implemented by subclass.")
+
+ @property
+ def world_size(self):
+ raise NotImplementedError("Method `world_size` must be implemented by subclass.")
+
+ @property
+ def rank(self):
+ raise NotImplementedError("Method `rank` must be implemented by subclass.")
+
+ @property
+ def local_rank(self):
+ raise NotImplementedError("Method `local_rank` must be implemented by subclass.")
+
+ @property
+ def is_main_process(self):
+ raise NotImplementedError("Method `is_main_process` must be implemented by subclass.")
+
+ @property
+ def is_local_main_process(self):
+ raise NotImplementedError("Method `is_local_main_process` must be implemented by subclass.")
+
+ @property
+ def device(self):
+ raise NotImplementedError("Method `device` must be implemented by subclass.")
+
+ @property
+ def pipeline_parallel_enabled(self):
+ raise NotImplementedError("Property `pipeline_parallel_enabled` must be implemented by subclass.")
+
+ @property
+ def data_parallel_enabled(self):
+ raise NotImplementedError("Property `data_parallel_enabled` must be implemented by subclass.")
+
+ @property
+ def data_replication_enabled(self):
+ raise NotImplementedError("Property `data_replication_enabled` must be implemented by subclass.")
+
+ @property
+ def data_sharding_enabled(self):
+ raise NotImplementedError("Property `data_sharding_enabled` must be implemented by subclass.")
+
+ @property
+ def context_parallel_enabled(self):
+ raise NotImplementedError("Property `context_parallel_enabled` must be implemented by subclass.")
+
+ @property
+ def tensor_parallel_enabled(self):
+ raise NotImplementedError("Property `tensor_parallel_enabled` must be implemented by subclass.")
+
+
+class BaseCheckpointer:
+ r"""
+ Base class that contains properties and methods that should be implemented by different parallel backends.
+ """
+
+ def __init__(
+ self,
+ dataloader: torch.utils.data.DataLoader,
+ model_parts: List[torch.nn.Module],
+ optimizers: Any,
+ schedulers: Any,
+ states: Dict[str, Any],
+ checkpointing_steps: int,
+ checkpointing_limit: int,
+ output_dir: str,
+ enable: bool = True,
+ _callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None,
+ _prefix: str = "finetrainers_step",
+ *args,
+ **kwargs,
+ ) -> None:
+ raise NotImplementedError("Method `__init__` must be implemented by subclass.")
+
+ def save(self, step: int, force: bool, *, _device: Optional[torch.device] = None, _is_main_process: bool) -> str:
+ raise NotImplementedError("Method `save` must be implemented by subclass.")
+
+ def load(self, step: int = -1) -> bool:
+ raise NotImplementedError("Method `load` must be implemented by subclass.")
diff --git a/docs/finetrainers-src-codebase/finetrainers/parallel/deepspeed.py b/docs/finetrainers-src-codebase/finetrainers/parallel/deepspeed.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b9f54d66ec1941ffc44d6239b305cc397ce61d4
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/parallel/deepspeed.py
@@ -0,0 +1,7 @@
+from .base import BaseParallelBackend
+
+
+class DeepspeedParallelBackend(BaseParallelBackend):
+ def __init__(self):
+ # TODO(aryan)
+ raise NotImplementedError("DeepspeedParallelBackend is not implemented yet.")
diff --git a/docs/finetrainers-src-codebase/finetrainers/parallel/ptd.py b/docs/finetrainers-src-codebase/finetrainers/parallel/ptd.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a95b1a95781913a4b682b2a57159416d8d7443b
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/parallel/ptd.py
@@ -0,0 +1,709 @@
+import datetime
+import functools
+import os
+import pathlib
+import shutil
+import time
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
+
+import datasets.distributed
+import torch
+import torch.distributed._functional_collectives
+import torch.distributed.checkpoint
+import torch.distributed.checkpoint.stateful
+from diffusers.hooks import HookRegistry, ModelHook
+from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
+from torch.distributed._composable.replicate import replicate
+from torch.distributed.checkpoint.state_dict import (
+ StateDictOptions,
+ get_model_state_dict,
+ set_model_state_dict,
+)
+from torch.distributed.tensor import DTensor, Shard
+
+from finetrainers._metadata import ContextParallelModelPlan, CPInput, CPOutput, TransformerRegistry
+from finetrainers.data import DPDataLoader
+from finetrainers.logging import get_logger
+from finetrainers.utils import enable_determinism, get_device_info, get_submodule_by_name, unwrap_module
+from finetrainers.utils._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES
+
+from .base import BaseCheckpointer, BaseParallelBackend
+
+
+if TYPE_CHECKING:
+ from finetrainers import optimizer
+
+
+_device_type, _device_module = get_device_info()
+logger = get_logger()
+
+
+class PytorchDTensorParallelBackend(BaseParallelBackend):
+ def __init__(
+ self,
+ world_size: int,
+ pp_degree: int = 1,
+ dp_degree: int = 1,
+ dp_shards: int = -1,
+ cp_degree: int = 1,
+ tp_degree: int = 1,
+ backend: str = "nccl",
+ timeout: int = 180,
+ logging_dir: Optional[str] = None,
+ output_dir: Optional[str] = None,
+ gradient_accumulation_steps: Optional[int] = None,
+ ) -> None:
+ super().__init__()
+
+ self._world_size = world_size
+ self._pp_degree = pp_degree
+ self._dp_degree = dp_degree
+ self._dp_shards = dp_shards
+ self._cp_degree = cp_degree
+ self._tp_degree = tp_degree
+ self._output_dir = pathlib.Path(output_dir) if output_dir is not None else None
+ self._logging_dir = (
+ self._output_dir / logging_dir if output_dir is not None and logging_dir is not None else None
+ )
+ self._backend = backend
+ self._timeout = timeout
+
+ for degree in [pp_degree, dp_degree, dp_shards, cp_degree, tp_degree]:
+ if degree < 1:
+ raise ValueError(f"Parallel degree must be at least 1, got {degree}.")
+
+ if dp_shards * pp_degree * dp_degree * cp_degree * tp_degree != world_size:
+ raise ValueError(
+ f"World size {world_size} must be divisible by the product of all parallel degrees and data parallel shards."
+ )
+
+ torch.distributed.init_process_group(backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout))
+ _device_module.set_device(self.local_rank)
+
+ logger.info(
+ f"Initialized parallel state with:\n"
+ f" - World size: {world_size}\n"
+ f" - Pipeline parallel degree: {pp_degree}\n"
+ f" - Data parallel degree: {dp_degree}\n"
+ f" - Context parallel degree: {cp_degree}\n"
+ f" - Tensor parallel degree: {tp_degree}\n"
+ f" - Data parallel shards: {dp_shards}\n"
+ )
+
+ self._mesh: torch.distributed.DeviceMesh = None
+
+ def enable_determinism(self, seed):
+ world_mesh = self.get_mesh()
+ enable_determinism(seed, world_mesh)
+
+ def apply_ddp(
+ self, model: torch.nn.Module, device_mesh: Optional[torch.distributed.DeviceMesh] = None
+ ) -> torch.nn.Module:
+ if device_mesh is None:
+ device_mesh = self.get_mesh()
+ apply_ddp(model, device_mesh)
+ logger.debug("Applied PytorchDTensorParallel::apply_ddp to model.")
+ return model
+
+ def apply_fsdp2(
+ self,
+ model: torch.nn.Module,
+ param_dtype: torch.dtype,
+ reduce_dtype: torch.dtype,
+ output_dtype: torch.dtype,
+ pp_enabled: bool = False,
+ cpu_offload: bool = False,
+ device_mesh: Optional[torch.distributed.DeviceMesh] = None,
+ ) -> torch.nn.Module:
+ if device_mesh is None:
+ device_mesh = self.get_mesh()
+ apply_fsdp2(model, device_mesh, param_dtype, reduce_dtype, output_dtype, pp_enabled, cpu_offload)
+ logger.debug("Applied PytorchDTensorParallel::apply_fsdp2 to model.")
+ return model
+
+ def apply_context_parallel(
+ self, model: torch.nn.Module, device_mesh: Optional[torch.distributed.DeviceMesh] = None
+ ) -> torch.nn.Module:
+ if device_mesh is None:
+ device_mesh = self.get_mesh()
+ apply_context_parallel(model, device_mesh)
+ logger.debug("Applied PytorchDTensorParallel::apply_context_parallel to model.")
+ return model
+
+ def prepare_model(self, model: torch.nn.Module) -> torch.nn.Module:
+ return model
+
+ def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset:
+ if self._dp_degree == 1:
+ return dataset
+ dp_mesh = self.get_mesh()["dp_replicate"]
+ dp_local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size()
+ dataset._data = datasets.distributed.split_dataset_by_node(dataset._data, dp_local_rank, dp_world_size)
+ logger.debug("PytorchDTensorParallelBackend::prepare_dataset completed!")
+ return dataset
+
+ def prepare_dataloader(
+ self, dataset: torch.utils.data.IterableDataset, batch_size: int, num_workers: int, pin_memory: bool
+ ) -> DPDataLoader:
+ if self._dp_degree == 1:
+ dp_local_rank = 0
+ else:
+ dp_mesh = self.get_mesh()["dp_replicate"]
+ dp_local_rank = dp_mesh.get_local_rank()
+ dataloader = DPDataLoader(dp_local_rank, dataset, batch_size=batch_size, num_workers=num_workers)
+ logger.debug("PytorchDTensorParallelBackend::prepare_dataloader completed!")
+ return dataloader
+
+ def prepare_optimizer(self, optimizer, lr_scheduler):
+ logger.debug("PytorchDTensorParallelBackend::prepare_optimizer completed!")
+ return optimizer, lr_scheduler
+
+ def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh:
+ def _get_mesh():
+ if name is None:
+ return self._mesh
+ try:
+ return self._mesh[name]
+ except (KeyError, RuntimeError):
+ if self._mesh.ndim == 0:
+ return None
+ return self._mesh
+
+ if self._mesh is not None:
+ return _get_mesh()
+
+ mesh_list = [
+ ("pp", self._pp_degree),
+ ("dp_replicate", self._dp_degree),
+ ("dp_shard", self._dp_shards),
+ ("cp", self._cp_degree),
+ ("tp", self._tp_degree),
+ ]
+ mesh_list = [(name, degree) for name, degree in mesh_list if degree > 1]
+ names = [x[0] for x in mesh_list]
+ degrees = [x[1] for x in mesh_list]
+ mesh = torch.distributed.device_mesh.init_device_mesh(_device_type, mesh_shape=degrees, mesh_dim_names=names)
+
+ dp_mesh_names, dp_cp_mesh_names, dp_shard_cp_mesh_names = [], [], []
+
+ if self.data_replication_enabled:
+ dp_mesh_names.append("dp_replicate")
+ dp_cp_mesh_names.append("dp_replicate")
+ if self.data_sharding_enabled:
+ dp_mesh_names.append("dp_shard")
+ dp_cp_mesh_names.append("dp_shard")
+ dp_shard_cp_mesh_names.append("dp_shard")
+ if self.context_parallel_enabled:
+ dp_cp_mesh_names.append("cp")
+ dp_shard_cp_mesh_names.append("cp")
+
+ if len(dp_mesh_names) > 0:
+ mesh[tuple(dp_mesh_names)]._flatten(mesh_dim_name="dp")
+ if len(dp_cp_mesh_names) > 0:
+ mesh[tuple(dp_cp_mesh_names)]._flatten(mesh_dim_name="dp_cp")
+ if len(dp_shard_cp_mesh_names) > 0:
+ mesh[tuple(dp_shard_cp_mesh_names)]._flatten(mesh_dim_name="dp_shard_cp")
+
+ logger.debug(f"Device mesh: {mesh}")
+ self._mesh = mesh
+ return _get_mesh()
+
+ def get_checkpointer(self, *args, **kwargs):
+ return PTDCheckpointer(*args, **kwargs)
+
+ @property
+ def world_size(self):
+ return torch.distributed.get_world_size()
+
+ @property
+ def rank(self):
+ return torch.distributed.get_rank()
+
+ @property
+ def local_rank(self):
+ return int(os.environ.get("LOCAL_RANK", 0))
+
+ @property
+ def is_main_process(self):
+ r"""Returns `True` if the current process is the main process on the master node."""
+ return self.rank == 0
+
+ @property
+ def is_local_main_process(self):
+ r"""Returns `True` if the current process is the main process on local node."""
+ return self.local_rank == 0
+
+ @property
+ def device(self):
+ return torch.device(_device_type, self.local_rank)
+
+ def wait_for_everyone(self):
+ return torch.distributed.barrier()
+
+ # @contextmanager
+ # def main_process_first(self):
+ # if self.is_main_process:
+ # yield
+ # self.wait_for_everyone()
+ # else:
+ # self.wait_for_everyone()
+ # yield
+
+ def destroy(self):
+ if self.is_main_process and self.tracker is not None:
+ self.tracker.finish()
+ return torch.distributed.destroy_process_group()
+
+ @property
+ def pipeline_parallel_enabled(self):
+ return self._pp_degree > 1
+
+ @property
+ def data_parallel_enabled(self):
+ return self._dp_degree > 1 or self._dp_shards > 1
+
+ @property
+ def data_replication_enabled(self):
+ return self._dp_degree > 1
+
+ @property
+ def data_sharding_enabled(self):
+ return self._dp_shards > 1
+
+ @property
+ def context_parallel_enabled(self):
+ return self._cp_degree > 1
+
+ @property
+ def tensor_parallel_enabled(self):
+ return self._tp_degree > 1
+
+
+class ModelWrapper(torch.distributed.checkpoint.stateful.Stateful):
+ def __init__(self, model: Union[torch.nn.Module, List[torch.nn.Module]]) -> None:
+ self.model = [model] if isinstance(model, torch.nn.Module) else model
+
+ def state_dict(self) -> Dict[str, Any]:
+ return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}
+
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+ func = functools.partial(
+ set_model_state_dict,
+ model_state_dict=state_dict,
+ options=StateDictOptions(strict=False),
+ )
+ list(map(func, self.model))
+
+
+class PTDCheckpointer(BaseCheckpointer):
+ def __init__(
+ self,
+ dataloader: torch.utils.data.DataLoader,
+ model_parts: List[torch.nn.Module],
+ optimizers: "optimizer.OptimizerWrapper",
+ schedulers: "optimizer.SchedulerWrapper",
+ states: Dict[str, Any],
+ checkpointing_steps: int,
+ checkpointing_limit: int,
+ output_dir: str,
+ enable: bool = True,
+ _callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None,
+ _prefix: str = "finetrainers_step",
+ ) -> None:
+ self.states = states
+ self.states.update(
+ {
+ "model": ModelWrapper(model_parts),
+ "optimizer": optimizers,
+ "dataloader": dataloader,
+ }
+ )
+ self.states.update(schedulers.get_lr_scheduler_state())
+
+ self.checkpointing_steps = checkpointing_steps
+ self.checkpointing_limit = checkpointing_limit
+ self.output_dir = pathlib.Path(output_dir)
+ self.enable = enable
+ self._callback_fn = _callback_fn
+ self._prefix = _prefix
+
+ logger.info(f"Checkpointing enabled. Checkpoints will be stored in '{self.output_dir}'")
+
+ def save(self, step: int = -1, force: bool = False, *, _device: torch.device, _is_main_process: bool) -> str:
+ if not self._should_checkpoint(step, force):
+ return None
+
+ checkpoint_dir = self._get_checkpoint_dir(step)
+ begin_time = time.monotonic()
+ torch.distributed.checkpoint.save(self.states, checkpoint_id=checkpoint_dir.as_posix())
+ end_time = time.monotonic()
+ logger.info(
+ f"Saved checkpoint in {end_time - begin_time:.2f} seconds at step {step}. Directory: {checkpoint_dir}"
+ )
+ self._purge_stale_checkpoints()
+
+ state_dicts = [
+ gather_state_dict_on_cpu_rank0(model, _device, is_main_process=_is_main_process)
+ for model in self.states["model"].model
+ ]
+ if self._callback_fn is not None:
+ list(map(self._callback_fn, state_dicts))
+
+ return checkpoint_dir.as_posix()
+
+ def load(self, step: int = -1) -> bool:
+ if not self.enable:
+ return False
+ if not self.output_dir.exists():
+ return False
+ if step != -1 and not self._get_checkpoint_dir(step).exists():
+ return False
+
+ if step == -1:
+ latest_checkpoint_dir = self._find_latest_checkpoint_dir()
+ if latest_checkpoint_dir is None:
+ return False
+ step = int(latest_checkpoint_dir.name.split("_")[-1])
+
+ checkpoint_dir = self._get_checkpoint_dir(step)
+ logger.info(f"Loading checkpoint from '{checkpoint_dir}' at step {step}")
+
+ # For step 0, optimizers/schedulers are not available as they are created during training after first step
+ states = {"model": self.states["model"]} if step == 0 else self.states
+
+ # See bug: https://github.com/pytorch/pytorch/pull/138575
+ original_stateful_states = {
+ k: v for k, v in states.items() if isinstance(v, torch.distributed.checkpoint.stateful.Stateful)
+ }
+ begin_time = time.monotonic()
+ torch.distributed.checkpoint.load(states, checkpoint_id=checkpoint_dir.as_posix())
+ end_time = time.monotonic()
+ logger.info(f"Loaded checkpoint in {end_time - begin_time:.2f} seconds.")
+
+ # bugfix from above: restore the original stateful objects, whose states were already updated in-place by dcp.load()
+ states.update(original_stateful_states)
+
+ return True
+
+ def _should_checkpoint(self, step: int, force: bool) -> bool:
+ if not self.enable:
+ return False
+ if not force:
+ if step % self.checkpointing_steps != 0:
+ return False
+ return True
+
+ def _get_checkpoint_dir(self, step: int) -> pathlib.Path:
+ return self.output_dir / f"{self._prefix}_{step}"
+
+ def _find_latest_checkpoint_dir(self) -> Optional[pathlib.Path]:
+ checkpoints = sorted(self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]))
+ return checkpoints[-1] if len(checkpoints) > 0 else None
+
+ def _purge_stale_checkpoints(self) -> None:
+ if self.checkpointing_limit is None or self.checkpointing_limit <= 0:
+ return
+ checkpoints = sorted(
+ self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]), reverse=True
+ )
+ for checkpoint in checkpoints[self.checkpointing_limit :]:
+ logger.info(f"Deleting stale checkpoint: {checkpoint}")
+ shutil.rmtree(checkpoint, ignore_errors=True)
+
+
+def gather_state_dict_on_cpu_rank0(
+ model, device: Optional[torch.device] = None, *, is_main_process: bool
+) -> Dict[str, Any]:
+ cpu_state_dict = {}
+ sharded_sd = model.state_dict()
+ for param_name, param in sharded_sd.items():
+ if param.is_cpu:
+ # Move back to device if offloaded to CPU
+ param = param.to(device)
+ if hasattr(param, "_local_tensor"):
+ # Gather DTensor
+ param = param.full_tensor()
+ if is_main_process:
+ cpu_state_dict[param_name] = param.cpu()
+ torch.distributed.barrier()
+ return cpu_state_dict
+
+
+# # Copied from pytorch (torch/distributed/checkpoint/format_utils.py) to support callbacks to modify state_dict
+# def dcp_to_torch_save(
+# dcp_checkpoint_dir: Union[str, os.PathLike],
+# torch_save_path: Union[str, os.PathLike],
+# callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None,
+# ):
+# """
+# Given a directory containing a DCP checkpoint, this function will convert it into a
+# Torch save file.
+
+# Args:
+# dcp_checkpoint_dir: Directory containing the DCP checkpoint.
+# torch_save_path: Filename to store the converted Torch save file.
+# callback_fn: Optional callback function that takes the state_dict as input and returns a modified state_dict.
+
+# .. warning::
+# To avoid OOM, it's recommended to only run this function on a single rank.
+# """
+# state_dict = {}
+# _load_state_dict(
+# state_dict,
+# storage_reader=FileSystemReader(dcp_checkpoint_dir),
+# planner=_EmptyStateDictLoadPlanner(),
+# no_dist=True,
+# )
+# if callback_fn is not None:
+# state_dict = callback_fn(state_dict)
+# torch.save(state_dict, torch_save_path)
+
+
+def apply_ddp(model: torch.nn.Module, dp_mesh: torch.distributed.device_mesh.DeviceMesh) -> None:
+ replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
+
+
+def apply_fsdp2(
+ model: torch.nn.Module,
+ dp_mesh: torch.distributed.device_mesh.DeviceMesh,
+ param_dtype: torch.dtype,
+ reduce_dtype: torch.dtype,
+ output_dtype: torch.dtype,
+ pp_enabled: bool = False,
+ cpu_offload: bool = False,
+) -> None:
+ """Apply FSDP2 on a model."""
+ mp_policy = MixedPrecisionPolicy(param_dtype, reduce_dtype, output_dtype, cast_forward_inputs=True)
+ fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
+
+ if cpu_offload:
+ fsdp_config["offload_policy"] = CPUOffloadPolicy(pin_memory=True)
+
+ def apply_fully_shard(blocks):
+ for layer_index, block in enumerate(blocks):
+ if pp_enabled:
+ # For PP, do not reshard after forward to avoid per-microbatch
+ # all-gathers, which can be expensive and non-overlapped
+ reshard_after_forward = False
+ else:
+ # As an optimization, do not reshard after forward for the last
+ # transformer block since FSDP would prefetch it immediately
+ reshard_after_forward = layer_index < len(blocks) - 1
+ fully_shard(block, **fsdp_config, reshard_after_forward=reshard_after_forward)
+
+ for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES:
+ blocks = getattr(model, transformer_block_name, None)
+ if blocks is not None:
+ apply_fully_shard(blocks)
+
+ fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
+
+
+def apply_context_parallel(
+ model: torch.nn.Module,
+ mesh: torch.distributed.device_mesh.DeviceMesh,
+ plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
+) -> None:
+ """Apply context parallel on a model."""
+ logger.debug(f"Applying context parallel with CP mesh: {mesh}")
+ model_cls = unwrap_module(model).__class__
+
+ if plan is None:
+ plan = TransformerRegistry.get(model_cls).cp_plan
+
+ for module_id, cp_model_plan in plan.items():
+ module = get_submodule_by_name(model, module_id)
+ if not isinstance(module, list):
+ module = [module]
+ logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(module)} modules")
+ for m in module:
+ registry = HookRegistry.check_if_exists_or_initialize(m)
+ if isinstance(cp_model_plan, list):
+ # Metadata can only be a list when it is a list of CPOutput
+ assert all(isinstance(x, CPOutput) for x in cp_model_plan)
+ hook = ContextParallelGatherHook(cp_model_plan, mesh)
+ hook_name = f"cp_output---{module_id}"
+ else:
+ hook = ContextParallelSplitHook(cp_model_plan, mesh)
+ hook_name = f"cp_input---{module_id}"
+ registry.register_hook(hook, hook_name)
+
+
+class ContextParallelSplitHook(ModelHook):
+ def __init__(self, metadata: ContextParallelModelPlan, mesh: torch.distributed.device_mesh.DeviceMesh) -> None:
+ super().__init__()
+ self.metadata = metadata
+ self.mesh = mesh
+
+ def pre_forward(self, module, *args, **kwargs):
+ args_list = list(args)
+
+ for param_identifier, cpm in self.metadata.items():
+ name = param_identifier.name
+ index = param_identifier.index
+
+ if isinstance(cpm, CPInput) and cpm.split_output:
+ continue
+
+ # Maybe the parameter was passed as a keyword argument
+ is_kwarg = True
+ input_val = kwargs.get(name, None)
+
+ # If not, maybe it was passed as a positional argument
+ if input_val is None and index is not None:
+ if index < len(args_list): # Ensure index is within bounds
+ input_val = args_list[index]
+ is_kwarg = False
+ else:
+ logger.warning(f"Index {index} out of bounds for args of length {len(args_list)}.")
+ continue # Skip if index is invalid
+
+ # Either the input_val is truly None, or argument is passed as normal argument
+ # but user forgot to specify the index when registering metadata
+ if input_val is None:
+ continue
+
+ # The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
+ # the output instead of input for a particular layer by setting split_output=True
+ if torch.is_tensor(input_val):
+ input_val = self._prepare_cp_input(input_val, cpm)
+
+ elif isinstance(input_val, (list, tuple)):
+ if len(input_val) != len(cpm):
+ raise ValueError(
+ f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}."
+ )
+ sharded_input_val = []
+ for i, x in enumerate(input_val):
+ if torch.is_tensor(x) and not cpm[i].split_output:
+ x = self._prepare_cp_input(x, cpm[i])
+ sharded_input_val.append(x)
+ input_val = sharded_input_val
+
+ else:
+ raise ValueError(f"Unsupported input type: {type(input_val)}")
+
+ if is_kwarg:
+ kwargs[name] = input_val
+ elif index is not None and index < len(args_list):
+ args_list[index] = input_val
+
+ return tuple(args_list), kwargs
+
+ def post_forward(self, module, output):
+ is_tensor = torch.is_tensor(output)
+ is_tensor_list = isinstance(output, (list, tuple)) and all(torch.is_tensor(x) for x in output)
+ if not is_tensor and not is_tensor_list:
+ raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
+ output = [output] if is_tensor else list(output)
+ for param_identifier, cpm in self.metadata.items():
+ if not isinstance(cpm, CPInput) or not cpm.split_output:
+ continue
+ index = param_identifier.index
+ if index >= len(output):
+ raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
+ current_output = output[index]
+ current_output = self._prepare_cp_input(current_output, cpm)
+ output[index] = current_output
+ return output[0] if is_tensor else tuple(output)
+
+ def _prepare_cp_input(self, x: torch.Tensor, cp_input: CPInput) -> torch.Tensor:
+ if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
+ raise ValueError(
+ f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
+ )
+ return _EquipartitionSharder.shard(x, cp_input.split_dim, self.mesh)
+
+
+class ContextParallelGatherHook(ModelHook):
+ def __init__(self, metadata: ContextParallelModelPlan, mesh: torch.distributed.device_mesh.DeviceMesh) -> None:
+ super().__init__()
+ self.metadata = metadata
+ self.mesh = mesh
+
+ def post_forward(self, module, output):
+ is_tensor = torch.is_tensor(output)
+ if is_tensor:
+ output = [output]
+ output = list(output)
+ assert len(output) == len(self.metadata), f"Expected {len(self.metadata)} outputs, but got {len(output)}."
+ for i, cpm in enumerate(self.metadata):
+ if cpm is None:
+ continue
+ output[i] = _EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.mesh)
+ return output[0] if is_tensor else tuple(output)
+
+
+class _ContextParallelSharder:
+ @classmethod
+ def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
+ raise NotImplementedError("_ContextParallelSharder::shard should be implemented in subclasses")
+
+ @classmethod
+ def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
+ raise NotImplementedError("_ContextParallelSharder::unshard should be implemented in subclasses")
+
+
+class _EquipartitionSharder(_ContextParallelSharder):
+ """
+ Shards the input tensor along the specified dimension into cp_mesh's world size chunks.
+ Essentially, rank_i gets the i-th chunk.
+
+ This sharding strategy should only be used when performing full attention. Otherwise, it will
+ have performance penalty. If using causal attention, please use _CausalSharder instead.
+ """
+
+ @classmethod
+ def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
+ assert tensor.size()[dim] % mesh.size() == 0
+ return tensor.chunk(mesh.size(), dim=dim)[mesh.get_local_rank()]
+
+ @classmethod
+ def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
+ tensor = tensor.contiguous()
+ # TODO(aryan): pass a shape here so that we can allow uneven sharding across seq dim
+ result = DTensor.from_local(tensor, mesh, placements=[Shard(dim)]).full_tensor()
+ return result
+
+
+# TODO(aryan): this class is untested
+class _CausalSharder(_ContextParallelSharder):
+ """
+ Shards the input tensor along the specified dimension into 2x cp_mesh's world size chunks.
+ Essentially, rank_i gets the i-th chunk and (2 * cp_world_size - 1 - i)-th chunk.
+
+ This sharding strategy improves the performance for causal attention, as it allows
+ equal distribution of computation across all ranks.
+
+ Causal attention mask:
+ ```
+ 1 0 0 0 <--- Group 0
+ 1 1 0 0 <--- Group 1
+ 1 1 1 0 <--- Group 1
+ 1 1 1 1 <--- Group 0
+ ```
+ """
+
+ @classmethod
+ def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
+ world_size = mesh.size()
+ rank = mesh.get_local_rank()
+ assert tensor.size()[dim] % (2 * world_size) == 0
+ chunks = tensor.chunk(2 * world_size, dim=dim)
+ i, j = rank, 2 * world_size - 1 - rank
+ return torch.cat((chunks[i], chunks[j]), dim=dim)
+
+ @classmethod
+ def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
+ tensor = tensor.contiguous()
+ world_size = mesh.size()
+ # TODO(aryan): pass a shape here so that we can allow uneven sharding across seq dim
+ all_tensors = DTensor.from_local(tensor, mesh, placements=[Shard(dim)]).full_tensor()
+ sliced_tensors = [st for t in all_tensors for st in t.chunk(2, dim=dim)]
+ ordered_tensors = list(sliced_tensors)
+ for i, t in enumerate(sliced_tensors):
+ if i % 2 == 0:
+ ordered_tensors[i // 2] = t
+ else:
+ ordered_tensors[world_size * 2 - (i // 2) - 1] = t
+ return torch.cat(ordered_tensors, dim=dim)
diff --git a/docs/finetrainers-src-codebase/finetrainers/parallel/utils.py b/docs/finetrainers-src-codebase/finetrainers/parallel/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a13ef10bd679d4443bea447eaba90a883b763c7e
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/parallel/utils.py
@@ -0,0 +1,19 @@
+import torch
+import torch.distributed._functional_collectives as funcol
+import torch.distributed.tensor
+
+
+def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: torch.distributed.device_mesh.DeviceMesh) -> float:
+ if isinstance(x, torch.distributed.tensor.DTensor):
+ # functional collectives do not support DTensor inputs
+ x = x.full_tensor()
+ assert x.numel() == 1 # required by `.item()`
+ return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item()
+
+
+def dist_max(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float:
+ return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.MAX.name, mesh=mesh)
+
+
+def dist_mean(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float:
+ return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.AVG.name, mesh=mesh)
diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/__init__.py b/docs/finetrainers-src-codebase/finetrainers/patches/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c81ba8da8f19b818ea21979a4ec237f9ee56aeb9
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/patches/__init__.py
@@ -0,0 +1,58 @@
+from typing import TYPE_CHECKING
+
+import torch
+
+from .dependencies.diffusers.peft import load_lora_weights
+
+
+if TYPE_CHECKING:
+ from finetrainers.args import BaseArgsType
+ from finetrainers.parallel import ParallelBackendType
+
+
+def perform_patches_for_training(args: "BaseArgsType", parallel_backend: "ParallelBackendType") -> None:
+ # To avoid circular imports
+ from finetrainers.config import ModelType, TrainingType
+
+ from .dependencies.diffusers import patch
+
+ # Modeling patches
+ patch_scaled_dot_product_attention()
+
+ patch.patch_diffusers_rms_norm_forward()
+
+ # LTX Video patches
+ if args.model_name == ModelType.LTX_VIDEO:
+ from .models.ltx_video import patch
+
+ patch.patch_transformer_forward()
+ if parallel_backend.tensor_parallel_enabled:
+ patch.patch_apply_rotary_emb_for_tp_compatibility()
+
+ # Wan patches
+ if args.model_name == ModelType.WAN and "transformer" in args.layerwise_upcasting_modules:
+ from .models.wan import patch
+
+ patch.patch_time_text_image_embedding_forward()
+
+ # LoRA patches
+ if args.training_type == TrainingType.LORA and len(args.layerwise_upcasting_modules) > 0:
+ from .dependencies.peft import patch
+
+ patch.patch_peft_move_adapter_to_device_of_base_layer()
+
+
+def perform_patches_for_inference(args: "BaseArgsType", parallel_backend: "ParallelBackendType") -> None:
+ # To avoid circular imports
+ from .dependencies.diffusers import patch
+
+ # Modeling patches
+ patch_scaled_dot_product_attention()
+
+ patch.patch_diffusers_rms_norm_forward()
+
+
+def patch_scaled_dot_product_attention():
+ from finetrainers.models.attention_dispatch import attention_dispatch
+
+ torch.nn.functional.scaled_dot_product_attention = attention_dispatch
diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/__init__.py b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/__init__.py b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/control.py b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/control.py
new file mode 100644
index 0000000000000000000000000000000000000000..baa45910659f6a79ad8d133cf76671482284b44a
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/control.py
@@ -0,0 +1,36 @@
+from contextlib import contextmanager
+from typing import List, Union
+
+import torch
+from diffusers.hooks import HookRegistry, ModelHook
+
+
+_CONTROL_CHANNEL_CONCATENATE_HOOK = "FINETRAINERS_CONTROL_CHANNEL_CONCATENATE_HOOK"
+
+
+class ControlChannelConcatenateHook(ModelHook):
+ def __init__(self, input_names: List[str], inputs: List[torch.Tensor], dims: List[int]):
+ self.input_names = input_names
+ self.inputs = inputs
+ self.dims = dims
+
+ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
+ for input_name, input_tensor, dim in zip(self.input_names, self.inputs, self.dims):
+ original_tensor = args[input_name] if isinstance(input_name, int) else kwargs[input_name]
+ control_tensor = torch.cat([original_tensor, input_tensor], dim=dim)
+ if isinstance(input_name, int):
+ args[input_name] = control_tensor
+ else:
+ kwargs[input_name] = control_tensor
+ return args, kwargs
+
+
+@contextmanager
+def control_channel_concat(
+ module: torch.nn.Module, input_names: List[Union[int, str]], inputs: List[torch.Tensor], dims: List[int]
+):
+ registry = HookRegistry.check_if_exists_or_initialize(module)
+ hook = ControlChannelConcatenateHook(input_names, inputs, dims)
+ registry.register_hook(hook, _CONTROL_CHANNEL_CONCATENATE_HOOK)
+ yield
+ registry.remove_hook(_CONTROL_CHANNEL_CONCATENATE_HOOK, recurse=False)
diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/patch.py b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a0c7952574b034039a0082caec50d4253a343ab
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/patch.py
@@ -0,0 +1,6 @@
+def patch_diffusers_rms_norm_forward() -> None:
+ import diffusers.models.normalization
+
+ from .rms_norm import _patched_rms_norm_forward
+
+ diffusers.models.normalization.RMSNorm.forward = _patched_rms_norm_forward
diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/peft.py b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/peft.py
new file mode 100644
index 0000000000000000000000000000000000000000..f625323b548e159717598f5b5990d6626f5fc3b0
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/peft.py
@@ -0,0 +1,61 @@
+import json
+from pathlib import Path
+from typing import Optional
+
+import safetensors.torch
+from diffusers import DiffusionPipeline
+from diffusers.loaders.lora_pipeline import _LOW_CPU_MEM_USAGE_DEFAULT_LORA
+from huggingface_hub import repo_exists, snapshot_download
+from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
+
+from finetrainers.logging import get_logger
+from finetrainers.utils import find_files
+
+
+logger = get_logger()
+
+
+def load_lora_weights(
+ pipeline: DiffusionPipeline, pretrained_model_name_or_path: str, adapter_name: Optional[str] = None, **kwargs
+) -> None:
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+
+ is_local_file_path = Path(pretrained_model_name_or_path).is_dir()
+ if not is_local_file_path:
+ does_repo_exist = repo_exists(pretrained_model_name_or_path, repo_type="model")
+ if not does_repo_exist:
+ raise ValueError(f"Model repo {pretrained_model_name_or_path} does not exist on the Hub or locally.")
+ else:
+ pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path, repo_type="model")
+
+ prefix = "transformer"
+ state_dict = pipeline.lora_state_dict(pretrained_model_name_or_path)
+ state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
+
+ file_list = find_files(pretrained_model_name_or_path, "*.safetensors", depth=1)
+ if len(file_list) == 0:
+ raise ValueError(f"No .safetensors files found in {pretrained_model_name_or_path}.")
+ if len(file_list) > 1:
+ logger.warning(
+ f"Multiple .safetensors files found in {pretrained_model_name_or_path}. Using the first one: {file_list[0]}."
+ )
+ with safetensors.torch.safe_open(file_list[0], framework="pt") as f:
+ metadata = f.metadata()
+ metadata = json.loads(metadata["lora_config"])
+
+ transformer = pipeline.transformer
+ if adapter_name is None:
+ adapter_name = "default"
+
+ lora_config = LoraConfig(**metadata)
+ inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
+ result = set_peft_model_state_dict(
+ transformer,
+ state_dict,
+ adapter_name=adapter_name,
+ ignore_mismatched_sizes=False,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+ logger.debug(
+ f"Loaded LoRA weights from {pretrained_model_name_or_path} into {pipeline.__class__.__name__}. Result: {result}"
+ )
diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/rms_norm.py b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/rms_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6e9a4ef14590665a44104f9cdf1651b26fe81b0
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/diffusers/rms_norm.py
@@ -0,0 +1,46 @@
+import torch
+import torch.nn as nn
+from diffusers.utils import is_torch_npu_available, is_torch_version
+
+
+def _patched_rms_norm_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if is_torch_npu_available():
+ import torch_npu
+
+ if self.weight is not None:
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+ hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
+ if self.bias is not None:
+ hidden_states = hidden_states + self.bias
+ elif is_torch_version(">=", "2.4"):
+ ### ===== =======
+ input_dtype = hidden_states.dtype
+ if self.weight is not None:
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+ hidden_states = nn.functional.rms_norm(
+ hidden_states, normalized_shape=(hidden_states.shape[-1],), weight=self.weight, eps=self.eps
+ )
+ if self.bias is not None:
+ hidden_states = hidden_states + self.bias
+ hidden_states = hidden_states.to(input_dtype)
+ ### ===== =====
+ else:
+ input_dtype = hidden_states.dtype
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
+
+ if self.weight is not None:
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+ hidden_states = hidden_states * self.weight
+ if self.bias is not None:
+ hidden_states = hidden_states + self.bias
+ else:
+ hidden_states = hidden_states.to(input_dtype)
+
+ return hidden_states
diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/peft/__init__.py b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/peft/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/peft/patch.py b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/peft/patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..4de4b1a965fa6c33ebc9acad81fb1dddc1ba8de2
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/patches/dependencies/peft/patch.py
@@ -0,0 +1,25 @@
+import functools
+
+from peft.tuners.tuners_utils import BaseTunerLayer
+
+from finetrainers.patches.utils import DisableTensorToDtype
+
+
+def patch_peft_move_adapter_to_device_of_base_layer() -> None:
+ _perform_patch_move_adapter_to_device_of_base_layer()
+
+
+def _perform_patch_move_adapter_to_device_of_base_layer() -> None:
+ BaseTunerLayer._move_adapter_to_device_of_base_layer = _patched_move_adapter_to_device_of_base_layer(
+ BaseTunerLayer._move_adapter_to_device_of_base_layer
+ )
+
+
+def _patched_move_adapter_to_device_of_base_layer(func) -> None:
+ # TODO(aryan): This is really unsafe probably and may break things. It works for now, but revisit and refactor.
+ @functools.wraps(func)
+ def wrapper(self, *args, **kwargs):
+ with DisableTensorToDtype():
+ return func(self, *args, **kwargs)
+
+ return wrapper
diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/models/__init__.py b/docs/finetrainers-src-codebase/finetrainers/patches/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/models/ltx_video/__init__.py b/docs/finetrainers-src-codebase/finetrainers/patches/models/ltx_video/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/models/ltx_video/patch.py b/docs/finetrainers-src-codebase/finetrainers/patches/models/ltx_video/patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e8caa803f0716280ff066d6e7865746344fb8e9
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/patches/models/ltx_video/patch.py
@@ -0,0 +1,127 @@
+from typing import Any, Dict, Optional, Tuple
+
+import diffusers
+import torch
+from diffusers import LTXVideoTransformer3DModel
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.utils.import_utils import is_torch_version
+
+
+def patch_transformer_forward() -> None:
+ _perform_ltx_transformer_forward_patch()
+
+
+def patch_apply_rotary_emb_for_tp_compatibility() -> None:
+ _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch()
+
+
+def _perform_ltx_transformer_forward_patch() -> None:
+ LTXVideoTransformer3DModel.forward = _patched_LTXVideoTransformer3D_forward
+
+
+def _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() -> None:
+ def apply_rotary_emb(x, freqs):
+ cos, sin = freqs
+ # ======== THIS IS CHANGED FROM THE ORIGINAL IMPLEMENTATION ========
+ # The change is made due to unsupported DTensor operation aten.ops.unbind
+ # FIXME: Once aten.ops.unbind support lands, this will no longer be required
+ # x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2]
+ x_real, x_imag = x.unflatten(2, (-1, 2)).chunk(2, dim=-1) # [B, S, H, D // 2]
+ # ==================================================================
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
+ return out
+
+ diffusers.models.transformers.transformer_ltx.apply_rotary_emb = apply_rotary_emb
+
+
+def _patched_LTXVideoTransformer3D_forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_attention_mask: torch.Tensor,
+ num_frames: int,
+ height: int,
+ width: int,
+ rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
+ return_dict: bool = True,
+ *args,
+ **kwargs,
+) -> torch.Tensor:
+ image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ batch_size = hidden_states.size(0)
+
+ # ===== This is modified compared to Diffusers =====
+ # This is done because the Diffusers pipeline will pass in a 1D tensor for timestep
+ if timestep.ndim == 1:
+ timestep = timestep.view(-1, 1, 1).expand(-1, *hidden_states.shape[1:-1], -1)
+ # ==================================================
+
+ temb, embedded_timestep = self.time_embed(
+ timestep.flatten(),
+ batch_size=batch_size,
+ hidden_dtype=hidden_states.dtype,
+ )
+
+ # ===== This is modified compared to Diffusers =====
+ # temb = temb.view(batch_size, -1, temb.size(-1))
+ # embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
+ # ==================================================
+ # This is done to make it possible to use per-token timestep embedding
+ temb = temb.view(batch_size, *hidden_states.shape[1:-1], temb.size(-1))
+ embedded_timestep = embedded_timestep.view(batch_size, *hidden_states.shape[1:-1], embedded_timestep.size(-1))
+ # ==================================================
+
+ hidden_states = self.proj_in(hidden_states)
+
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
+
+ for block in self.transformer_blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ encoder_attention_mask,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+
+ scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
+ shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
+
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = hidden_states * (1 + scale) + shift
+ output = self.proj_out(hidden_states)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/models/wan/__init__.py b/docs/finetrainers-src-codebase/finetrainers/patches/models/wan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/models/wan/patch.py b/docs/finetrainers-src-codebase/finetrainers/patches/models/wan/patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5c44ae42637fb9c6fc0a9803930f1728a92b693
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/patches/models/wan/patch.py
@@ -0,0 +1,33 @@
+from typing import Optional
+
+import diffusers
+import torch
+
+
+def patch_time_text_image_embedding_forward() -> None:
+ _patch_time_text_image_embedding_forward()
+
+
+def _patch_time_text_image_embedding_forward() -> None:
+ diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding.forward = (
+ _patched_WanTimeTextImageEmbedding_forward
+ )
+
+
+def _patched_WanTimeTextImageEmbedding_forward(
+ self,
+ timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+):
+ # Some code has been removed compared to original implementation in Diffusers
+ # Also, timestep is typed as that of encoder_hidden_states
+ timestep = self.timesteps_proj(timestep).type_as(encoder_hidden_states)
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
+ timestep_proj = self.time_proj(self.act_fn(temb))
+
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
+
+ return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
diff --git a/docs/finetrainers-src-codebase/finetrainers/patches/utils.py b/docs/finetrainers-src-codebase/finetrainers/patches/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d7f4726cc8183a461310570762ee95b5c4e6187
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/patches/utils.py
@@ -0,0 +1,18 @@
+import torch
+
+
+class DisableTensorToDtype:
+ def __enter__(self):
+ self.original_to = torch.Tensor.to
+
+ def modified_to(tensor, *args, **kwargs):
+ # remove dtype from args if present
+ args = [arg if not isinstance(arg, torch.dtype) else None for arg in args]
+ if "dtype" in kwargs:
+ kwargs.pop("dtype")
+ return self.original_to(tensor, *args, **kwargs)
+
+ torch.Tensor.to = modified_to
+
+ def __exit__(self, *args, **kwargs):
+ torch.Tensor.to = self.original_to
diff --git a/docs/finetrainers-src-codebase/finetrainers/processors/__init__.py b/docs/finetrainers-src-codebase/finetrainers/processors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..82a99170fa50b01e2767f21591db37e8e3046883
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/processors/__init__.py
@@ -0,0 +1,23 @@
+from typing import Any, Dict, List, Optional
+
+from .base import ProcessorMixin
+from .canny import CannyProcessor
+from .clip import CLIPPooledProcessor
+from .glm import CogView4GLMProcessor
+from .llama import LlamaProcessor
+from .t5 import T5Processor
+from .text import CaptionEmbeddingDropoutProcessor, CaptionTextDropoutProcessor
+
+
+class CopyProcessor(ProcessorMixin):
+ r"""Processor that copies the input data unconditionally to the output."""
+
+ def __init__(self, output_names: List[str] = None, input_names: Optional[Dict[str, Any]] = None):
+ super().__init__()
+
+ self.output_names = output_names
+ self.input_names = input_names
+ assert len(output_names) == 1
+
+ def forward(self, input: Any) -> Any:
+ return {self.output_names[0]: input}
diff --git a/docs/finetrainers-src-codebase/finetrainers/processors/base.py b/docs/finetrainers-src-codebase/finetrainers/processors/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea8989fd70f359268620a16d1cca885983eed02d
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/processors/base.py
@@ -0,0 +1,24 @@
+import inspect
+from typing import Any, Dict, List
+
+
+class ProcessorMixin:
+ def __init__(self) -> None:
+ self._forward_parameter_names = inspect.signature(self.forward).parameters.keys()
+ self.output_names: List[str] = None
+ self.input_names: Dict[str, Any] = None
+
+ def __call__(self, *args, **kwargs) -> Any:
+ shallow_copy_kwargs = dict(kwargs.items())
+ if self.input_names is not None:
+ for k, v in self.input_names.items():
+ if k in shallow_copy_kwargs:
+ shallow_copy_kwargs[v] = shallow_copy_kwargs.pop(k)
+ acceptable_kwargs = {k: v for k, v in shallow_copy_kwargs.items() if k in self._forward_parameter_names}
+ output = self.forward(*args, **acceptable_kwargs)
+ if "__drop__" in output:
+ output.pop("__drop__")
+ return output
+
+ def forward(self, *args, **kwargs) -> Dict[str, Any]:
+ raise NotImplementedError("ProcessorMixin::forward method should be implemented by the subclass.")
diff --git a/docs/finetrainers-src-codebase/finetrainers/processors/canny.py b/docs/finetrainers-src-codebase/finetrainers/processors/canny.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4bf95e8c753e7fb539e8b0fde15788445fafd8d
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/processors/canny.py
@@ -0,0 +1,78 @@
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+
+from ..utils.import_utils import is_kornia_available
+from .base import ProcessorMixin
+
+
+if is_kornia_available():
+ import kornia
+
+
+class CannyProcessor(ProcessorMixin):
+ r"""
+ Processor for obtaining the Canny edge detection of an image.
+
+ Args:
+ output_names (`List[str]`):
+ The names of the outputs that the processor should return. The first output is the Canny edge detection of
+ the input image.
+ """
+
+ def __init__(
+ self,
+ output_names: List[str] = None,
+ input_names: Optional[Dict[str, Any]] = None,
+ device: Optional[torch.device] = None,
+ ):
+ super().__init__()
+
+ self.output_names = output_names
+ self.input_names = input_names
+ self.device = device
+ assert len(output_names) == 1
+
+ def forward(self, input: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]]) -> torch.Tensor:
+ r"""
+ Obtain the Canny edge detection of the input image.
+
+ Args:
+ input (`torch.Tensor`, `PIL.Image.Image`, or `List[PIL.Image.Image]`):
+ The input tensor, image or list of images for which the Canny edge detection should be obtained.
+ If a tensor, must be a 3D (CHW) or 4D (BCHW) or 5D (BTCHW) tensor. The input tensor should have
+ values in the range [0, 1].
+
+ Returns:
+ torch.Tensor:
+ The Canny edge detection of the input image. The output has the same shape as the input tensor. If
+ the input is an image, the output is a 3D tensor. If the input is a list of images, the output is a 5D
+ tensor. The output tensor has values in the range [0, 1].
+ """
+ if isinstance(input, PIL.Image.Image):
+ input = kornia.utils.image.image_to_tensor(np.array(input)).unsqueeze(0) / 255.0
+ input = input.to(self.device)
+ output = kornia.filters.canny(input)[1].repeat(1, 3, 1, 1).squeeze(0)
+ elif isinstance(input, list):
+ input = kornia.utils.image.image_list_to_tensor([np.array(img) for img in input]) / 255.0
+ output = kornia.filters.canny(input)[1].repeat(1, 3, 1, 1)
+ else:
+ ndim = input.ndim
+ assert ndim in [3, 4, 5]
+
+ batch_size = 1 if ndim == 3 else input.size(0)
+
+ if ndim == 3:
+ input = input.unsqueeze(0) # [C, H, W] -> [1, C, H, W]
+ elif ndim == 5:
+ input = input.flatten(0, 1) # [B, F, C, H, W] -> [B*F, C, H, W]
+
+ output = kornia.filters.canny(input)[1].repeat(1, 3, 1, 1)
+ output = output[0] if ndim == 3 else output.unflatten(0, (batch_size, -1)) if ndim == 5 else output
+
+ # TODO(aryan): think about how one can pass parameters to the underlying function from
+ # a UI perspective. It's important to think about ProcessorMixin in terms of a Graph-based
+ # data processing pipeline.
+ return {self.output_names[0]: output}
diff --git a/docs/finetrainers-src-codebase/finetrainers/processors/clip.py b/docs/finetrainers-src-codebase/finetrainers/processors/clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..e58282b69fb4845f079b103122a732acf7348d14
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/processors/clip.py
@@ -0,0 +1,63 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, CLIPTokenizerFast
+
+from .base import ProcessorMixin
+
+
+class CLIPPooledProcessor(ProcessorMixin):
+ r"""
+ Processor for the Llama family of models. This processor is used to encode text inputs and return the embeddings
+ and attention masks for the input text.
+
+ Args:
+ output_names (`List[str]`):
+ The names of the outputs that the processor should return. The first output is the embeddings of the input
+ text and the second output is the attention mask for the input text.
+ """
+
+ def __init__(self, output_names: List[str] = None, input_names: Optional[Dict[str, Any]] = None) -> None:
+ super().__init__()
+
+ self.output_names = output_names
+ self.input_names = input_names
+
+ assert len(output_names) == 1
+
+ def forward(
+ self,
+ tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast],
+ text_encoder: CLIPTextModel,
+ caption: Union[str, List[str]],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ r"""
+ Encode the input text and return the embeddings and attention mask for the input text.
+
+ Args:
+ tokenizer (`Union[LlamaTokenizer, LlamaTokenizerFast]`):
+ The tokenizer used to tokenize the input text.
+ text_encoder (`LlamaModel`):
+ The text encoder used to encode the input text.
+ caption (`Union[str, List[str]]`):
+ The input text to be encoded.
+ """
+ if isinstance(caption, str):
+ caption = [caption]
+
+ device = text_encoder.device
+ dtype = text_encoder.dtype
+
+ text_inputs = tokenizer(
+ caption,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids.to(device)
+
+ prompt_embeds = text_encoder(text_input_ids, output_hidden_states=False).pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return {self.output_names[0]: prompt_embeds}
diff --git a/docs/finetrainers-src-codebase/finetrainers/processors/glm.py b/docs/finetrainers-src-codebase/finetrainers/processors/glm.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf742130bb7da8808710ec562c85d9c64a535cb6
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/processors/glm.py
@@ -0,0 +1,74 @@
+from typing import List, Tuple, Union
+
+import torch
+from transformers import AutoTokenizer, GlmModel
+
+from .base import ProcessorMixin
+
+
+class CogView4GLMProcessor(ProcessorMixin):
+ r"""
+ Processor for the GLM family of models. This processor is used to encode text inputs and return the embeddings
+ and attention masks for the input text.
+
+ This processor is specific to CogView4 but can be used with any other model.
+
+ Args:
+ output_names (`List[str]`):
+ The names of the outputs that the processor should return. The first output is the embeddings of the input
+ text and the second output is the attention mask for the input text.
+ """
+
+ def __init__(self, output_names: List[str]):
+ super().__init__()
+
+ self.output_names = output_names
+
+ assert len(self.output_names) == 1
+
+ def forward(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: GlmModel,
+ caption: Union[str, List[str]],
+ max_sequence_length: int,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ r"""
+ Encode the input text and return the embeddings and attention mask for the input text.
+
+ Args:
+ tokenizer (`AutoTokenizer`):
+ The tokenizer used to tokenize the input text.
+ text_encoder (`GlmModel`):
+ The text encoder used to encode the input text.
+ caption (`Union[str, List[str]]`):
+ The input text to be encoded.
+ max_sequence_length (`int`):
+ The maximum sequence length of the input text.
+ """
+ if isinstance(caption, str):
+ caption = [caption]
+
+ device = text_encoder.device
+ dtype = text_encoder.dtype
+
+ text_inputs = tokenizer(
+ caption,
+ padding="longest",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids.to(device)
+
+ current_length = text_input_ids.size(1)
+ pad_length = 16 - current_length % 16
+ if pad_length > 0:
+ pad_ids = text_input_ids.new_full((text_input_ids.shape[0], pad_length), fill_value=tokenizer.pad_token_id)
+ text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
+
+ prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True).hidden_states[-2]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return {self.output_names[0]: prompt_embeds}
diff --git a/docs/finetrainers-src-codebase/finetrainers/processors/llama.py b/docs/finetrainers-src-codebase/finetrainers/processors/llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..749e5f313541b92317279669faf915edeb9129c4
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/processors/llama.py
@@ -0,0 +1,118 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import LlamaModel, LlamaTokenizer, LlamaTokenizerFast
+
+from .base import ProcessorMixin
+
+
+DEFAULT_PROMPT_TEMPLATE = {
+ "template": (
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
+ "1. The main content and theme of the video."
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
+ "4. background environment, light, style and atmosphere."
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
+ ),
+ "crop_start": 95,
+}
+
+
+class LlamaProcessor(ProcessorMixin):
+ r"""
+ Processor for the Llama family of models. This processor is used to encode text inputs and return the embeddings
+ and attention masks for the input text.
+
+ Args:
+ output_names (`List[str]`):
+ The names of the outputs that the processor should return. The first output is the embeddings of the input
+ text and the second output is the attention mask for the input text.
+ """
+
+ def __init__(self, output_names: List[str] = None):
+ super().__init__()
+
+ self.output_names = output_names
+
+ assert len(output_names) == 2
+
+ def forward(
+ self,
+ tokenizer: Union[LlamaTokenizer, LlamaTokenizerFast],
+ text_encoder: LlamaModel,
+ caption: Union[str, List[str]],
+ max_sequence_length: int,
+ prompt_template: Optional[Dict[str, Any]] = None,
+ num_layers_to_skip: int = 2,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ r"""
+ Encode the input text and return the embeddings and attention mask for the input text.
+
+ Args:
+ tokenizer (`Union[LlamaTokenizer, LlamaTokenizerFast]`):
+ The tokenizer used to tokenize the input text.
+ text_encoder (`LlamaModel`):
+ The text encoder used to encode the input text.
+ caption (`Union[str, List[str]]`):
+ The input text to be encoded.
+ max_sequence_length (`int`):
+ The maximum sequence length of the input text.
+ prompt_template (`Optional[Dict[str, Any]]`):
+ The prompt template to be used to encode the input text.
+ """
+ if prompt_template is None:
+ prompt_template = DEFAULT_PROMPT_TEMPLATE
+ if isinstance(caption, str):
+ caption = [caption]
+
+ device = text_encoder.device
+ dtype = text_encoder.dtype
+
+ batch_size = len(caption)
+ caption = [prompt_template["template"].format(c) for c in caption]
+
+ crop_start = prompt_template.get("crop_start", None)
+ if crop_start is None:
+ prompt_template_input = tokenizer(
+ prompt_template["template"],
+ padding="max_length",
+ return_tensors="pt",
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_attention_mask=False,
+ )
+ crop_start = prompt_template_input["input_ids"].shape[-1]
+ # Remove <|eot_id|> token and placeholder {}
+ crop_start -= 2
+
+ max_sequence_length += crop_start
+ text_inputs = tokenizer(
+ caption,
+ max_length=max_sequence_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt",
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_attention_mask=True,
+ )
+ text_input_ids = text_inputs.input_ids.to(device)
+ prompt_attention_mask = text_inputs.attention_mask.bool().to(device)
+
+ prompt_embeds = text_encoder(
+ text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
+ ).hidden_states[-(num_layers_to_skip + 1)]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ if crop_start is not None and crop_start > 0:
+ prompt_embeds = prompt_embeds[:, crop_start:]
+ prompt_attention_mask = prompt_attention_mask[:, crop_start:]
+
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
+
+ return {
+ self.output_names[0]: prompt_embeds,
+ self.output_names[1]: prompt_attention_mask,
+ }
diff --git a/docs/finetrainers-src-codebase/finetrainers/processors/t5.py b/docs/finetrainers-src-codebase/finetrainers/processors/t5.py
new file mode 100644
index 0000000000000000000000000000000000000000..006aed2c18376c3ff1509bd6fadd57e48ea39350
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/processors/t5.py
@@ -0,0 +1,87 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import T5EncoderModel, T5Tokenizer, T5TokenizerFast
+
+from .base import ProcessorMixin
+
+
+class T5Processor(ProcessorMixin):
+ r"""
+ Processor for the T5 family of models. This processor is used to encode text inputs and return the embeddings
+ and attention masks for the input text.
+
+ Args:
+ output_names (`List[str]`):
+ The names of the outputs that the processor should return. The first output is the embeddings of the input
+ text and the second output is the attention mask for the input text.
+ """
+
+ def __init__(
+ self,
+ output_names: List[str],
+ input_names: Optional[Dict[str, Any]] = None,
+ *,
+ use_attention_mask: bool = False,
+ ):
+ super().__init__()
+
+ self.output_names = output_names
+ self.input_names = input_names
+ self.use_attention_mask = use_attention_mask
+
+ if input_names is not None:
+ assert len(input_names) <= 4
+ assert len(self.output_names) == 2
+
+ def forward(
+ self,
+ tokenizer: Union[T5Tokenizer, T5TokenizerFast],
+ text_encoder: T5EncoderModel,
+ caption: Union[str, List[str]],
+ max_sequence_length: int,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ r"""
+ Encode the input text and return the embeddings and attention mask for the input text.
+
+ Args:
+ tokenizer (`Union[T5Tokenizer, T5TokenizerFast]`):
+ The tokenizer used to tokenize the input text.
+ text_encoder (`T5EncoderModel`):
+ The text encoder used to encode the input text.
+ caption (`Union[str, List[str]]`):
+ The input text to be encoded.
+ max_sequence_length (`int`):
+ The maximum sequence length of the input text.
+ """
+ if isinstance(caption, str):
+ caption = [caption]
+
+ device = text_encoder.device
+ dtype = text_encoder.dtype
+
+ batch_size = len(caption)
+ text_inputs = tokenizer(
+ caption,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.bool().to(device)
+
+ te_mask = None
+ if self.use_attention_mask:
+ te_mask = prompt_attention_mask
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), te_mask)[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
+
+ return {
+ self.output_names[0]: prompt_embeds,
+ self.output_names[1]: prompt_attention_mask,
+ }
diff --git a/docs/finetrainers-src-codebase/finetrainers/processors/text.py b/docs/finetrainers-src-codebase/finetrainers/processors/text.py
new file mode 100644
index 0000000000000000000000000000000000000000..884284725004e285a817424ef4561b3aefeb466a
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/processors/text.py
@@ -0,0 +1,23 @@
+from typing import List, Union
+
+import torch
+
+import finetrainers.functional as FF
+
+from .base import ProcessorMixin
+
+
+class CaptionTextDropoutProcessor(ProcessorMixin):
+ def __init__(self, dropout_p: float = 0.0) -> None:
+ self.dropout_p = dropout_p
+
+ def forward(self, caption: Union[str, List[str]]) -> Union[str, List[str]]:
+ return FF.dropout_caption(caption, self.dropout_p)
+
+
+class CaptionEmbeddingDropoutProcessor(ProcessorMixin):
+ def __init__(self, dropout_p: float = 0.0) -> None:
+ self.dropout_p = dropout_p
+
+ def forward(self, embedding: torch.Tensor) -> torch.Tensor:
+ return FF.dropout_embeddings_to_zero(embedding, self.dropout_p)
diff --git a/docs/finetrainers-src-codebase/finetrainers/state.py b/docs/finetrainers-src-codebase/finetrainers/state.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a44b6d6df74139b5ee405cc90288ec58abda3bd
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/state.py
@@ -0,0 +1,66 @@
+import io
+from dataclasses import dataclass, field
+from typing import Any, Dict, List
+
+import torch
+import torch.distributed.checkpoint.stateful
+
+from .parallel import ParallelBackendType
+from .utils import get_device_info
+
+
+_device_type, _ = get_device_info()
+
+
+@dataclass
+class TrainState(torch.distributed.checkpoint.stateful.Stateful):
+ step: int = 0
+ observed_data_samples: int = 0
+ global_avg_losses: List[float] = field(default_factory=list)
+ global_max_losses: List[float] = field(default_factory=list)
+ log_steps: List[int] = field(default_factory=list)
+
+ def state_dict(self) -> Dict[str, Any]:
+ # Only checkpoint global_avg_losses and global_max_losses per log frequency
+ # to avoid sync overhead in every iteration.
+ global_avg_losses_bytes = io.BytesIO()
+ torch.save(self.global_avg_losses, global_avg_losses_bytes)
+ global_max_losses_bytes = io.BytesIO()
+ torch.save(self.global_max_losses, global_max_losses_bytes)
+ log_steps_bytes = io.BytesIO()
+ torch.save(self.log_steps, log_steps_bytes)
+ return {
+ "step": torch.tensor(self.step, dtype=torch.int32),
+ "observed_data_samples": torch.tensor(self.observed_data_samples, dtype=torch.int32),
+ "global_avg_losses": global_avg_losses_bytes,
+ "global_max_losses": global_max_losses_bytes,
+ "log_steps": log_steps_bytes,
+ }
+
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+ state_dict["global_avg_losses"].seek(0)
+ state_dict["global_max_losses"].seek(0)
+ state_dict["log_steps"].seek(0)
+
+ self.step = state_dict["step"].item()
+ self.observed_data_samples = state_dict["observed_data_samples"].item()
+ self.global_avg_losses = torch.load(state_dict["global_avg_losses"], weights_only=False)
+ self.global_max_losses = torch.load(state_dict["global_max_losses"], weights_only=False)
+ self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)
+
+
+@dataclass
+class State:
+ # Parallel state
+ parallel_backend: ParallelBackendType = None
+
+ # Training state
+ train_state: TrainState = None
+ num_trainable_parameters: int = 0
+ generator: torch.Generator = None
+
+ # Hub state
+ repo_id: str = None
+
+ # Artifacts state
+ output_dir: str = None
diff --git a/docs/finetrainers-src-codebase/finetrainers/trackers.py b/docs/finetrainers-src-codebase/finetrainers/trackers.py
new file mode 100644
index 0000000000000000000000000000000000000000..68a53c5adc5934b8a1a802a1e48d2e5c5323b240
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/trackers.py
@@ -0,0 +1,145 @@
+import contextlib
+import copy
+import pathlib
+from enum import Enum
+from typing import Any, Dict, List, Optional, Union
+
+from .logging import get_logger
+from .utils import Timer, TimerDevice
+
+
+logger = get_logger()
+
+
+class BaseTracker:
+ r"""Base class for loggers. Does nothing by default, so it is useful when you want to disable logging."""
+
+ def __init__(self):
+ self._timed_metrics = {}
+
+ @contextlib.contextmanager
+ def timed(self, name: str, device: TimerDevice = TimerDevice.CPU, device_sync: bool = False):
+ r"""Context manager to track time for a specific operation."""
+ timer = Timer(name, device, device_sync)
+ timer.start()
+ yield timer
+ timer.end()
+ elapsed_time = timer.elapsed_time
+ if name in self._timed_metrics:
+ # If the timer name already exists, add the elapsed time to the existing value since a log has not been invoked yet
+ self._timed_metrics[name] += elapsed_time
+ else:
+ self._timed_metrics[name] = elapsed_time
+
+ def log(self, metrics: Dict[str, Any], step: int) -> None:
+ pass
+
+ def finish(self) -> None:
+ pass
+
+
+class DummyTracker(BaseTracker):
+ def __init__(self):
+ super().__init__()
+
+ def log(self, *args, **kwargs):
+ pass
+
+ def finish(self) -> None:
+ pass
+
+
+class WandbTracker(BaseTracker):
+ r"""Logger implementation for Weights & Biases."""
+
+ def __init__(self, experiment_name: str, log_dir: str, config: Optional[Dict[str, Any]] = None) -> None:
+ super().__init__()
+
+ import wandb
+
+ self.wandb = wandb
+
+ # WandB does not create a directory if it does not exist and instead starts using the system temp directory.
+ pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True)
+
+ self.run = wandb.init(project=experiment_name, dir=log_dir, config=config)
+ logger.info("WandB logging enabled")
+
+ def log(self, metrics: Dict[str, Any], step: int) -> None:
+ metrics = {**self._timed_metrics, **metrics}
+ self.run.log(metrics, step=step)
+ self._timed_metrics = {}
+
+ def finish(self) -> None:
+ self.run.finish()
+
+
+class SequentialTracker(BaseTracker):
+ r"""Sequential tracker that logs to multiple trackers in sequence."""
+
+ def __init__(self, trackers: List[BaseTracker]) -> None:
+ super().__init__()
+ self.trackers = trackers
+
+ @contextlib.contextmanager
+ def timed(self, name: str, device: TimerDevice = TimerDevice.CPU, device_sync: bool = False):
+ r"""Context manager to track time for a specific operation."""
+ timer = Timer(name, device, device_sync)
+ timer.start()
+ yield timer
+ timer.end()
+ elapsed_time = timer.elapsed_time
+ if name in self._timed_metrics:
+ # If the timer name already exists, add the elapsed time to the existing value since a log has not been invoked yet
+ self._timed_metrics[name] += elapsed_time
+ else:
+ self._timed_metrics[name] = elapsed_time
+ for tracker in self.trackers:
+ tracker._timed_metrics = copy.deepcopy(self._timed_metrics)
+
+ def log(self, metrics: Dict[str, Any], step: int) -> None:
+ for tracker in self.trackers:
+ tracker.log(metrics, step)
+ self._timed_metrics = {}
+
+ def finish(self) -> None:
+ for tracker in self.trackers:
+ tracker.finish()
+
+
+class Trackers(str, Enum):
+ r"""Enum for supported trackers."""
+
+ NONE = "none"
+ WANDB = "wandb"
+
+
+_SUPPORTED_TRACKERS = [tracker.value for tracker in Trackers.__members__.values()]
+
+
+def initialize_trackers(
+ trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str
+) -> Union[BaseTracker, SequentialTracker]:
+ r"""Initialize loggers based on the provided configuration."""
+
+ logger.info(f"Initializing trackers: {trackers}. Logging to {log_dir=}")
+
+ if len(trackers) == 0:
+ return BaseTracker()
+
+ if any(tracker_name not in _SUPPORTED_TRACKERS for tracker_name in set(trackers)):
+ raise ValueError(f"Unsupported tracker(s) provided. Supported trackers: {_SUPPORTED_TRACKERS}")
+
+ tracker_instances = []
+ for tracker_name in set(trackers):
+ if tracker_name == Trackers.NONE:
+ tracker = BaseTracker()
+ elif tracker_name == Trackers.WANDB:
+ tracker = WandbTracker(experiment_name, log_dir, config)
+ tracker_instances.append(tracker)
+
+ tracker = SequentialTracker(tracker_instances)
+ return tracker
+
+
+TrackerType = Union[BaseTracker, SequentialTracker, WandbTracker]
diff --git a/docs/finetrainers-src-codebase/finetrainers/trainer/__init__.py b/docs/finetrainers-src-codebase/finetrainers/trainer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..30a243509e26f71cc7deab8cd3aa03f6fa779e98
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/trainer/__init__.py
@@ -0,0 +1,2 @@
+from .control_trainer import ControlTrainer
+from .sft_trainer import SFTTrainer
diff --git a/docs/finetrainers-src-codebase/finetrainers/trainer/base.py b/docs/finetrainers-src-codebase/finetrainers/trainer/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..445fc89ee43283992b2d2f4263fd12ef9c0d2d46
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/trainer/base.py
@@ -0,0 +1,188 @@
+import contextlib
+import functools
+import os
+from typing import Callable, List, Tuple
+
+import torch
+import torch.backends
+from diffusers.hooks import HookRegistry, ModelHook
+
+from finetrainers import logging, parallel, patches
+from finetrainers.args import BaseArgsType
+from finetrainers.logging import get_logger
+from finetrainers.models.attention_dispatch import AttentionProvider, _AttentionProviderRegistry
+from finetrainers.state import State
+
+
+logger = get_logger()
+
+_LATEST_ACTIVE_MODULE_HOOK = "latest_active_module_hook"
+
+
+class Trainer:
+ def __init__(self, args: BaseArgsType):
+ self.args = args
+
+ self.state = State()
+
+ self._module_name_providers_training = _parse_attention_providers(args.attn_provider_training)
+ self._module_name_providers_inference = _parse_attention_providers(args.attn_provider_inference)
+
+ self._init_distributed()
+ self._init_config_options()
+
+ # Perform any patches that might be necessary for training to work as expected
+ patches.perform_patches_for_training(self.args, self.state.parallel_backend)
+
+ @contextlib.contextmanager
+ def attention_provider_ctx(self, training: bool = True):
+ name_providers_active = (
+ self._module_name_providers_training if training else self._module_name_providers_inference
+ )
+ name_providers_dict = dict(name_providers_active)
+ default_provider = _AttentionProviderRegistry._active_provider
+
+ all_registered_module_names = [
+ attr for attr in dir(self) if isinstance(getattr(self, attr, None), torch.nn.Module)
+ ]
+ for module_name in all_registered_module_names:
+ if module_name in name_providers_dict:
+ continue
+ name_providers_dict[module_name] = default_provider
+
+ module_providers_dict = {}
+ for module_name, provider in name_providers_dict.items():
+ module = getattr(self, module_name, None)
+ if module is not None:
+ module_providers_dict[module] = (module_name, provider)
+
+ # We don't want to immediately unset the attention provider to default after forward because if the
+ # model is being trained, the backward pass must be invoked with the same attention provider
+ # So, we lazily switch attention providers only when the forward pass of a new module is called
+ def callback(m: torch.nn.Module):
+ module_name, provider = module_providers_dict[m]
+ # HACK: for CP on transformer. Need to support other modules too and improve overall experience for external usage
+ if module_name in ["transformer"] and self.state.parallel_backend.context_parallel_enabled:
+ if not _AttentionProviderRegistry.supports_context_parallel(provider):
+ raise ValueError(
+ f"Attention provider {provider} does not support context parallel. Please use a different provider."
+ )
+ _AttentionProviderRegistry._set_context_parallel(
+ mesh=self.state.parallel_backend.get_mesh()["cp"], convert_to_fp32=True, rotate_method="allgather"
+ )
+ _AttentionProviderRegistry._active_provider = provider
+
+ # HACK: for VAE
+ if "vae" in name_providers_dict:
+ _apply_forward_hooks_hack(self.vae, name_providers_dict["vae"])
+
+ for module in module_providers_dict.keys():
+ registry = HookRegistry.check_if_exists_or_initialize(module)
+ hook = LatestActiveModuleHook(callback)
+ registry.register_hook(hook, _LATEST_ACTIVE_MODULE_HOOK)
+
+ yield
+
+ _AttentionProviderRegistry._active_provider = default_provider
+ _AttentionProviderRegistry._set_context_parallel(reset=True)
+ for module in module_providers_dict.keys():
+ registry: HookRegistry = module._diffusers_hook
+ registry.remove_hook(_LATEST_ACTIVE_MODULE_HOOK)
+
+ def _init_distributed(self) -> None:
+ world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))
+
+ # TODO(aryan): handle other backends
+ backend_cls: parallel.ParallelBackendType = parallel.get_parallel_backend_cls(self.args.parallel_backend)
+ self.state.parallel_backend = backend_cls(
+ world_size=world_size,
+ pp_degree=self.args.pp_degree,
+ dp_degree=self.args.dp_degree,
+ dp_shards=self.args.dp_shards,
+ cp_degree=self.args.cp_degree,
+ tp_degree=self.args.tp_degree,
+ backend="nccl",
+ timeout=self.args.init_timeout,
+ logging_dir=self.args.logging_dir,
+ output_dir=self.args.output_dir,
+ gradient_accumulation_steps=self.args.gradient_accumulation_steps,
+ )
+
+ if self.args.seed is not None:
+ self.state.parallel_backend.enable_determinism(self.args.seed)
+
+ def _init_logging(self) -> None:
+ logging._set_parallel_backend(self.state.parallel_backend)
+ logging.set_dependency_log_level(self.args.verbose, self.state.parallel_backend.is_local_main_process)
+ logger.info("Initialized FineTrainers")
+
+ def _init_trackers(self) -> None:
+ # TODO(aryan): handle multiple trackers
+ trackers = [self.args.report_to]
+ experiment_name = self.args.tracker_name or "finetrainers-experiment"
+ self.state.parallel_backend.initialize_trackers(
+ trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir
+ )
+
+ def _init_config_options(self) -> None:
+ # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if self.args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.set_float32_matmul_precision(self.args.float32_matmul_precision)
+
+ @property
+ def tracker(self):
+ return self.state.parallel_backend.tracker
+
+
+class LatestActiveModuleHook(ModelHook):
+ def __init__(self, callback: Callable[[torch.nn.Module], None] = None):
+ super().__init__()
+ self.callback = callback
+
+ def pre_forward(self, module, *args, **kwargs):
+ self.callback(module)
+ return args, kwargs
+
+
+def _parse_attention_providers(attn_providers: List[str] = None) -> List[Tuple[str, AttentionProvider]]:
+ parsed_providers = []
+ if attn_providers:
+ for provider_str in attn_providers:
+ parts = provider_str.split(":")
+ if len(parts) != 2:
+ raise ValueError(
+ f"Invalid attention provider format: '{provider_str}'. Expected 'module_name:provider_name'."
+ )
+ parts[1] = AttentionProvider(parts[1])
+ parsed_providers.append(tuple(parts))
+ return parsed_providers
+
+
+# TODO(aryan): instead of this, we could probably just apply the hook to vae.children() as we know their forward methods will be invoked
+def _apply_forward_hooks_hack(module: torch.nn.Module, provider: AttentionProvider):
+ if hasattr(module, "_finetrainers_wrapped_methods"):
+ return
+
+ def create_wrapper(old_method):
+ @functools.wraps(old_method)
+ def wrapper(*args, **kwargs):
+ _AttentionProviderRegistry._set_context_parallel(reset=True) # HACK: needs improvement
+ old_provider = _AttentionProviderRegistry._active_provider
+ _AttentionProviderRegistry._active_provider = provider
+ output = old_method(*args, **kwargs)
+ _AttentionProviderRegistry._active_provider = old_provider
+ return output
+
+ return wrapper
+
+ methods = ["encode", "decode", "_encode", "_decode", "tiled_encode", "tiled_decode"]
+ finetrainers_wrapped_methods = []
+ for method_name in methods:
+ if not hasattr(module, method_name):
+ continue
+ method = getattr(module, method_name)
+ wrapper = create_wrapper(method)
+ setattr(module, method_name, wrapper)
+ finetrainers_wrapped_methods.append(method_name)
+ module._finetrainers_wrapped_methods = finetrainers_wrapped_methods
diff --git a/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/__init__.py b/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b72fc82a2c73cfbbd8e95aaca9a1f127d15774
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/__init__.py
@@ -0,0 +1,2 @@
+from .config import ControlFullRankConfig, ControlLowRankConfig
+from .trainer import ControlTrainer
diff --git a/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/config.py b/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..14cfe715749fd74a622abd5afe5131efc0db130f
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/config.py
@@ -0,0 +1,185 @@
+import argparse
+from enum import Enum
+from typing import TYPE_CHECKING, Any, Dict, List, Union
+
+from finetrainers.utils import ArgsConfigMixin
+
+
+if TYPE_CHECKING:
+ from finetrainers.args import BaseArgs
+
+
+class ControlType(str, Enum):
+ r"""
+ Enum class for the control types.
+ """
+
+ CANNY = "canny"
+ CUSTOM = "custom"
+ NONE = "none"
+
+
+class FrameConditioningType(str, Enum):
+ r"""
+ Enum class for the frame conditioning types.
+ """
+
+ INDEX = "index"
+ PREFIX = "prefix"
+ RANDOM = "random"
+ FIRST_AND_LAST = "first_and_last"
+ FULL = "full"
+
+
+class ControlLowRankConfig(ArgsConfigMixin):
+ r"""
+ Configuration class for SFT channel-concatenated Control low rank training.
+
+ Args:
+ control_type (`str`, defaults to `"canny"`):
+ Control type for the low rank approximation matrices. Can be "canny", "custom".
+ rank (int, defaults to `64`):
+ Rank of the low rank approximation matrix.
+ lora_alpha (int, defaults to `64`):
+ The lora_alpha parameter to compute scaling factor (lora_alpha / rank) for low-rank matrices.
+ target_modules (`str` or `List[str]`, defaults to `"(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0|ff.net.0.proj|ff.net.2)"`):
+ Target modules for the low rank approximation matrices. Can be a regex string or a list of regex strings.
+ train_qk_norm (`bool`, defaults to `False`):
+ Whether to train the QK normalization layers.
+ frame_conditioning_type (`str`, defaults to `"full"`):
+ Type of frame conditioning. Can be "index", "prefix", "random", "first_and_last", or "full".
+ frame_conditioning_index (int, defaults to `0`):
+ Index of the frame conditioning. Only used if `frame_conditioning_type` is "index".
+ frame_conditioning_concatenate_mask (`bool`, defaults to `False`):
+ Whether to concatenate the frame mask with the latents across channel dim.
+ """
+
+ control_type: str = ControlType.CANNY
+ rank: int = 64
+ lora_alpha: int = 64
+ target_modules: Union[str, List[str]] = (
+ "(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0|ff.net.0.proj|ff.net.2)"
+ )
+ train_qk_norm: bool = False
+
+ # Specific to video models
+ frame_conditioning_type: str = FrameConditioningType.FULL
+ frame_conditioning_index: int = 0
+ frame_conditioning_concatenate_mask: bool = False
+
+ def add_args(self, parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--control_type",
+ type=str,
+ default=ControlType.CANNY.value,
+ choices=[x.value for x in ControlType.__members__.values()],
+ )
+ parser.add_argument("--rank", type=int, default=64)
+ parser.add_argument("--lora_alpha", type=int, default=64)
+ parser.add_argument(
+ "--target_modules",
+ type=str,
+ nargs="+",
+ default=[
+ "(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0|ff.net.0.proj|ff.net.2)"
+ ],
+ )
+ parser.add_argument("--train_qk_norm", action="store_true")
+ parser.add_argument(
+ "--frame_conditioning_type",
+ type=str,
+ default=FrameConditioningType.INDEX.value,
+ choices=[x.value for x in FrameConditioningType.__members__.values()],
+ )
+ parser.add_argument("--frame_conditioning_index", type=int, default=0)
+ parser.add_argument("--frame_conditioning_concatenate_mask", action="store_true")
+
+ def validate_args(self, args: "BaseArgs"):
+ assert self.rank > 0, "Rank must be a positive integer."
+ assert self.lora_alpha > 0, "lora_alpha must be a positive integer."
+
+ def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"):
+ mapped_args.control_type = argparse_args.control_type
+ mapped_args.rank = argparse_args.rank
+ mapped_args.lora_alpha = argparse_args.lora_alpha
+ mapped_args.target_modules = (
+ argparse_args.target_modules[0] if len(argparse_args.target_modules) == 1 else argparse_args.target_modules
+ )
+ mapped_args.train_qk_norm = argparse_args.train_qk_norm
+ mapped_args.frame_conditioning_type = argparse_args.frame_conditioning_type
+ mapped_args.frame_conditioning_index = argparse_args.frame_conditioning_index
+ mapped_args.frame_conditioning_concatenate_mask = argparse_args.frame_conditioning_concatenate_mask
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "control_type": self.control_type,
+ "rank": self.rank,
+ "lora_alpha": self.lora_alpha,
+ "target_modules": self.target_modules,
+ "train_qk_norm": self.train_qk_norm,
+ "frame_conditioning_type": self.frame_conditioning_type,
+ "frame_conditioning_index": self.frame_conditioning_index,
+ "frame_conditioning_concatenate_mask": self.frame_conditioning_concatenate_mask,
+ }
+
+
+class ControlFullRankConfig(ArgsConfigMixin):
+ r"""
+ Configuration class for SFT channel-concatenated Control full rank training.
+
+ Args:
+ control_type (`str`, defaults to `"canny"`):
+ Control type for the low rank approximation matrices. Can be "canny", "custom".
+ train_qk_norm (`bool`, defaults to `False`):
+ Whether to train the QK normalization layers.
+ frame_conditioning_type (`str`, defaults to `"index"`):
+ Type of frame conditioning. Can be "index", "prefix", "random", "first_and_last", or "full".
+ frame_conditioning_index (int, defaults to `0`):
+ Index of the frame conditioning. Only used if `frame_conditioning_type` is "index".
+ frame_conditioning_concatenate_mask (`bool`, defaults to `False`):
+ Whether to concatenate the frame mask with the latents across channel dim.
+ """
+
+ control_type: str = ControlType.CANNY
+ train_qk_norm: bool = False
+
+ # Specific to video models
+ frame_conditioning_type: str = FrameConditioningType.INDEX
+ frame_conditioning_index: int = 0
+ frame_conditioning_concatenate_mask: bool = False
+
+ def add_args(self, parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--control_type",
+ type=str,
+ default=ControlType.CANNY.value,
+ choices=[x.value for x in ControlType.__members__.values()],
+ )
+ parser.add_argument("--train_qk_norm", action="store_true")
+ parser.add_argument(
+ "--frame_conditioning_type",
+ type=str,
+ default=FrameConditioningType.INDEX.value,
+ choices=[x.value for x in FrameConditioningType.__members__.values()],
+ )
+ parser.add_argument("--frame_conditioning_index", type=int, default=0)
+ parser.add_argument("--frame_conditioning_concatenate_mask", action="store_true")
+
+ def validate_args(self, args: "BaseArgs"):
+ pass
+
+ def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"):
+ mapped_args.control_type = argparse_args.control_type
+ mapped_args.train_qk_norm = argparse_args.train_qk_norm
+ mapped_args.frame_conditioning_type = argparse_args.frame_conditioning_type
+ mapped_args.frame_conditioning_index = argparse_args.frame_conditioning_index
+ mapped_args.frame_conditioning_concatenate_mask = argparse_args.frame_conditioning_concatenate_mask
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "control_type": self.control_type,
+ "train_qk_norm": self.train_qk_norm,
+ "frame_conditioning_type": self.frame_conditioning_type,
+ "frame_conditioning_index": self.frame_conditioning_index,
+ "frame_conditioning_concatenate_mask": self.frame_conditioning_concatenate_mask,
+ }
diff --git a/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/data.py b/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c91fec06a62e22148415860c48c20e6ae2605d8
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/data.py
@@ -0,0 +1,268 @@
+import random
+from typing import Any, Dict, Optional
+
+import torch
+import torch.distributed.checkpoint.stateful
+from diffusers.video_processor import VideoProcessor
+
+import finetrainers.functional as FF
+from finetrainers.logging import get_logger
+from finetrainers.processors import CannyProcessor, CopyProcessor
+
+from .config import ControlType, FrameConditioningType
+
+
+logger = get_logger()
+
+
+class IterableControlDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
+ def __init__(
+ self, dataset: torch.utils.data.IterableDataset, control_type: str, device: Optional[torch.device] = None
+ ):
+ super().__init__()
+
+ self.dataset = dataset
+ self.control_type = control_type
+
+ self.control_processors = []
+ if control_type == ControlType.CANNY:
+ self.control_processors.append(
+ CannyProcessor(
+ output_names=["control_output"], input_names={"image": "input", "video": "input"}, device=device
+ )
+ )
+ elif control_type == ControlType.NONE:
+ self.control_processors.append(
+ CopyProcessor(output_names=["control_output"], input_names={"image": "input", "video": "input"})
+ )
+
+ logger.info("Initialized IterableControlDataset")
+
+ def __iter__(self):
+ logger.info("Starting IterableControlDataset")
+ for data in iter(self.dataset):
+ control_augmented_data = self._run_control_processors(data)
+ yield control_augmented_data
+
+ def load_state_dict(self, state_dict):
+ self.dataset.load_state_dict(state_dict)
+
+ def state_dict(self):
+ return self.dataset.state_dict()
+
+ def _run_control_processors(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ if "control_image" in data:
+ if "image" in data:
+ data["control_image"] = FF.resize_to_nearest_bucket_image(
+ data["control_image"], [data["image"].shape[-2:]], resize_mode="bicubic"
+ )
+ if "video" in data:
+ batch_size, num_frames, num_channels, height, width = data["video"].shape
+ data["control_video"], _first_frame_only = FF.resize_to_nearest_bucket_video(
+ data["control_video"], [[num_frames, height, width]], resize_mode="bicubic"
+ )
+ if _first_frame_only:
+ msg = (
+ "The number of frames in the control video is less than the minimum bucket size "
+ "specified. The first frame is being used as a single frame video. This "
+ "message is logged at the first occurence and for every 128th occurence "
+ "after that."
+ )
+ logger.log_freq("WARNING", "BUCKET_TEMPORAL_SIZE_UNAVAILABLE_CONTROL", msg, frequency=128)
+ data["control_video"] = data["control_video"][0]
+ return data
+
+ if "control_video" in data:
+ if "image" in data:
+ data["control_image"] = FF.resize_to_nearest_bucket_image(
+ data["control_video"][0], [data["image"].shape[-2:]], resize_mode="bicubic"
+ )
+ if "video" in data:
+ batch_size, num_frames, num_channels, height, width = data["video"].shape
+ data["control_video"], _first_frame_only = FF.resize_to_nearest_bucket_video(
+ data["control_video"], [[num_frames, height, width]], resize_mode="bicubic"
+ )
+ if _first_frame_only:
+ msg = (
+ "The number of frames in the control video is less than the minimum bucket size "
+ "specified. The first frame is being used as a single frame video. This "
+ "message is logged at the first occurence and for every 128th occurence "
+ "after that."
+ )
+ logger.log_freq("WARNING", "BUCKET_TEMPORAL_SIZE_UNAVAILABLE_CONTROL", msg, frequency=128)
+ data["control_video"] = data["control_video"][0]
+ return data
+
+ if self.control_type == ControlType.CUSTOM:
+ return data
+
+ shallow_copy_data = dict(data.items())
+ is_image_control = "image" in shallow_copy_data
+ is_video_control = "video" in shallow_copy_data
+ if (is_image_control + is_video_control) != 1:
+ raise ValueError("Exactly one of 'image' or 'video' should be present in the data.")
+ for processor in self.control_processors:
+ result = processor(**shallow_copy_data)
+ result_keys = set(result.keys())
+ repeat_keys = result_keys.intersection(shallow_copy_data.keys())
+ if repeat_keys:
+ logger.warning(
+ f"Processor {processor.__class__.__name__} returned keys that already exist in "
+ f"conditions: {repeat_keys}. Overwriting the existing values, but this may not "
+ f"be intended. Please rename the keys in the processor to avoid conflicts."
+ )
+ shallow_copy_data.update(result)
+ if "control_output" in shallow_copy_data:
+ # Normalize to [-1, 1] range
+ control_output = shallow_copy_data.pop("control_output")
+ # TODO(aryan): need to specify a dim for normalize here across channels
+ control_output = FF.normalize(control_output, min=-1.0, max=1.0)
+ key = "control_image" if is_image_control else "control_video"
+ shallow_copy_data[key] = control_output
+ return shallow_copy_data
+
+
+class ValidationControlDataset(torch.utils.data.IterableDataset):
+ def __init__(
+ self, dataset: torch.utils.data.IterableDataset, control_type: str, device: Optional[torch.device] = None
+ ):
+ super().__init__()
+
+ self.dataset = dataset
+ self.control_type = control_type
+ self.device = device
+ self._video_processor = VideoProcessor()
+
+ self.control_processors = []
+ if control_type == ControlType.CANNY:
+ self.control_processors.append(
+ CannyProcessor(["control_output"], input_names={"image": "input", "video": "input"}, device=device)
+ )
+ elif control_type == ControlType.NONE:
+ self.control_processors.append(
+ CopyProcessor(["control_output"], input_names={"image": "input", "video": "input"})
+ )
+
+ logger.info("Initialized ValidationControlDataset")
+
+ def __iter__(self):
+ logger.info("Starting ValidationControlDataset")
+ for data in iter(self.dataset):
+ control_augmented_data = self._run_control_processors(data)
+ yield control_augmented_data
+
+ def load_state_dict(self, state_dict):
+ self.dataset.load_state_dict(state_dict)
+
+ def state_dict(self):
+ return self.dataset.state_dict()
+
+ def _run_control_processors(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ if self.control_type == ControlType.CUSTOM:
+ return data
+ # These are already expected to be tensors
+ if "control_image" in data or "control_video" in data:
+ return data
+ shallow_copy_data = dict(data.items())
+ is_image_control = "image" in shallow_copy_data
+ is_video_control = "video" in shallow_copy_data
+ if (is_image_control + is_video_control) != 1:
+ raise ValueError("Exactly one of 'image' or 'video' should be present in the data.")
+ for processor in self.control_processors:
+ result = processor(**shallow_copy_data)
+ result_keys = set(result.keys())
+ repeat_keys = result_keys.intersection(shallow_copy_data.keys())
+ if repeat_keys:
+ logger.warning(
+ f"Processor {processor.__class__.__name__} returned keys that already exist in "
+ f"conditions: {repeat_keys}. Overwriting the existing values, but this may not "
+ f"be intended. Please rename the keys in the processor to avoid conflicts."
+ )
+ shallow_copy_data.update(result)
+ if "control_output" in shallow_copy_data:
+ # Normalize to [-1, 1] range
+ control_output = shallow_copy_data.pop("control_output")
+ if torch.is_tensor(control_output):
+ # TODO(aryan): need to specify a dim for normalize here across channels
+ control_output = FF.normalize(control_output, min=-1.0, max=1.0)
+ ndim = control_output.ndim
+ assert 3 <= ndim <= 5, "Control output should be at least ndim=3 and less than or equal to ndim=5"
+ if ndim == 5:
+ control_output = self._video_processor.postprocess_video(control_output, output_type="pil")
+ else:
+ if ndim == 3:
+ control_output = control_output.unsqueeze(0)
+ control_output = self._video_processor.postprocess(control_output, output_type="pil")[0]
+ key = "control_image" if is_image_control else "control_video"
+ shallow_copy_data[key] = control_output
+ return shallow_copy_data
+
+
+# TODO(aryan): write a test for this function
+def apply_frame_conditioning_on_latents(
+ latents: torch.Tensor,
+ expected_num_frames: int,
+ channel_dim: int,
+ frame_dim: int,
+ frame_conditioning_type: FrameConditioningType,
+ frame_conditioning_index: Optional[int] = None,
+ concatenate_mask: bool = False,
+) -> torch.Tensor:
+ num_frames = latents.size(frame_dim)
+ mask = torch.zeros_like(latents)
+
+ if frame_conditioning_type == FrameConditioningType.INDEX:
+ frame_index = min(frame_conditioning_index, num_frames - 1)
+ indexing = [slice(None)] * latents.ndim
+ indexing[frame_dim] = frame_index
+ mask[tuple(indexing)] = 1
+ latents = latents * mask
+
+ elif frame_conditioning_type == FrameConditioningType.PREFIX:
+ frame_index = random.randint(1, num_frames)
+ indexing = [slice(None)] * latents.ndim
+ indexing[frame_dim] = slice(0, frame_index) # Keep frames 0 to frame_index-1
+ mask[tuple(indexing)] = 1
+ latents = latents * mask
+
+ elif frame_conditioning_type == FrameConditioningType.RANDOM:
+ # Zero or more random frames to keep
+ num_frames_to_keep = random.randint(1, num_frames)
+ frame_indices = random.sample(range(num_frames), num_frames_to_keep)
+ indexing = [slice(None)] * latents.ndim
+ indexing[frame_dim] = frame_indices
+ mask[tuple(indexing)] = 1
+ latents = latents * mask
+
+ elif frame_conditioning_type == FrameConditioningType.FIRST_AND_LAST:
+ indexing = [slice(None)] * latents.ndim
+ indexing[frame_dim] = 0
+ mask[tuple(indexing)] = 1
+ indexing[frame_dim] = num_frames - 1
+ mask[tuple(indexing)] = 1
+ latents = latents * mask
+
+ elif frame_conditioning_type == FrameConditioningType.FULL:
+ indexing = [slice(None)] * latents.ndim
+ indexing[frame_dim] = slice(0, num_frames)
+ mask[tuple(indexing)] = 1
+
+ if latents.size(frame_dim) >= expected_num_frames:
+ slicing = [slice(None)] * latents.ndim
+ slicing[frame_dim] = slice(expected_num_frames)
+ latents = latents[tuple(slicing)]
+ mask = mask[tuple(slicing)]
+ else:
+ pad_size = expected_num_frames - num_frames
+ pad_shape = list(latents.shape)
+ pad_shape[frame_dim] = pad_size
+ padding = latents.new_zeros(pad_shape)
+ latents = torch.cat([latents, padding], dim=frame_dim)
+ mask = torch.cat([mask, padding], dim=frame_dim)
+
+ if concatenate_mask:
+ slicing = [slice(None)] * latents.ndim
+ slicing[channel_dim] = 0
+ latents = torch.cat([latents, mask], dim=channel_dim)
+
+ return latents
diff --git a/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/trainer.py b/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..576e17a0c0298f48b3a413ebc144586e4ce9e590
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/trainer/control_trainer/trainer.py
@@ -0,0 +1,1021 @@
+import functools
+import json
+import os
+import re
+import time
+from pathlib import Path
+from typing import Any, Dict, Iterable, List, Optional, Union
+
+import datasets.distributed
+import safetensors.torch
+import torch
+import wandb
+from diffusers import DiffusionPipeline
+from diffusers.hooks import apply_layerwise_casting
+from diffusers.training_utils import cast_training_params
+from diffusers.utils import export_to_video
+from huggingface_hub import create_repo, upload_folder
+from peft import LoraConfig, get_peft_model_state_dict
+from tqdm import tqdm
+
+from finetrainers import data, logging, models, optimizer, parallel, utils
+from finetrainers.args import BaseArgsType
+from finetrainers.config import TrainingType
+from finetrainers.patches import load_lora_weights
+from finetrainers.state import TrainState
+
+from ..base import Trainer
+from .config import ControlFullRankConfig, ControlLowRankConfig
+from .data import IterableControlDataset, ValidationControlDataset
+
+
+ArgsType = Union[BaseArgsType, ControlFullRankConfig, ControlLowRankConfig]
+
+logger = logging.get_logger()
+
+
+class ControlTrainer(Trainer):
+ def __init__(self, args: ArgsType, model_specification: models.ControlModelSpecification) -> None:
+ super().__init__(args)
+
+ # Tokenizers
+ self.tokenizer = None
+ self.tokenizer_2 = None
+ self.tokenizer_3 = None
+
+ # Text encoders
+ self.text_encoder = None
+ self.text_encoder_2 = None
+ self.text_encoder_3 = None
+
+ # Denoisers
+ self.transformer = None
+ self.unet = None
+
+ # Autoencoders
+ self.vae = None
+
+ # Scheduler
+ self.scheduler = None
+
+ # Optimizer & LR scheduler
+ self.optimizer = None
+ self.lr_scheduler = None
+
+ # Checkpoint manager
+ self.checkpointer = None
+
+ self.model_specification = model_specification
+ self._are_condition_models_loaded = False
+
+ model_specification._trainer_init(
+ args.frame_conditioning_type, args.frame_conditioning_index, args.frame_conditioning_concatenate_mask
+ )
+
+ def run(self) -> None:
+ try:
+ self._prepare_models()
+ self._prepare_trainable_parameters()
+ self._prepare_for_training()
+ self._prepare_dataset()
+ self._prepare_checkpointing()
+ self._train()
+ # trainer._evaluate()
+ except Exception as e:
+ logger.error(f"Error during training: {e}")
+ self.state.parallel_backend.destroy()
+ raise e
+
+ def _prepare_models(self) -> None:
+ logger.info("Initializing models")
+
+ # TODO(aryan): allow multiple control conditions instead of just one if there's a use case for it
+ new_in_features = self.model_specification._original_control_layer_in_features * 2
+ diffusion_components = self.model_specification.load_diffusion_models(new_in_features)
+ self._set_components(diffusion_components)
+
+ if self.state.parallel_backend.pipeline_parallel_enabled:
+ raise NotImplementedError(
+ "Pipeline parallelism is not supported yet. This will be supported in the future."
+ )
+
+ def _prepare_trainable_parameters(self) -> None:
+ logger.info("Initializing trainable parameters")
+
+ parallel_backend = self.state.parallel_backend
+ model_spec = self.model_specification
+
+ if self.args.training_type == TrainingType.CONTROL_FULL_FINETUNE:
+ logger.info("Finetuning transformer with no additional parameters")
+ utils.set_requires_grad([self.transformer], True)
+ else:
+ logger.info("Finetuning transformer with PEFT parameters")
+ utils.set_requires_grad([self.transformer], False)
+
+ # Layerwise upcasting must be applied before adding the LoRA adapter.
+ # If we don't perform this before moving to device, we might OOM on the GPU. So, best to do it on
+ # CPU for now, before support is added in Diffusers for loading and enabling layerwise upcasting directly.
+ if (
+ self.args.training_type == TrainingType.CONTROL_LORA
+ and "transformer" in self.args.layerwise_upcasting_modules
+ ):
+ apply_layerwise_casting(
+ self.transformer,
+ storage_dtype=self.args.layerwise_upcasting_storage_dtype,
+ compute_dtype=self.args.transformer_dtype,
+ skip_modules_pattern=self.args.layerwise_upcasting_skip_modules_pattern,
+ non_blocking=True,
+ )
+
+ transformer_lora_config = None
+ if self.args.training_type == TrainingType.CONTROL_LORA:
+ transformer_lora_config = LoraConfig(
+ r=self.args.rank,
+ lora_alpha=self.args.lora_alpha,
+ init_lora_weights=True,
+ target_modules=self._get_lora_target_modules(),
+ rank_pattern={
+ model_spec.control_injection_layer_name: model_spec._original_control_layer_out_features
+ },
+ alpha_pattern={
+ model_spec.control_injection_layer_name: model_spec._original_control_layer_out_features
+ },
+ )
+ self.transformer.add_adapter(transformer_lora_config)
+
+ if self.args.train_qk_norm:
+ qk_norm_identifiers = model_spec._qk_norm_identifiers
+ qk_norm_module_names, qk_norm_modules = [], []
+
+ for name, module in self.transformer.named_modules():
+ regex_match = any(re.search(identifier, name) is not None for identifier in qk_norm_identifiers)
+ is_parameteric = len(list(module.parameters())) > 0
+ if regex_match and is_parameteric:
+ qk_norm_module_names.append(name)
+ qk_norm_modules.append(module)
+
+ if len(qk_norm_modules) > 0:
+ logger.info(f"Training QK norms for modules: {qk_norm_module_names}")
+ utils.set_requires_grad(qk_norm_modules, True)
+ else:
+ logger.warning(f"No QK norm modules found with identifiers: {qk_norm_identifiers}")
+
+ # Make sure the trainable params are in float32 if data sharding is not enabled. For FSDP, we need all
+ # parameters to be of the same dtype.
+ if parallel_backend.data_sharding_enabled:
+ self.transformer.to(dtype=self.args.transformer_dtype)
+ else:
+ if self.args.training_type == TrainingType.CONTROL_LORA:
+ cast_training_params([self.transformer], dtype=torch.float32)
+
+ def _prepare_for_training(self) -> None:
+ # 1. Apply parallelism
+ parallel_backend = self.state.parallel_backend
+ model_specification = self.model_specification
+
+ if parallel_backend.context_parallel_enabled:
+ parallel_backend.apply_context_parallel(self.transformer, parallel_backend.get_mesh()["cp"])
+
+ if parallel_backend.tensor_parallel_enabled:
+ # TODO(aryan): handle fp8 from TorchAO here
+ model_specification.apply_tensor_parallel(
+ backend=parallel.ParallelBackendEnum.PTD,
+ device_mesh=parallel_backend.get_mesh()["tp"],
+ transformer=self.transformer,
+ )
+
+ # Enable gradient checkpointing
+ if self.args.gradient_checkpointing:
+ # TODO(aryan): support other checkpointing types
+ utils.apply_activation_checkpointing(self.transformer, checkpointing_type="full")
+
+ # Apply torch.compile
+ self._maybe_torch_compile()
+
+ # Enable DDP, FSDP or HSDP
+ if parallel_backend.data_sharding_enabled:
+ # TODO(aryan): remove this when supported
+ if self.args.parallel_backend == "accelerate":
+ raise NotImplementedError("Data sharding is not supported with Accelerate yet.")
+
+ dp_method = "HSDP" if parallel_backend.data_replication_enabled else "FSDP"
+ logger.info(f"Applying {dp_method} on the model")
+
+ if parallel_backend.data_replication_enabled or parallel_backend.context_parallel_enabled:
+ dp_mesh_names = ("dp_replicate", "dp_shard_cp")
+ else:
+ dp_mesh_names = ("dp_shard_cp",)
+
+ parallel_backend.apply_fsdp2(
+ model=self.transformer,
+ param_dtype=self.args.transformer_dtype,
+ reduce_dtype=torch.float32,
+ output_dtype=None,
+ pp_enabled=parallel_backend.pipeline_parallel_enabled,
+ cpu_offload=False, # TODO(aryan): needs to be tested and allowed for enabling later
+ device_mesh=parallel_backend.get_mesh()[dp_mesh_names],
+ )
+ elif parallel_backend.data_replication_enabled:
+ if parallel_backend.get_mesh().ndim > 1:
+ raise ValueError("DDP not supported for > 1D parallelism")
+ parallel_backend.apply_ddp(self.transformer, parallel_backend.get_mesh())
+ else:
+ parallel_backend.prepare_model(self.transformer)
+
+ self._move_components_to_device()
+
+ # 2. Prepare optimizer and lr scheduler
+ # For training LoRAs, we can be a little more optimal. Currently, the OptimizerWrapper only accepts torch::nn::Module.
+ # This causes us to loop over all the parameters (even ones that don't require gradients, as in LoRA) at each optimizer
+ # step. This is OK (see https://github.com/pytorch/pytorch/blob/2f40f789dafeaa62c4e4b90dbf4a900ff6da2ca4/torch/optim/sgd.py#L85-L99)
+ # but can be optimized a bit by maybe creating a simple wrapper module encompassing the actual parameters that require
+ # gradients. TODO(aryan): look into it in the future.
+ model_parts = [self.transformer]
+ self.state.num_trainable_parameters = sum(
+ p.numel() for m in model_parts for p in m.parameters() if p.requires_grad
+ )
+
+ # Setup distributed optimizer and lr scheduler
+ logger.info("Initializing optimizer and lr scheduler")
+ self.state.train_state = TrainState()
+ self.optimizer = optimizer.get_optimizer(
+ parallel_backend=self.args.parallel_backend,
+ name=self.args.optimizer,
+ model_parts=model_parts,
+ learning_rate=self.args.lr,
+ beta1=self.args.beta1,
+ beta2=self.args.beta2,
+ beta3=self.args.beta3,
+ epsilon=self.args.epsilon,
+ weight_decay=self.args.weight_decay,
+ fused=False,
+ )
+ self.lr_scheduler = optimizer.get_lr_scheduler(
+ parallel_backend=self.args.parallel_backend,
+ name=self.args.lr_scheduler,
+ optimizer=self.optimizer,
+ num_warmup_steps=self.args.lr_warmup_steps,
+ num_training_steps=self.args.train_steps,
+ # TODO(aryan): handle last_epoch
+ )
+ self.optimizer, self.lr_scheduler = parallel_backend.prepare_optimizer(self.optimizer, self.lr_scheduler)
+
+ # 3. Initialize trackers, directories and repositories
+ self._init_logging()
+ self._init_trackers()
+ self._init_directories_and_repositories()
+
+ def _prepare_dataset(self) -> None:
+ logger.info("Initializing dataset and dataloader")
+
+ with open(self.args.dataset_config, "r") as file:
+ dataset_configs = json.load(file)["datasets"]
+ logger.info(f"Training configured to use {len(dataset_configs)} datasets")
+
+ datasets = []
+ for config in dataset_configs:
+ data_root = config.pop("data_root", None)
+ dataset_file = config.pop("dataset_file", None)
+ dataset_type = config.pop("dataset_type")
+ caption_options = config.pop("caption_options", {})
+
+ if data_root is not None and dataset_file is not None:
+ raise ValueError("Both data_root and dataset_file cannot be provided in the same dataset config.")
+
+ dataset_name_or_root = data_root or dataset_file
+ dataset = data.initialize_dataset(
+ dataset_name_or_root, dataset_type, streaming=True, infinite=True, _caption_options=caption_options
+ )
+
+ if not dataset._precomputable_once and self.args.precomputation_once:
+ raise ValueError(
+ f"Dataset {dataset_name_or_root} does not support precomputing all embeddings at once."
+ )
+
+ logger.info(f"Initialized dataset: {dataset_name_or_root}")
+ dataset = self.state.parallel_backend.prepare_dataset(dataset)
+ dataset = data.wrap_iterable_dataset_for_preprocessing(dataset, dataset_type, config)
+ datasets.append(dataset)
+
+ dataset = data.combine_datasets(datasets, buffer_size=self.args.dataset_shuffle_buffer_size, shuffle=True)
+ dataset = IterableControlDataset(dataset, self.args.control_type, self.state.parallel_backend.device)
+ dataloader = self.state.parallel_backend.prepare_dataloader(
+ dataset, batch_size=1, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.pin_memory
+ )
+
+ self.dataset = dataset
+ self.dataloader = dataloader
+
+ def _prepare_checkpointing(self) -> None:
+ parallel_backend = self.state.parallel_backend
+
+ def save_model_hook(state_dict: Dict[str, Any]) -> None:
+ state_dict = utils.get_unwrapped_model_state_dict(state_dict)
+ if parallel_backend.is_main_process:
+ if self.args.training_type == TrainingType.CONTROL_LORA:
+ state_dict = get_peft_model_state_dict(self.transformer, state_dict)
+ qk_norm_state_dict = None
+ if self.args.train_qk_norm:
+ qk_norm_state_dict = {
+ name: parameter
+ for name, parameter in state_dict.items()
+ if any(
+ re.search(identifier, name) is not None
+ for identifier in self.model_specification._qk_norm_identifiers
+ )
+ and parameter.numel() > 0
+ }
+ if len(qk_norm_state_dict) == 0:
+ qk_norm_state_dict = None
+ # fmt: off
+ metadata = {
+ "r": self.args.rank,
+ "lora_alpha": self.args.lora_alpha,
+ "init_lora_weights": True,
+ "target_modules": self._get_lora_target_modules(),
+ "rank_pattern": {self.model_specification.control_injection_layer_name: self.model_specification._original_control_layer_out_features},
+ "alpha_pattern": {self.model_specification.control_injection_layer_name: self.model_specification._original_control_layer_out_features},
+ }
+ metadata = {"lora_config": json.dumps(metadata, indent=4)}
+ # fmt: on
+ self.model_specification._save_lora_weights(
+ os.path.join(self.args.output_dir, "lora_weights", f"{self.state.train_state.step:06d}"),
+ state_dict,
+ qk_norm_state_dict,
+ self.scheduler,
+ metadata,
+ )
+ elif self.args.training_type == TrainingType.CONTROL_FULL_FINETUNE:
+ self.model_specification._save_model(
+ os.path.join(self.args.output_dir, "model_weights", f"{self.state.train_state.step:06d}"),
+ self.transformer,
+ state_dict,
+ self.scheduler,
+ )
+ parallel_backend.wait_for_everyone()
+
+ enable_state_checkpointing = self.args.checkpointing_steps > 0
+ self.checkpointer = parallel_backend.get_checkpointer(
+ dataloader=self.dataloader,
+ model_parts=[self.transformer],
+ optimizers=self.optimizer,
+ schedulers=self.lr_scheduler,
+ states={"train_state": self.state.train_state},
+ checkpointing_steps=self.args.checkpointing_steps,
+ checkpointing_limit=self.args.checkpointing_limit,
+ output_dir=self.args.output_dir,
+ enable=enable_state_checkpointing,
+ _callback_fn=save_model_hook,
+ )
+
+ resume_from_checkpoint = self.args.resume_from_checkpoint
+ if resume_from_checkpoint == "latest":
+ resume_from_checkpoint = -1
+ if resume_from_checkpoint is not None:
+ self.checkpointer.load(resume_from_checkpoint)
+
+ def _train(self) -> None:
+ logger.info("Starting training")
+
+ parallel_backend = self.state.parallel_backend
+ train_state = self.state.train_state
+ device = parallel_backend.device
+ dtype = self.args.transformer_dtype
+
+ memory_statistics = utils.get_memory_statistics()
+ logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
+
+ global_batch_size = self.args.batch_size * parallel_backend._dp_degree
+ info = {
+ "trainable parameters": self.state.num_trainable_parameters,
+ "train steps": self.args.train_steps,
+ "per-replica batch size": self.args.batch_size,
+ "global batch size": global_batch_size,
+ "gradient accumulation steps": self.args.gradient_accumulation_steps,
+ }
+ logger.info(f"Training configuration: {json.dumps(info, indent=4)}")
+
+ progress_bar = tqdm(
+ range(0, self.args.train_steps),
+ initial=train_state.step,
+ desc="Training steps",
+ disable=not parallel_backend.is_local_main_process,
+ )
+
+ generator = torch.Generator(device=device)
+ if self.args.seed is not None:
+ generator = generator.manual_seed(self.args.seed)
+ self.state.generator = generator
+
+ scheduler_sigmas = utils.get_scheduler_sigmas(self.scheduler)
+ scheduler_sigmas = (
+ scheduler_sigmas.to(device=device, dtype=torch.float32) if scheduler_sigmas is not None else None
+ )
+ scheduler_alphas = utils.get_scheduler_alphas(self.scheduler)
+ scheduler_alphas = (
+ scheduler_alphas.to(device=device, dtype=torch.float32) if scheduler_alphas is not None else None
+ )
+ # timesteps_buffer = []
+
+ self.transformer.train()
+ data_iterator = iter(self.dataloader)
+
+ compute_posterior = False if self.args.enable_precomputation else (not self.args.precomputation_once)
+ preprocessor = data.initialize_preprocessor(
+ rank=parallel_backend.rank,
+ world_size=parallel_backend.world_size,
+ num_items=self.args.precomputation_items if self.args.enable_precomputation else 1,
+ processor_fn={
+ "condition": self.model_specification.prepare_conditions,
+ "latent": functools.partial(
+ self.model_specification.prepare_latents, compute_posterior=compute_posterior
+ ),
+ },
+ save_dir=self.args.precomputation_dir,
+ enable_precomputation=self.args.enable_precomputation,
+ enable_reuse=self.args.precomputation_reuse,
+ )
+ condition_iterator: Iterable[Dict[str, Any]] = None
+ latent_iterator: Iterable[Dict[str, Any]] = None
+ sampler = data.ResolutionSampler(
+ batch_size=self.args.batch_size, dim_keys=self.model_specification._resolution_dim_keys
+ )
+ requires_gradient_step = True
+ accumulated_loss = 0.0
+
+ while (
+ train_state.step < self.args.train_steps and train_state.observed_data_samples < self.args.max_data_samples
+ ):
+ # 1. Load & preprocess data if required
+ if preprocessor.requires_data:
+ condition_iterator, latent_iterator = self._prepare_data(preprocessor, data_iterator)
+
+ # 2. Prepare batch
+ with self.tracker.timed("timing/batch_preparation"):
+ try:
+ condition_item = next(condition_iterator)
+ latent_item = next(latent_iterator)
+ sampler.consume(condition_item, latent_item)
+ except StopIteration:
+ if requires_gradient_step:
+ self.optimizer.step()
+ self.lr_scheduler.step()
+ requires_gradient_step = False
+ logger.info("Data exhausted. Exiting training loop.")
+ break
+
+ if sampler.is_ready:
+ condition_batch, latent_batch = sampler.get_batch()
+ condition_model_conditions = self.model_specification.collate_conditions(condition_batch)
+ latent_model_conditions = self.model_specification.collate_latents(latent_batch)
+ else:
+ continue
+
+ train_state.step += 1
+ train_state.observed_data_samples += self.args.batch_size * parallel_backend._dp_degree
+
+ logger.debug(f"Starting training step ({train_state.step}/{self.args.train_steps})")
+
+ latent_model_conditions = utils.align_device_and_dtype(latent_model_conditions, device, dtype)
+ condition_model_conditions = utils.align_device_and_dtype(condition_model_conditions, device, dtype)
+ latent_model_conditions = utils.make_contiguous(latent_model_conditions)
+ condition_model_conditions = utils.make_contiguous(condition_model_conditions)
+
+ # 3. Forward pass
+ sigmas = utils.prepare_sigmas(
+ scheduler=self.scheduler,
+ sigmas=scheduler_sigmas,
+ batch_size=self.args.batch_size,
+ num_train_timesteps=self.scheduler.config.num_train_timesteps,
+ flow_weighting_scheme=self.args.flow_weighting_scheme,
+ flow_logit_mean=self.args.flow_logit_mean,
+ flow_logit_std=self.args.flow_logit_std,
+ flow_mode_scale=self.args.flow_mode_scale,
+ device=device,
+ generator=self.state.generator,
+ )
+ sigmas = utils.expand_tensor_dims(sigmas, latent_model_conditions["latents"].ndim)
+
+ # NOTE: for planned refactor, make sure that forward and backward pass run under the context.
+ # If only forward runs under context, backward will most likely fail when using activation checkpointing
+ with self.attention_provider_ctx(training=True):
+ with self.tracker.timed("timing/forward"):
+ pred, target, sigmas = self.model_specification.forward(
+ transformer=self.transformer,
+ scheduler=self.scheduler,
+ condition_model_conditions=condition_model_conditions,
+ latent_model_conditions=latent_model_conditions,
+ sigmas=sigmas,
+ compute_posterior=compute_posterior,
+ )
+
+ timesteps = (sigmas * 1000.0).long()
+ weights = utils.prepare_loss_weights(
+ scheduler=self.scheduler,
+ alphas=scheduler_alphas[timesteps] if scheduler_alphas is not None else None,
+ sigmas=sigmas,
+ flow_weighting_scheme=self.args.flow_weighting_scheme,
+ )
+ weights = utils.expand_tensor_dims(weights, pred.ndim)
+
+ # 4. Compute loss & backward pass
+ with self.tracker.timed("timing/backward"):
+ loss = weights.float() * (pred.float() - target.float()).pow(2)
+ # Average loss across all but batch dimension (for per-batch debugging in case needed)
+ loss = loss.mean(list(range(1, loss.ndim)))
+ # Average loss across batch dimension
+ loss = loss.mean()
+ if self.args.gradient_accumulation_steps > 1:
+ loss = loss / self.args.gradient_accumulation_steps
+ loss.backward()
+
+ accumulated_loss += loss.detach().item()
+ requires_gradient_step = True
+
+ # 5. Clip gradients
+ model_parts = [self.transformer]
+ grad_norm = utils.torch._clip_grad_norm_while_handling_failing_dtensor_cases(
+ [p for m in model_parts for p in m.parameters()],
+ self.args.max_grad_norm,
+ foreach=True,
+ pp_mesh=parallel_backend.get_mesh()["pp"] if parallel_backend.pipeline_parallel_enabled else None,
+ )
+
+ # 6. Step optimizer & log metrics
+ logs = {}
+
+ if train_state.step % self.args.gradient_accumulation_steps == 0:
+ # TODO(aryan): revisit no_sync() for FSDP
+ with self.tracker.timed("timing/optimizer_step"):
+ self.optimizer.step()
+ self.lr_scheduler.step()
+ self.optimizer.zero_grad()
+
+ if grad_norm is not None:
+ grad_norm = grad_norm if isinstance(grad_norm, float) else grad_norm.detach().item()
+ if (
+ parallel_backend.data_replication_enabled
+ or parallel_backend.data_sharding_enabled
+ or parallel_backend.context_parallel_enabled
+ ):
+ dp_cp_mesh = parallel_backend.get_mesh()["dp_cp"]
+ if grad_norm is not None:
+ grad_norm = parallel.dist_mean(torch.tensor([grad_norm], device=device), dp_cp_mesh)
+ global_avg_loss, global_max_loss = (
+ parallel.dist_mean(torch.tensor([accumulated_loss], device=device), dp_cp_mesh),
+ parallel.dist_max(torch.tensor([accumulated_loss], device=device), dp_cp_mesh),
+ )
+ else:
+ global_avg_loss = global_max_loss = accumulated_loss
+
+ logs["train/global_avg_loss"] = global_avg_loss
+ logs["train/global_max_loss"] = global_max_loss
+ if grad_norm is not None:
+ logs["train/grad_norm"] = grad_norm
+ train_state.global_avg_losses.append(global_avg_loss)
+ train_state.global_max_losses.append(global_max_loss)
+ accumulated_loss = 0.0
+ requires_gradient_step = False
+
+ progress_bar.update(1)
+ progress_bar.set_postfix(logs)
+
+ # timesteps_buffer.extend([(train_state.step, t) for t in timesteps.detach().cpu().numpy().tolist()])
+
+ if train_state.step % self.args.logging_steps == 0:
+ # TODO(aryan): handle non-SchedulerWrapper schedulers (probably not required eventually) since they might not be dicts
+ # TODO(aryan): causes NCCL hang for some reason. look into later
+ # logs.update(self.lr_scheduler.get_last_lr())
+
+ # timesteps_table = wandb.Table(data=timesteps_buffer, columns=["step", "timesteps"])
+ # logs["timesteps"] = wandb.plot.scatter(
+ # timesteps_table, "step", "timesteps", title="Timesteps distribution"
+ # )
+ # timesteps_buffer = []
+
+ logs["train/observed_data_samples"] = train_state.observed_data_samples
+
+ parallel_backend.log(logs, step=train_state.step)
+ train_state.log_steps.append(train_state.step)
+
+ # 7. Save checkpoint if required
+ with self.tracker.timed("timing/checkpoint"):
+ self.checkpointer.save(
+ step=train_state.step, _device=device, _is_main_process=parallel_backend.is_main_process
+ )
+
+ # 8. Perform validation if required
+ if train_state.step % self.args.validation_steps == 0:
+ self._validate(step=train_state.step, final_validation=False)
+
+ # 9. Final checkpoint, validation & cleanup
+ self.checkpointer.save(
+ train_state.step, force=True, _device=device, _is_main_process=parallel_backend.is_main_process
+ )
+ parallel_backend.wait_for_everyone()
+ self._validate(step=train_state.step, final_validation=True)
+
+ self._delete_components()
+ memory_statistics = utils.get_memory_statistics()
+ logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}")
+
+ # 10. Upload artifacts to hub
+ if parallel_backend.is_main_process and self.args.push_to_hub:
+ upload_folder(
+ repo_id=self.state.repo_id,
+ folder_path=self.args.output_dir,
+ ignore_patterns=[f"{self.checkpointer._prefix}_*"],
+ )
+
+ parallel_backend.destroy()
+
+ def _validate(self, step: int, final_validation: bool = False) -> None:
+ if self.args.validation_dataset_file is None:
+ return
+
+ logger.info("Starting validation")
+
+ # 1. Load validation dataset
+ parallel_backend = self.state.parallel_backend
+ dataset = data.ValidationDataset(self.args.validation_dataset_file)
+
+ # Hack to make accelerate work. TODO(aryan): refactor
+ if parallel_backend._dp_degree > 1:
+ dp_mesh = parallel_backend.get_mesh()["dp"]
+ dp_local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size()
+ dataset._data = datasets.distributed.split_dataset_by_node(dataset._data, dp_local_rank, dp_world_size)
+ else:
+ dp_mesh = None
+ dp_local_rank, dp_world_size = parallel_backend.local_rank, 1
+
+ dataset = ValidationControlDataset(dataset, self.args.control_type, parallel_backend.device)
+ validation_dataloader = data.DPDataLoader(
+ dp_local_rank,
+ dataset,
+ batch_size=1,
+ num_workers=self.args.dataloader_num_workers,
+ collate_fn=lambda items: items,
+ )
+ data_iterator = iter(validation_dataloader)
+ main_process_prompts_to_filenames = {} # Used to save model card
+ all_processes_artifacts = [] # Used to gather artifacts from all processes
+
+ memory_statistics = utils.get_memory_statistics()
+ logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
+
+ seed = self.args.seed if self.args.seed is not None else 0
+ generator = torch.Generator(device=parallel_backend.device).manual_seed(seed)
+ pipeline = self._init_pipeline(final_validation=final_validation)
+
+ # 2. Run validation
+ # TODO(aryan): when running validation with FSDP, if the number of data points is not divisible by dp_shards, we
+ # will hang indefinitely. Either pad the dataset or raise an error early on during initialization if the dataset
+ # size is not divisible by dp_shards.
+ self.transformer.eval()
+ while True:
+ validation_data = next(data_iterator, None)
+ if validation_data is None:
+ break
+
+ validation_data = validation_data[0]
+ with self.attention_provider_ctx(training=False):
+ validation_artifacts = self.model_specification.validation(
+ pipeline=pipeline, generator=generator, **validation_data
+ )
+
+ if dp_local_rank != 0:
+ continue
+
+ PROMPT = validation_data["prompt"]
+ IMAGE = validation_data.get("image", None)
+ VIDEO = validation_data.get("video", None)
+ CONTROL_IMAGE = validation_data.get("control_image", None)
+ CONTROL_VIDEO = validation_data.get("control_video", None)
+ EXPORT_FPS = validation_data.get("export_fps", 30)
+
+ # 2.1. If there are any initial images or videos, they will be logged to keep track of them as
+ # conditioning for generation.
+ prompt_filename = utils.string_to_filename(PROMPT)[:25]
+ artifacts = {
+ "input_image": data.ImageArtifact(value=IMAGE),
+ "input_video": data.VideoArtifact(value=VIDEO),
+ "control_image": data.ImageArtifact(value=CONTROL_IMAGE),
+ "control_video": data.VideoArtifact(value=CONTROL_VIDEO),
+ }
+
+ # 2.2. Track the artifacts generated from validation
+ for i, validation_artifact in enumerate(validation_artifacts):
+ if validation_artifact.value is None:
+ continue
+ artifacts.update({f"artifact_{i}": validation_artifact})
+
+ # 2.3. Save the artifacts to the output directory and create appropriate logging objects
+ # TODO(aryan): Currently, we only support WandB so we've hardcoded it here. Needs to be revisited.
+ for index, (key, artifact) in enumerate(list(artifacts.items())):
+ assert isinstance(artifact, (data.ImageArtifact, data.VideoArtifact))
+ if artifact.value is None:
+ continue
+
+ time_, rank, ext = int(time.time()), parallel_backend.rank, artifact.file_extension
+ filename = "validation-" if not final_validation else "final-"
+ filename += f"{step}-{rank}-{index}-{prompt_filename}-{time_}.{ext}"
+
+ if parallel_backend.is_main_process and ext in ["mp4", "jpg", "jpeg", "png"]:
+ main_process_prompts_to_filenames[PROMPT] = filename
+
+ caption = PROMPT
+ if key == "control_image":
+ filename = f"control_image-{filename}"
+ caption = f"[control] {caption}"
+ elif key == "control_video":
+ filename = f"control_video-{filename}"
+ caption = f"[control] {caption}"
+
+ output_filename = os.path.join(self.args.output_dir, filename)
+
+ if isinstance(artifact, data.ImageArtifact):
+ artifact.value.save(output_filename)
+ all_processes_artifacts.append(wandb.Image(output_filename, caption=caption))
+ elif isinstance(artifact, data.VideoArtifact):
+ export_to_video(artifact.value, output_filename, fps=EXPORT_FPS)
+ all_processes_artifacts.append(wandb.Video(output_filename, caption=caption))
+
+ # 3. Cleanup & log artifacts
+ parallel_backend.wait_for_everyone()
+
+ memory_statistics = utils.get_memory_statistics()
+ logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
+
+ # Remove all hooks that might have been added during pipeline initialization to the models
+ pipeline.remove_all_hooks()
+ del pipeline
+ module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "vae"]
+ if self.args.enable_precomputation:
+ self._delete_components(module_names)
+ torch.cuda.reset_peak_memory_stats(parallel_backend.device)
+
+ # Gather artifacts from all processes. We also need to flatten them since each process returns a list of artifacts.
+ all_artifacts = [None] * dp_world_size
+ if dp_world_size > 1:
+ torch.distributed.all_gather_object(all_artifacts, all_processes_artifacts)
+ else:
+ all_artifacts = [all_processes_artifacts]
+ all_artifacts = [artifact for artifacts in all_artifacts for artifact in artifacts]
+
+ if parallel_backend.is_main_process:
+ tracker_key = "final" if final_validation else "validation"
+ artifact_log_dict = {}
+
+ image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
+ if len(image_artifacts) > 0:
+ artifact_log_dict["images"] = image_artifacts
+ video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
+ if len(video_artifacts) > 0:
+ artifact_log_dict["videos"] = video_artifacts
+ parallel_backend.log({tracker_key: artifact_log_dict}, step=step)
+
+ if self.args.push_to_hub and final_validation:
+ video_filenames = list(main_process_prompts_to_filenames.values())
+ prompts = list(main_process_prompts_to_filenames.keys())
+ utils.save_model_card(
+ args=self.args, repo_id=self.state.repo_id, videos=video_filenames, validation_prompts=prompts
+ )
+
+ parallel_backend.wait_for_everyone()
+ if not final_validation:
+ self._move_components_to_device()
+ self.transformer.train()
+
+ def _evaluate(self) -> None:
+ raise NotImplementedError("Evaluation has not been implemented yet.")
+
+ def _init_directories_and_repositories(self) -> None:
+ if self.state.parallel_backend.is_main_process:
+ self.args.output_dir = Path(self.args.output_dir)
+ self.args.output_dir.mkdir(parents=True, exist_ok=True)
+ self.state.output_dir = Path(self.args.output_dir)
+
+ if self.args.push_to_hub:
+ repo_id = self.args.hub_model_id or Path(self.args.output_dir).name
+ self.state.repo_id = create_repo(token=self.args.hub_token, repo_id=repo_id, exist_ok=True).repo_id
+
+ def _move_components_to_device(
+ self, components: Optional[List[torch.nn.Module]] = None, device: Optional[Union[str, torch.device]] = None
+ ) -> None:
+ if device is None:
+ device = self.state.parallel_backend.device
+ if components is None:
+ components = [self.text_encoder, self.text_encoder_2, self.text_encoder_3, self.transformer, self.vae]
+ components = utils.get_non_null_items(components)
+ components = list(filter(lambda x: hasattr(x, "to"), components))
+ for component in components:
+ component.to(device)
+
+ def _set_components(self, components: Dict[str, Any]) -> None:
+ for component_name in self._all_component_names:
+ existing_component = getattr(self, component_name, None)
+ new_component = components.get(component_name, existing_component)
+ setattr(self, component_name, new_component)
+
+ def _delete_components(self, component_names: Optional[List[str]] = None) -> None:
+ if component_names is None:
+ component_names = self._all_component_names
+ for component_name in component_names:
+ setattr(self, component_name, None)
+ utils.free_memory()
+ utils.synchronize_device()
+
+ def _init_pipeline(self, final_validation: bool = False) -> DiffusionPipeline:
+ parallel_backend = self.state.parallel_backend
+ module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae"]
+
+ if not final_validation:
+ module_names.remove("transformer")
+ pipeline = self.model_specification.load_pipeline(
+ tokenizer=self.tokenizer,
+ tokenizer_2=self.tokenizer_2,
+ tokenizer_3=self.tokenizer_3,
+ text_encoder=self.text_encoder,
+ text_encoder_2=self.text_encoder_2,
+ text_encoder_3=self.text_encoder_3,
+ # TODO(aryan): handle unwrapping for compiled modules
+ # transformer=utils.unwrap_model(accelerator, self.transformer),
+ transformer=self.transformer,
+ vae=self.vae,
+ enable_slicing=self.args.enable_slicing,
+ enable_tiling=self.args.enable_tiling,
+ enable_model_cpu_offload=self.args.enable_model_cpu_offload,
+ training=True,
+ )
+ else:
+ self._delete_components()
+
+ # TODO(aryan): allow multiple control conditions instead of just one if there's a use case for it
+ new_in_features = self.model_specification._original_control_layer_in_features * 2
+ if self.args.frame_conditioning_concatenate_mask:
+ new_in_features += 1
+ transformer = self.model_specification.load_diffusion_models(new_in_features)["transformer"]
+
+ pipeline = self.model_specification.load_pipeline(
+ transformer=transformer,
+ enable_slicing=self.args.enable_slicing,
+ enable_tiling=self.args.enable_tiling,
+ enable_model_cpu_offload=self.args.enable_model_cpu_offload,
+ training=False,
+ device=parallel_backend.device,
+ )
+
+ # Load the LoRA weights if performing LoRA finetuning
+ if self.args.training_type == TrainingType.CONTROL_LORA:
+ load_lora_weights(
+ pipeline, os.path.join(self.args.output_dir, "lora_weights", f"{self.state.train_state.step:06d}")
+ )
+ norm_state_dict_path = os.path.join(
+ self.args.output_dir,
+ "lora_weights",
+ f"{self.state.train_state.step:06d}",
+ "norm_state_dict.safetensors",
+ )
+ if self.args.train_qk_norm and norm_state_dict_path.exists():
+ norm_state_dict = safetensors.torch.load_file(norm_state_dict_path, parallel_backend.device)
+ self.transformer.load_state_dict(norm_state_dict)
+
+ components = {module_name: getattr(pipeline, module_name, None) for module_name in module_names}
+ self._set_components(components)
+ if not self.args.enable_model_cpu_offload:
+ self._move_components_to_device(list(components.values()))
+ self._maybe_torch_compile()
+ return pipeline
+
+ def _prepare_data(
+ self,
+ preprocessor: Union[data.InMemoryDistributedDataPreprocessor, data.PrecomputedDistributedDataPreprocessor],
+ data_iterator,
+ ):
+ if not self.args.enable_precomputation:
+ if not self._are_condition_models_loaded:
+ logger.info(
+ "Precomputation disabled. Loading in-memory data loaders. All components will be loaded on GPUs."
+ )
+ condition_components = self.model_specification.load_condition_models()
+ latent_components = self.model_specification.load_latent_models()
+ all_components = {**condition_components, **latent_components}
+ self._set_components(all_components)
+ self._move_components_to_device(list(all_components.values()))
+ utils._enable_vae_memory_optimizations(self.vae, self.args.enable_slicing, self.args.enable_tiling)
+ self._maybe_torch_compile()
+ else:
+ condition_components = {k: v for k in self._condition_component_names if (v := getattr(self, k, None))}
+ latent_components = {k: v for k in self._latent_component_names if (v := getattr(self, k, None))}
+
+ condition_iterator = preprocessor.consume(
+ "condition",
+ components=condition_components,
+ data_iterator=data_iterator,
+ generator=self.state.generator,
+ cache_samples=True,
+ )
+ latent_iterator = preprocessor.consume(
+ "latent",
+ components=latent_components,
+ data_iterator=data_iterator,
+ generator=self.state.generator,
+ use_cached_samples=True,
+ drop_samples=True,
+ )
+
+ self._are_condition_models_loaded = True
+ else:
+ logger.info("Precomputed condition & latent data exhausted. Loading & preprocessing new data.")
+
+ parallel_backend = self.state.parallel_backend
+ if parallel_backend.world_size == 1:
+ self._move_components_to_device([self.transformer], "cpu")
+ utils.free_memory()
+ utils.synchronize_device()
+ torch.cuda.reset_peak_memory_stats(parallel_backend.device)
+
+ consume_fn = preprocessor.consume_once if self.args.precomputation_once else preprocessor.consume
+
+ # Prepare condition iterators
+ condition_components, component_names, component_modules = {}, [], []
+ if not self.args.precomputation_reuse:
+ condition_components = self.model_specification.load_condition_models()
+ component_names = list(condition_components.keys())
+ component_modules = list(condition_components.values())
+ self._set_components(condition_components)
+ self._move_components_to_device(component_modules)
+ self._maybe_torch_compile()
+ condition_iterator = consume_fn(
+ "condition",
+ components=condition_components,
+ data_iterator=data_iterator,
+ generator=self.state.generator,
+ cache_samples=True,
+ )
+ self._delete_components(component_names)
+ del condition_components, component_names, component_modules
+
+ # Prepare latent iterators
+ latent_components, component_names, component_modules = {}, [], []
+ if not self.args.precomputation_reuse:
+ latent_components = self.model_specification.load_latent_models()
+ utils._enable_vae_memory_optimizations(self.vae, self.args.enable_slicing, self.args.enable_tiling)
+ component_names = list(latent_components.keys())
+ component_modules = list(latent_components.values())
+ self._set_components(latent_components)
+ self._move_components_to_device(component_modules)
+ self._maybe_torch_compile()
+ latent_iterator = consume_fn(
+ "latent",
+ components=latent_components,
+ data_iterator=data_iterator,
+ generator=self.state.generator,
+ use_cached_samples=True,
+ drop_samples=True,
+ )
+ self._delete_components(component_names)
+ del latent_components, component_names, component_modules
+
+ if parallel_backend.world_size == 1:
+ self._move_components_to_device([self.transformer])
+
+ return condition_iterator, latent_iterator
+
+ def _maybe_torch_compile(self):
+ for model_name, compile_scope in zip(self.args.compile_modules, self.args.compile_scopes):
+ model = getattr(self, model_name, None)
+ if model is not None:
+ logger.info(f"Applying torch.compile to '{model_name}' with scope '{compile_scope}'.")
+ compiled_model = utils.apply_compile(model, compile_scope)
+ setattr(self, model_name, compiled_model)
+
+ def _get_training_info(self) -> Dict[str, Any]:
+ info = self.args.to_dict()
+
+ # Removing flow matching arguments when not using flow-matching objective
+ diffusion_args = info.get("diffusion_arguments", {})
+ scheduler_name = self.scheduler.__class__.__name__ if self.scheduler is not None else ""
+ if scheduler_name != "FlowMatchEulerDiscreteScheduler":
+ filtered_diffusion_args = {k: v for k, v in diffusion_args.items() if "flow" not in k}
+ else:
+ filtered_diffusion_args = diffusion_args
+
+ info.update({"diffusion_arguments": filtered_diffusion_args})
+ return info
+
+ def _get_lora_target_modules(self):
+ target_modules = self.args.target_modules
+ if isinstance(target_modules, list):
+ target_modules = list(target_modules) # Make a copy to avoid modifying args
+ target_modules.append(f"^{self.model_specification.control_injection_layer_name}$")
+ if isinstance(target_modules, str):
+ target_modules = f"(^{self.model_specification.control_injection_layer_name}$)|({target_modules})"
+ return target_modules
+
+ # fmt: off
+ _all_component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "unet", "vae", "scheduler"]
+ _condition_component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3"]
+ _latent_component_names = ["vae"]
+ _diffusion_component_names = ["transformer", "unet", "scheduler"]
+ # fmt: on
diff --git a/docs/finetrainers-src-codebase/finetrainers/trainer/sft_trainer/__init__.py b/docs/finetrainers-src-codebase/finetrainers/trainer/sft_trainer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..726049bd5a9059c0df70efeacc76ac9f3423315a
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/trainer/sft_trainer/__init__.py
@@ -0,0 +1,2 @@
+from .config import SFTFullRankConfig, SFTLowRankConfig
+from .trainer import SFTTrainer
diff --git a/docs/finetrainers-src-codebase/finetrainers/trainer/sft_trainer/config.py b/docs/finetrainers-src-codebase/finetrainers/trainer/sft_trainer/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c70c6503373ed7e6aaf8c2b60fc4ba0a0f0f81a6
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/trainer/sft_trainer/config.py
@@ -0,0 +1,65 @@
+import argparse
+from typing import TYPE_CHECKING, Any, Dict, List, Union
+
+from finetrainers.utils import ArgsConfigMixin
+
+
+if TYPE_CHECKING:
+ from finetrainers.args import BaseArgs
+
+
+class SFTLowRankConfig(ArgsConfigMixin):
+ r"""
+ Configuration class for SFT low rank training.
+
+ Args:
+ rank (int):
+ Rank of the low rank approximation matrix.
+ lora_alpha (int):
+ The lora_alpha parameter to compute scaling factor (lora_alpha / rank) for low-rank matrices.
+ target_modules (`str` or `List[str]`):
+ Target modules for the low rank approximation matrices. Can be a regex string or a list of regex strings.
+ """
+
+ rank: int = 64
+ lora_alpha: int = 64
+ target_modules: Union[str, List[str]] = "(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0)"
+
+ def add_args(self, parser: argparse.ArgumentParser):
+ parser.add_argument("--rank", type=int, default=64)
+ parser.add_argument("--lora_alpha", type=int, default=64)
+ parser.add_argument(
+ "--target_modules",
+ type=str,
+ nargs="+",
+ default=["(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0)"],
+ )
+
+ def validate_args(self, args: "BaseArgs"):
+ assert self.rank > 0, "Rank must be a positive integer."
+ assert self.lora_alpha > 0, "lora_alpha must be a positive integer."
+
+ def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"):
+ mapped_args.rank = argparse_args.rank
+ mapped_args.lora_alpha = argparse_args.lora_alpha
+ mapped_args.target_modules = (
+ argparse_args.target_modules[0] if len(argparse_args.target_modules) == 1 else argparse_args.target_modules
+ )
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {"rank": self.rank, "lora_alpha": self.lora_alpha, "target_modules": self.target_modules}
+
+
+class SFTFullRankConfig(ArgsConfigMixin):
+ r"""
+ Configuration class for SFT full rank training.
+ """
+
+ def add_args(self, parser: argparse.ArgumentParser):
+ pass
+
+ def validate_args(self, args: "BaseArgs"):
+ pass
+
+ def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"):
+ pass
diff --git a/docs/finetrainers-src-codebase/finetrainers/trainer/sft_trainer/trainer.py b/docs/finetrainers-src-codebase/finetrainers/trainer/sft_trainer/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..789545963e533b8ee159a528bece5185f13634cc
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/trainer/sft_trainer/trainer.py
@@ -0,0 +1,946 @@
+import functools
+import json
+import os
+import time
+from pathlib import Path
+from typing import Any, Dict, Iterable, List, Optional, Union
+
+import datasets.distributed
+import torch
+import wandb
+from diffusers import DiffusionPipeline
+from diffusers.hooks import apply_layerwise_casting
+from diffusers.training_utils import cast_training_params
+from diffusers.utils import export_to_video
+from huggingface_hub import create_repo, upload_folder
+from peft import LoraConfig, get_peft_model_state_dict
+from tqdm import tqdm
+
+from finetrainers import data, logging, models, optimizer, parallel, utils
+from finetrainers.args import BaseArgsType
+from finetrainers.config import TrainingType
+from finetrainers.state import TrainState
+
+from ..base import Trainer
+from .config import SFTFullRankConfig, SFTLowRankConfig
+
+
+ArgsType = Union[BaseArgsType, SFTFullRankConfig, SFTLowRankConfig]
+
+logger = logging.get_logger()
+
+
+class SFTTrainer(Trainer):
+ def __init__(self, args: ArgsType, model_specification: models.ModelSpecification) -> None:
+ super().__init__(args)
+
+ # Tokenizers
+ self.tokenizer = None
+ self.tokenizer_2 = None
+ self.tokenizer_3 = None
+
+ # Text encoders
+ self.text_encoder = None
+ self.text_encoder_2 = None
+ self.text_encoder_3 = None
+
+ # Image encoders
+ self.image_encoder = None
+ self.image_processor = None
+
+ # Denoisers
+ self.transformer = None
+ self.unet = None
+
+ # Autoencoders
+ self.vae = None
+
+ # Scheduler
+ self.scheduler = None
+
+ # Optimizer & LR scheduler
+ self.optimizer = None
+ self.lr_scheduler = None
+
+ # Checkpoint manager
+ self.checkpointer = None
+
+ self.model_specification = model_specification
+ self._are_condition_models_loaded = False
+
+ def run(self) -> None:
+ try:
+ self._prepare_models()
+ self._prepare_trainable_parameters()
+ self._prepare_for_training()
+ self._prepare_dataset()
+ self._prepare_checkpointing()
+ self._train()
+ # trainer._evaluate()
+ except Exception as e:
+ logger.error(f"Error during training: {e}")
+ self.state.parallel_backend.destroy()
+ raise e
+
+ def _prepare_models(self) -> None:
+ logger.info("Initializing models")
+
+ diffusion_components = self.model_specification.load_diffusion_models()
+ self._set_components(diffusion_components)
+
+ if self.state.parallel_backend.pipeline_parallel_enabled:
+ raise NotImplementedError(
+ "Pipeline parallelism is not supported yet. This will be supported in the future."
+ )
+
+ def _prepare_trainable_parameters(self) -> None:
+ logger.info("Initializing trainable parameters")
+
+ parallel_backend = self.state.parallel_backend
+
+ if self.args.training_type == TrainingType.FULL_FINETUNE:
+ logger.info("Finetuning transformer with no additional parameters")
+ utils.set_requires_grad([self.transformer], True)
+ else:
+ logger.info("Finetuning transformer with PEFT parameters")
+ utils.set_requires_grad([self.transformer], False)
+
+ # Layerwise upcasting must be applied before adding the LoRA adapter.
+ # If we don't perform this before moving to device, we might OOM on the GPU. So, best to do it on
+ # CPU for now, before support is added in Diffusers for loading and enabling layerwise upcasting directly.
+ if self.args.training_type == TrainingType.LORA and "transformer" in self.args.layerwise_upcasting_modules:
+ apply_layerwise_casting(
+ self.transformer,
+ storage_dtype=self.args.layerwise_upcasting_storage_dtype,
+ compute_dtype=self.args.transformer_dtype,
+ skip_modules_pattern=self.args.layerwise_upcasting_skip_modules_pattern,
+ non_blocking=True,
+ )
+
+ transformer_lora_config = None
+ if self.args.training_type == TrainingType.LORA:
+ transformer_lora_config = LoraConfig(
+ r=self.args.rank,
+ lora_alpha=self.args.lora_alpha,
+ init_lora_weights=True,
+ target_modules=self.args.target_modules,
+ )
+ self.transformer.add_adapter(transformer_lora_config)
+
+ # Make sure the trainable params are in float32 if data sharding is not enabled. For FSDP, we need all
+ # parameters to be of the same dtype.
+ if parallel_backend.data_sharding_enabled:
+ self.transformer.to(dtype=self.args.transformer_dtype)
+ else:
+ if self.args.training_type == TrainingType.LORA:
+ cast_training_params([self.transformer], dtype=torch.float32)
+
+ def _prepare_for_training(self) -> None:
+ # 1. Apply parallelism
+ parallel_backend = self.state.parallel_backend
+ model_specification = self.model_specification
+
+ if parallel_backend.context_parallel_enabled:
+ parallel_backend.apply_context_parallel(self.transformer, parallel_backend.get_mesh()["cp"])
+
+ if parallel_backend.tensor_parallel_enabled:
+ # TODO(aryan): handle fp8 from TorchAO here
+ model_specification.apply_tensor_parallel(
+ backend=parallel.ParallelBackendEnum.PTD,
+ device_mesh=parallel_backend.get_mesh()["tp"],
+ transformer=self.transformer,
+ )
+
+ # Enable gradient checkpointing
+ if self.args.gradient_checkpointing:
+ # TODO(aryan): support other checkpointing types
+ utils.apply_activation_checkpointing(self.transformer, checkpointing_type="full")
+
+ # Apply torch.compile
+ self._maybe_torch_compile()
+
+ # Enable DDP, FSDP or HSDP
+ if parallel_backend.data_sharding_enabled:
+ # TODO(aryan): remove this when supported
+ if self.args.parallel_backend == "accelerate":
+ raise NotImplementedError("Data sharding is not supported with Accelerate yet.")
+
+ dp_method = "HSDP" if parallel_backend.data_replication_enabled else "FSDP"
+ logger.info(f"Applying {dp_method} on the model")
+
+ if parallel_backend.data_replication_enabled or parallel_backend.context_parallel_enabled:
+ dp_mesh_names = ("dp_replicate", "dp_shard_cp")
+ else:
+ dp_mesh_names = ("dp_shard_cp",)
+
+ parallel_backend.apply_fsdp2(
+ model=self.transformer,
+ param_dtype=self.args.transformer_dtype,
+ reduce_dtype=torch.float32,
+ output_dtype=None,
+ pp_enabled=parallel_backend.pipeline_parallel_enabled,
+ cpu_offload=False, # TODO(aryan): needs to be tested and allowed for enabling later
+ device_mesh=parallel_backend.get_mesh()[dp_mesh_names],
+ )
+ elif parallel_backend.data_replication_enabled:
+ if parallel_backend.get_mesh().ndim > 1:
+ raise ValueError("DDP not supported for > 1D parallelism")
+ logger.info("Applying DDP to the model")
+ parallel_backend.apply_ddp(self.transformer, parallel_backend.get_mesh())
+ else:
+ parallel_backend.prepare_model(self.transformer)
+
+ self._move_components_to_device()
+
+ # 2. Prepare optimizer and lr scheduler
+ # For training LoRAs, we can be a little more optimal. Currently, the OptimizerWrapper only accepts torch::nn::Module.
+ # This causes us to loop over all the parameters (even ones that don't require gradients, as in LoRA) at each optimizer
+ # step. This is OK (see https://github.com/pytorch/pytorch/blob/2f40f789dafeaa62c4e4b90dbf4a900ff6da2ca4/torch/optim/sgd.py#L85-L99)
+ # but can be optimized a bit by maybe creating a simple wrapper module encompassing the actual parameters that require
+ # gradients. TODO(aryan): look into it in the future.
+ model_parts = [self.transformer]
+ self.state.num_trainable_parameters = sum(
+ p.numel() for m in model_parts for p in m.parameters() if p.requires_grad
+ )
+
+ # Setup distributed optimizer and lr scheduler
+ logger.info("Initializing optimizer and lr scheduler")
+ self.state.train_state = TrainState()
+ self.optimizer = optimizer.get_optimizer(
+ parallel_backend=self.args.parallel_backend,
+ name=self.args.optimizer,
+ model_parts=model_parts,
+ learning_rate=self.args.lr,
+ beta1=self.args.beta1,
+ beta2=self.args.beta2,
+ beta3=self.args.beta3,
+ epsilon=self.args.epsilon,
+ weight_decay=self.args.weight_decay,
+ fused=False,
+ )
+ self.lr_scheduler = optimizer.get_lr_scheduler(
+ parallel_backend=self.args.parallel_backend,
+ name=self.args.lr_scheduler,
+ optimizer=self.optimizer,
+ num_warmup_steps=self.args.lr_warmup_steps,
+ num_training_steps=self.args.train_steps,
+ # TODO(aryan): handle last_epoch
+ )
+ self.optimizer, self.lr_scheduler = parallel_backend.prepare_optimizer(self.optimizer, self.lr_scheduler)
+
+ # 3. Initialize trackers, directories and repositories
+ self._init_logging()
+ self._init_trackers()
+ self._init_directories_and_repositories()
+
+ def _prepare_dataset(self) -> None:
+ logger.info("Initializing dataset and dataloader")
+
+ with open(self.args.dataset_config, "r") as file:
+ dataset_configs = json.load(file)["datasets"]
+ logger.info(f"Training configured to use {len(dataset_configs)} datasets")
+
+ datasets = []
+ for config in dataset_configs:
+ data_root = config.pop("data_root", None)
+ dataset_file = config.pop("dataset_file", None)
+ dataset_type = config.pop("dataset_type")
+ caption_options = config.pop("caption_options", {})
+
+ if data_root is not None and dataset_file is not None:
+ raise ValueError("Both data_root and dataset_file cannot be provided in the same dataset config.")
+
+ dataset_name_or_root = data_root or dataset_file
+ dataset = data.initialize_dataset(
+ dataset_name_or_root, dataset_type, streaming=True, infinite=True, _caption_options=caption_options
+ )
+
+ if not dataset._precomputable_once and self.args.precomputation_once:
+ raise ValueError(
+ f"Dataset {dataset_name_or_root} does not support precomputing all embeddings at once."
+ )
+
+ logger.info(f"Initialized dataset: {dataset_name_or_root}")
+ dataset = self.state.parallel_backend.prepare_dataset(dataset)
+ dataset = data.wrap_iterable_dataset_for_preprocessing(dataset, dataset_type, config)
+ datasets.append(dataset)
+
+ dataset = data.combine_datasets(datasets, buffer_size=self.args.dataset_shuffle_buffer_size, shuffle=True)
+ dataloader = self.state.parallel_backend.prepare_dataloader(
+ dataset, batch_size=1, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.pin_memory
+ )
+
+ self.dataset = dataset
+ self.dataloader = dataloader
+
+ def _prepare_checkpointing(self) -> None:
+ parallel_backend = self.state.parallel_backend
+
+ def save_model_hook(state_dict: Dict[str, Any]) -> None:
+ state_dict = utils.get_unwrapped_model_state_dict(state_dict)
+ if parallel_backend.is_main_process:
+ if self.args.training_type == TrainingType.LORA:
+ state_dict = get_peft_model_state_dict(self.transformer, state_dict)
+ # fmt: off
+ metadata = {
+ "r": self.args.rank,
+ "lora_alpha": self.args.lora_alpha,
+ "init_lora_weights": True,
+ "target_modules": self.args.target_modules,
+ }
+ metadata = {"lora_config": json.dumps(metadata, indent=4)}
+ # fmt: on
+ self.model_specification._save_lora_weights(
+ os.path.join(self.args.output_dir, "lora_weights", f"{self.state.train_state.step:06d}"),
+ state_dict,
+ self.scheduler,
+ metadata,
+ )
+ elif self.args.training_type == TrainingType.FULL_FINETUNE:
+ self.model_specification._save_model(
+ os.path.join(self.args.output_dir, "model_weights", f"{self.state.train_state.step:06d}"),
+ self.transformer,
+ state_dict,
+ self.scheduler,
+ )
+ parallel_backend.wait_for_everyone()
+
+ enable_state_checkpointing = self.args.checkpointing_steps > 0
+ self.checkpointer = parallel_backend.get_checkpointer(
+ dataloader=self.dataloader,
+ model_parts=[self.transformer],
+ optimizers=self.optimizer,
+ schedulers=self.lr_scheduler,
+ states={"train_state": self.state.train_state},
+ checkpointing_steps=self.args.checkpointing_steps,
+ checkpointing_limit=self.args.checkpointing_limit,
+ output_dir=self.args.output_dir,
+ enable=enable_state_checkpointing,
+ _callback_fn=save_model_hook,
+ )
+
+ resume_from_checkpoint = self.args.resume_from_checkpoint
+ if resume_from_checkpoint == "latest":
+ resume_from_checkpoint = -1
+ if resume_from_checkpoint is not None:
+ self.checkpointer.load(resume_from_checkpoint)
+
+ def _train(self) -> None:
+ logger.info("Starting training")
+
+ parallel_backend = self.state.parallel_backend
+ train_state = self.state.train_state
+ device = parallel_backend.device
+ dtype = self.args.transformer_dtype
+
+ memory_statistics = utils.get_memory_statistics()
+ logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
+
+ global_batch_size = self.args.batch_size * parallel_backend._dp_degree
+ info = {
+ "trainable parameters": self.state.num_trainable_parameters,
+ "train steps": self.args.train_steps,
+ "per-replica batch size": self.args.batch_size,
+ "global batch size": global_batch_size,
+ "gradient accumulation steps": self.args.gradient_accumulation_steps,
+ }
+ logger.info(f"Training configuration: {json.dumps(info, indent=4)}")
+
+ progress_bar = tqdm(
+ range(0, self.args.train_steps),
+ initial=train_state.step,
+ desc="Training steps",
+ disable=not parallel_backend.is_local_main_process,
+ )
+
+ generator = torch.Generator(device=device)
+ if self.args.seed is not None:
+ generator = generator.manual_seed(self.args.seed)
+ self.state.generator = generator
+
+ scheduler_sigmas = utils.get_scheduler_sigmas(self.scheduler)
+ scheduler_sigmas = (
+ scheduler_sigmas.to(device=device, dtype=torch.float32) if scheduler_sigmas is not None else None
+ )
+ scheduler_alphas = utils.get_scheduler_alphas(self.scheduler)
+ scheduler_alphas = (
+ scheduler_alphas.to(device=device, dtype=torch.float32) if scheduler_alphas is not None else None
+ )
+ # timesteps_buffer = []
+
+ self.transformer.train()
+ data_iterator = iter(self.dataloader)
+
+ compute_posterior = False if self.args.enable_precomputation else (not self.args.precomputation_once)
+ preprocessor = data.initialize_preprocessor(
+ rank=parallel_backend.rank,
+ world_size=parallel_backend.world_size,
+ num_items=self.args.precomputation_items if self.args.enable_precomputation else 1,
+ processor_fn={
+ "condition": self.model_specification.prepare_conditions,
+ "latent": functools.partial(
+ self.model_specification.prepare_latents, compute_posterior=compute_posterior
+ ),
+ },
+ save_dir=self.args.precomputation_dir,
+ enable_precomputation=self.args.enable_precomputation,
+ enable_reuse=self.args.precomputation_reuse,
+ )
+ condition_iterator: Iterable[Dict[str, Any]] = None
+ latent_iterator: Iterable[Dict[str, Any]] = None
+ sampler = data.ResolutionSampler(
+ batch_size=self.args.batch_size, dim_keys=self.model_specification._resolution_dim_keys
+ )
+ requires_gradient_step = True
+ accumulated_loss = 0.0
+
+ while (
+ train_state.step < self.args.train_steps and train_state.observed_data_samples < self.args.max_data_samples
+ ):
+ # 1. Load & preprocess data if required
+ if preprocessor.requires_data:
+ condition_iterator, latent_iterator = self._prepare_data(preprocessor, data_iterator)
+
+ # 2. Prepare batch
+ with self.tracker.timed("timing/batch_preparation"):
+ try:
+ condition_item = next(condition_iterator)
+ latent_item = next(latent_iterator)
+ sampler.consume(condition_item, latent_item)
+ except StopIteration:
+ if requires_gradient_step:
+ self.optimizer.step()
+ self.lr_scheduler.step()
+ requires_gradient_step = False
+ logger.info("Data exhausted. Exiting training loop.")
+ break
+
+ if sampler.is_ready:
+ condition_batch, latent_batch = sampler.get_batch()
+ condition_model_conditions = self.model_specification.collate_conditions(condition_batch)
+ latent_model_conditions = self.model_specification.collate_latents(latent_batch)
+ else:
+ continue
+
+ train_state.step += 1
+ train_state.observed_data_samples += self.args.batch_size * parallel_backend._dp_degree
+
+ logger.debug(f"Starting training step ({train_state.step}/{self.args.train_steps})")
+
+ latent_model_conditions = utils.align_device_and_dtype(latent_model_conditions, device, dtype)
+ condition_model_conditions = utils.align_device_and_dtype(condition_model_conditions, device, dtype)
+ latent_model_conditions = utils.make_contiguous(latent_model_conditions)
+ condition_model_conditions = utils.make_contiguous(condition_model_conditions)
+
+ # 3. Forward pass
+ sigmas = utils.prepare_sigmas(
+ scheduler=self.scheduler,
+ sigmas=scheduler_sigmas,
+ batch_size=self.args.batch_size,
+ num_train_timesteps=self.scheduler.config.num_train_timesteps,
+ flow_weighting_scheme=self.args.flow_weighting_scheme,
+ flow_logit_mean=self.args.flow_logit_mean,
+ flow_logit_std=self.args.flow_logit_std,
+ flow_mode_scale=self.args.flow_mode_scale,
+ device=device,
+ generator=self.state.generator,
+ )
+ sigmas = utils.expand_tensor_dims(sigmas, latent_model_conditions["latents"].ndim)
+
+ # NOTE: for planned refactor, make sure that forward and backward pass run under the context.
+ # If only forward runs under context, backward will most likely fail when using activation checkpointing
+ with self.attention_provider_ctx(training=True):
+ with self.tracker.timed("timing/forward"):
+ pred, target, sigmas = self.model_specification.forward(
+ transformer=self.transformer,
+ scheduler=self.scheduler,
+ condition_model_conditions=condition_model_conditions,
+ latent_model_conditions=latent_model_conditions,
+ sigmas=sigmas,
+ compute_posterior=compute_posterior,
+ )
+
+ timesteps = (sigmas * 1000.0).long()
+ weights = utils.prepare_loss_weights(
+ scheduler=self.scheduler,
+ alphas=scheduler_alphas[timesteps] if scheduler_alphas is not None else None,
+ sigmas=sigmas,
+ flow_weighting_scheme=self.args.flow_weighting_scheme,
+ )
+ weights = utils.expand_tensor_dims(weights, pred.ndim)
+
+ # 4. Compute loss & backward pass
+ with self.tracker.timed("timing/backward"):
+ loss = weights.float() * (pred.float() - target.float()).pow(2)
+ # Average loss across all but batch dimension (for per-batch debugging in case needed)
+ loss = loss.mean(list(range(1, loss.ndim)))
+ # Average loss across batch dimension
+ loss = loss.mean()
+ if self.args.gradient_accumulation_steps > 1:
+ loss = loss / self.args.gradient_accumulation_steps
+ loss.backward()
+
+ accumulated_loss += loss.detach().item()
+ requires_gradient_step = True
+
+ # 5. Clip gradients
+ model_parts = [self.transformer]
+ grad_norm = utils.torch._clip_grad_norm_while_handling_failing_dtensor_cases(
+ [p for m in model_parts for p in m.parameters()],
+ self.args.max_grad_norm,
+ foreach=True,
+ pp_mesh=parallel_backend.get_mesh()["pp"] if parallel_backend.pipeline_parallel_enabled else None,
+ )
+
+ # 6. Step optimizer & log metrics
+ logs = {}
+
+ if train_state.step % self.args.gradient_accumulation_steps == 0:
+ # TODO(aryan): revisit no_sync() for FSDP
+ with self.tracker.timed("timing/optimizer_step"):
+ self.optimizer.step()
+ self.lr_scheduler.step()
+ self.optimizer.zero_grad()
+
+ if grad_norm is not None:
+ grad_norm = grad_norm if isinstance(grad_norm, float) else grad_norm.detach().item()
+ if (
+ parallel_backend.data_replication_enabled
+ or parallel_backend.data_sharding_enabled
+ or parallel_backend.context_parallel_enabled
+ ):
+ dp_cp_mesh = parallel_backend.get_mesh()["dp_cp"]
+ if grad_norm is not None:
+ grad_norm = parallel.dist_mean(torch.tensor([grad_norm], device=device), dp_cp_mesh)
+ global_avg_loss, global_max_loss = (
+ parallel.dist_mean(torch.tensor([accumulated_loss], device=device), dp_cp_mesh),
+ parallel.dist_max(torch.tensor([accumulated_loss], device=device), dp_cp_mesh),
+ )
+ else:
+ global_avg_loss = global_max_loss = accumulated_loss
+
+ logs["train/global_avg_loss"] = global_avg_loss
+ logs["train/global_max_loss"] = global_max_loss
+ if grad_norm is not None:
+ logs["train/grad_norm"] = grad_norm
+ train_state.global_avg_losses.append(global_avg_loss)
+ train_state.global_max_losses.append(global_max_loss)
+ accumulated_loss = 0.0
+ requires_gradient_step = False
+
+ progress_bar.update(1)
+ progress_bar.set_postfix(logs)
+
+ # timesteps_buffer.extend([(train_state.step, t) for t in timesteps.detach().cpu().numpy().tolist()])
+
+ if train_state.step % self.args.logging_steps == 0:
+ # TODO(aryan): handle non-SchedulerWrapper schedulers (probably not required eventually) since they might not be dicts
+ # TODO(aryan): causes NCCL hang for some reason. look into later
+ # logs.update(self.lr_scheduler.get_last_lr())
+
+ # timesteps_table = wandb.Table(data=timesteps_buffer, columns=["step", "timesteps"])
+ # logs["timesteps"] = wandb.plot.scatter(
+ # timesteps_table, "step", "timesteps", title="Timesteps distribution"
+ # )
+ # timesteps_buffer = []
+
+ logs["train/observed_data_samples"] = train_state.observed_data_samples
+
+ parallel_backend.log(logs, step=train_state.step)
+ train_state.log_steps.append(train_state.step)
+
+ # 7. Save checkpoint if required
+ with self.tracker.timed("timing/checkpoint"):
+ self.checkpointer.save(
+ step=train_state.step, _device=device, _is_main_process=parallel_backend.is_main_process
+ )
+
+ # 8. Perform validation if required
+ if train_state.step % self.args.validation_steps == 0:
+ self._validate(step=train_state.step, final_validation=False)
+
+ # 9. Final checkpoint, validation & cleanup
+ self.checkpointer.save(
+ train_state.step, force=True, _device=device, _is_main_process=parallel_backend.is_main_process
+ )
+ parallel_backend.wait_for_everyone()
+ self._validate(step=train_state.step, final_validation=True)
+
+ self._delete_components()
+ memory_statistics = utils.get_memory_statistics()
+ logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}")
+
+ # 10. Upload artifacts to hub
+ if parallel_backend.is_main_process and self.args.push_to_hub:
+ upload_folder(
+ repo_id=self.state.repo_id,
+ folder_path=self.args.output_dir,
+ ignore_patterns=[f"{self.checkpointer._prefix}_*"],
+ )
+
+ parallel_backend.destroy()
+
+ def _validate(self, step: int, final_validation: bool = False) -> None:
+ if self.args.validation_dataset_file is None:
+ return
+
+ logger.info("Starting validation")
+
+ # 1. Load validation dataset
+ parallel_backend = self.state.parallel_backend
+ dataset = data.ValidationDataset(self.args.validation_dataset_file)
+
+ # Hack to make accelerate work. TODO(aryan): refactor
+ if parallel_backend._dp_degree > 1:
+ dp_mesh = parallel_backend.get_mesh()["dp"]
+ dp_local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size()
+ dataset._data = datasets.distributed.split_dataset_by_node(dataset._data, dp_local_rank, dp_world_size)
+ else:
+ dp_mesh = None
+ dp_local_rank, dp_world_size = parallel_backend.local_rank, 1
+
+ validation_dataloader = data.DPDataLoader(
+ dp_local_rank,
+ dataset,
+ batch_size=1,
+ num_workers=self.args.dataloader_num_workers,
+ collate_fn=lambda items: items,
+ )
+ data_iterator = iter(validation_dataloader)
+ main_process_prompts_to_filenames = {} # Used to save model card
+ all_processes_artifacts = [] # Used to gather artifacts from all processes
+
+ memory_statistics = utils.get_memory_statistics()
+ logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
+
+ seed = self.args.seed if self.args.seed is not None else 0
+ generator = torch.Generator(device=parallel_backend.device).manual_seed(seed)
+ pipeline = self._init_pipeline(final_validation=final_validation)
+
+ # 2. Run validation
+ # TODO(aryan): when running validation with FSDP, if the number of data points is not divisible by dp_shards, we
+ # will hang indefinitely. Either pad the dataset or raise an error early on during initialization if the dataset
+ # size is not divisible by dp_shards.
+ self.transformer.eval()
+ while True:
+ validation_data = next(data_iterator, None)
+ if validation_data is None:
+ break
+
+ validation_data = validation_data[0]
+ with self.attention_provider_ctx(training=False):
+ validation_artifacts = self.model_specification.validation(
+ pipeline=pipeline, generator=generator, **validation_data
+ )
+
+ if dp_local_rank != 0:
+ continue
+
+ PROMPT = validation_data["prompt"]
+ IMAGE = validation_data.get("image", None)
+ VIDEO = validation_data.get("video", None)
+ EXPORT_FPS = validation_data.get("export_fps", 30)
+
+ # 2.1. If there are any initial images or videos, they will be logged to keep track of them as
+ # conditioning for generation.
+ prompt_filename = utils.string_to_filename(PROMPT)[:25]
+ artifacts = {
+ "input_image": data.ImageArtifact(value=IMAGE),
+ "input_video": data.VideoArtifact(value=VIDEO),
+ }
+
+ # 2.2. Track the artifacts generated from validation
+ for i, validation_artifact in enumerate(validation_artifacts):
+ if validation_artifact.value is None:
+ continue
+ artifacts.update({f"artifact_{i}": validation_artifact})
+
+ # 2.3. Save the artifacts to the output directory and create appropriate logging objects
+ # TODO(aryan): Currently, we only support WandB so we've hardcoded it here. Needs to be revisited.
+ for index, (key, artifact) in enumerate(list(artifacts.items())):
+ assert isinstance(artifact, (data.ImageArtifact, data.VideoArtifact))
+ if artifact.value is None:
+ continue
+
+ time_, rank, ext = int(time.time()), parallel_backend.rank, artifact.file_extension
+ filename = "validation-" if not final_validation else "final-"
+ filename += f"{step}-{rank}-{index}-{prompt_filename}-{time_}.{ext}"
+ output_filename = os.path.join(self.args.output_dir, filename)
+
+ if parallel_backend.is_main_process and ext in ["mp4", "jpg", "jpeg", "png"]:
+ main_process_prompts_to_filenames[PROMPT] = filename
+
+ if isinstance(artifact, data.ImageArtifact):
+ artifact.value.save(output_filename)
+ all_processes_artifacts.append(wandb.Image(output_filename, caption=PROMPT))
+ elif isinstance(artifact, data.VideoArtifact):
+ export_to_video(artifact.value, output_filename, fps=EXPORT_FPS)
+ all_processes_artifacts.append(wandb.Video(output_filename, caption=PROMPT))
+
+ # 3. Cleanup & log artifacts
+ parallel_backend.wait_for_everyone()
+
+ memory_statistics = utils.get_memory_statistics()
+ logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
+
+ # Remove all hooks that might have been added during pipeline initialization to the models
+ pipeline.remove_all_hooks()
+ del pipeline
+ module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "image_encoder", "image_processor", "vae"]
+ if self.args.enable_precomputation:
+ self._delete_components(module_names)
+ torch.cuda.reset_peak_memory_stats(parallel_backend.device)
+
+ # Gather artifacts from all processes. We also need to flatten them since each process returns a list of artifacts.
+ all_artifacts = [None] * dp_world_size
+ if dp_world_size > 1:
+ torch.distributed.all_gather_object(all_artifacts, all_processes_artifacts)
+ else:
+ all_artifacts = [all_processes_artifacts]
+ all_artifacts = [artifact for artifacts in all_artifacts for artifact in artifacts]
+
+ if parallel_backend.is_main_process:
+ tracker_key = "final" if final_validation else "validation"
+ artifact_log_dict = {}
+
+ image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
+ if len(image_artifacts) > 0:
+ artifact_log_dict["images"] = image_artifacts
+ video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
+ if len(video_artifacts) > 0:
+ artifact_log_dict["videos"] = video_artifacts
+ parallel_backend.log({tracker_key: artifact_log_dict}, step=step)
+
+ if self.args.push_to_hub and final_validation:
+ video_filenames = list(main_process_prompts_to_filenames.values())
+ prompts = list(main_process_prompts_to_filenames.keys())
+ utils.save_model_card(
+ args=self.args, repo_id=self.state.repo_id, videos=video_filenames, validation_prompts=prompts
+ )
+
+ parallel_backend.wait_for_everyone()
+ if not final_validation:
+ self._move_components_to_device()
+ self.transformer.train()
+
+ def _evaluate(self) -> None:
+ raise NotImplementedError("Evaluation has not been implemented yet.")
+
+ def _init_directories_and_repositories(self) -> None:
+ if self.state.parallel_backend.is_main_process:
+ self.args.output_dir = Path(self.args.output_dir)
+ self.args.output_dir.mkdir(parents=True, exist_ok=True)
+ self.state.output_dir = Path(self.args.output_dir)
+
+ if self.args.push_to_hub:
+ repo_id = self.args.hub_model_id or Path(self.args.output_dir).name
+ self.state.repo_id = create_repo(token=self.args.hub_token, repo_id=repo_id, exist_ok=True).repo_id
+
+ def _move_components_to_device(
+ self, components: Optional[List[torch.nn.Module]] = None, device: Optional[Union[str, torch.device]] = None
+ ) -> None:
+ if device is None:
+ device = self.state.parallel_backend.device
+ if components is None:
+ components = [
+ self.text_encoder,
+ self.text_encoder_2,
+ self.text_encoder_3,
+ self.image_encoder,
+ self.transformer,
+ self.vae,
+ ]
+ components = utils.get_non_null_items(components)
+ components = list(filter(lambda x: hasattr(x, "to"), components))
+ for component in components:
+ component.to(device)
+
+ def _set_components(self, components: Dict[str, Any]) -> None:
+ for component_name in self._all_component_names:
+ existing_component = getattr(self, component_name, None)
+ new_component = components.get(component_name, existing_component)
+ setattr(self, component_name, new_component)
+
+ def _delete_components(self, component_names: Optional[List[str]] = None) -> None:
+ if component_names is None:
+ component_names = self._all_component_names
+ for component_name in component_names:
+ setattr(self, component_name, None)
+ utils.free_memory()
+ utils.synchronize_device()
+
+ def _init_pipeline(self, final_validation: bool = False) -> DiffusionPipeline:
+ module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "image_encoder", "transformer", "vae"]
+
+ if not final_validation:
+ module_names.remove("transformer")
+ pipeline = self.model_specification.load_pipeline(
+ tokenizer=self.tokenizer,
+ tokenizer_2=self.tokenizer_2,
+ tokenizer_3=self.tokenizer_3,
+ text_encoder=self.text_encoder,
+ text_encoder_2=self.text_encoder_2,
+ text_encoder_3=self.text_encoder_3,
+ image_encoder=self.image_encoder,
+ image_processor=self.image_processor,
+ # TODO(aryan): handle unwrapping for compiled modules
+ # transformer=utils.unwrap_model(accelerator, self.transformer),
+ transformer=self.transformer,
+ vae=self.vae,
+ enable_slicing=self.args.enable_slicing,
+ enable_tiling=self.args.enable_tiling,
+ enable_model_cpu_offload=self.args.enable_model_cpu_offload,
+ training=True,
+ )
+ else:
+ self._delete_components()
+
+ # Load the transformer weights from the final checkpoint if performing full-finetune
+ transformer = None
+ if self.args.training_type == TrainingType.FULL_FINETUNE:
+ transformer = self.model_specification.load_diffusion_models()["transformer"]
+
+ pipeline = self.model_specification.load_pipeline(
+ transformer=transformer,
+ enable_slicing=self.args.enable_slicing,
+ enable_tiling=self.args.enable_tiling,
+ enable_model_cpu_offload=self.args.enable_model_cpu_offload,
+ training=False,
+ )
+
+ # Load the LoRA weights if performing LoRA finetuning
+ if self.args.training_type == TrainingType.LORA:
+ pipeline.load_lora_weights(
+ os.path.join(self.args.output_dir, "lora_weights", f"{self.state.train_state.step:06d}")
+ )
+
+ components = {module_name: getattr(pipeline, module_name, None) for module_name in module_names}
+ self._set_components(components)
+ if not self.args.enable_model_cpu_offload:
+ self._move_components_to_device(list(components.values()))
+ self._maybe_torch_compile()
+ return pipeline
+
+ def _prepare_data(
+ self,
+ preprocessor: Union[data.InMemoryDistributedDataPreprocessor, data.PrecomputedDistributedDataPreprocessor],
+ data_iterator,
+ ):
+ if not self.args.enable_precomputation:
+ if not self._are_condition_models_loaded:
+ logger.info(
+ "Precomputation disabled. Loading in-memory data loaders. All components will be loaded on GPUs."
+ )
+ condition_components = self.model_specification.load_condition_models()
+ latent_components = self.model_specification.load_latent_models()
+ all_components = {**condition_components, **latent_components}
+ self._set_components(all_components)
+ self._move_components_to_device(list(all_components.values()))
+ utils._enable_vae_memory_optimizations(self.vae, self.args.enable_slicing, self.args.enable_tiling)
+ self._maybe_torch_compile()
+ else:
+ condition_components = {k: v for k in self._condition_component_names if (v := getattr(self, k, None))}
+ latent_components = {k: v for k in self._latent_component_names if (v := getattr(self, k, None))}
+
+ condition_iterator = preprocessor.consume(
+ "condition",
+ components=condition_components,
+ data_iterator=data_iterator,
+ generator=self.state.generator,
+ cache_samples=True,
+ )
+ latent_iterator = preprocessor.consume(
+ "latent",
+ components=latent_components,
+ data_iterator=data_iterator,
+ generator=self.state.generator,
+ use_cached_samples=True,
+ drop_samples=True,
+ )
+
+ self._are_condition_models_loaded = True
+ else:
+ logger.info("Precomputed condition & latent data exhausted. Loading & preprocessing new data.")
+
+ parallel_backend = self.state.parallel_backend
+ if parallel_backend.world_size == 1:
+ self._move_components_to_device([self.transformer], "cpu")
+ utils.free_memory()
+ utils.synchronize_device()
+ torch.cuda.reset_peak_memory_stats(parallel_backend.device)
+
+ consume_fn = preprocessor.consume_once if self.args.precomputation_once else preprocessor.consume
+
+ # Prepare condition iterators
+ condition_components, component_names, component_modules = {}, [], []
+ if not self.args.precomputation_reuse:
+ condition_components = self.model_specification.load_condition_models()
+ component_names = list(condition_components.keys())
+ component_modules = list(condition_components.values())
+ self._set_components(condition_components)
+ self._move_components_to_device(component_modules)
+ self._maybe_torch_compile()
+ condition_iterator = consume_fn(
+ "condition",
+ components=condition_components,
+ data_iterator=data_iterator,
+ generator=self.state.generator,
+ cache_samples=True,
+ )
+ self._delete_components(component_names)
+ del condition_components, component_names, component_modules
+
+ # Prepare latent iterators
+ latent_components, component_names, component_modules = {}, [], []
+ if not self.args.precomputation_reuse:
+ latent_components = self.model_specification.load_latent_models()
+ utils._enable_vae_memory_optimizations(self.vae, self.args.enable_slicing, self.args.enable_tiling)
+ component_names = list(latent_components.keys())
+ component_modules = list(latent_components.values())
+ self._set_components(latent_components)
+ self._move_components_to_device(component_modules)
+ self._maybe_torch_compile()
+ latent_iterator = consume_fn(
+ "latent",
+ components=latent_components,
+ data_iterator=data_iterator,
+ generator=self.state.generator,
+ use_cached_samples=True,
+ drop_samples=True,
+ )
+ self._delete_components(component_names)
+ del latent_components, component_names, component_modules
+
+ if parallel_backend.world_size == 1:
+ self._move_components_to_device([self.transformer])
+
+ return condition_iterator, latent_iterator
+
+ def _maybe_torch_compile(self):
+ for model_name, compile_scope in zip(self.args.compile_modules, self.args.compile_scopes):
+ model = getattr(self, model_name, None)
+ if model is not None:
+ logger.info(f"Applying torch.compile to '{model_name}' with scope '{compile_scope}'.")
+ compiled_model = utils.apply_compile(model, compile_scope)
+ setattr(self, model_name, compiled_model)
+
+ def _get_training_info(self) -> Dict[str, Any]:
+ info = self.args.to_dict()
+
+ # Removing flow matching arguments when not using flow-matching objective
+ diffusion_args = info.get("diffusion_arguments", {})
+ scheduler_name = self.scheduler.__class__.__name__ if self.scheduler is not None else ""
+ if scheduler_name != "FlowMatchEulerDiscreteScheduler":
+ filtered_diffusion_args = {k: v for k, v in diffusion_args.items() if "flow" not in k}
+ else:
+ filtered_diffusion_args = diffusion_args
+
+ info.update({"diffusion_arguments": filtered_diffusion_args})
+ return info
+
+ # fmt: off
+ _all_component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3", "image_encoder", "image_processor", "transformer", "unet", "vae", "scheduler"]
+ _condition_component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3"]
+ _latent_component_names = ["image_encoder", "image_processor", "vae"]
+ _diffusion_component_names = ["transformer", "unet", "scheduler"]
+ # fmt: on
diff --git a/docs/finetrainers-src-codebase/finetrainers/typing.py b/docs/finetrainers-src-codebase/finetrainers/typing.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7b3b339f252d8f47ef0ff67aa6c6733a2ccd7cf
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/typing.py
@@ -0,0 +1,11 @@
+from typing import Union
+
+from diffusers import CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler
+from transformers import CLIPTokenizer, LlamaTokenizer, LlamaTokenizerFast, T5Tokenizer, T5TokenizerFast
+
+from .data import ImageArtifact, VideoArtifact
+
+
+ArtifactType = Union[ImageArtifact, VideoArtifact]
+SchedulerType = Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]
+TokenizerType = Union[CLIPTokenizer, T5Tokenizer, T5TokenizerFast, LlamaTokenizer, LlamaTokenizerFast]
diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/__init__.py b/docs/finetrainers-src-codebase/finetrainers/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..56fd3b2819959cc968b338bc6e7bd7051b05c77b
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/utils/__init__.py
@@ -0,0 +1,51 @@
+import inspect
+from typing import Any, Dict, List, Optional, Set, Tuple, Union
+
+from .activation_checkpoint import apply_activation_checkpointing
+from .args_config import ArgsConfigMixin
+from .data import determine_batch_size, should_perform_precomputation
+from .diffusion import (
+ _enable_vae_memory_optimizations,
+ default_flow_shift,
+ get_scheduler_alphas,
+ get_scheduler_sigmas,
+ prepare_loss_weights,
+ prepare_sigmas,
+ prepare_target,
+ resolution_dependent_timestep_flow_shift,
+)
+from .file import delete_files, find_files, string_to_filename
+from .hub import save_model_card
+from .memory import bytes_to_gigabytes, free_memory, get_memory_statistics, make_contiguous
+from .model import resolve_component_cls
+from .serialization import safetensors_torch_save_function
+from .timing import Timer, TimerDevice
+from .torch import (
+ align_device_and_dtype,
+ apply_compile,
+ clip_grad_norm_,
+ enable_determinism,
+ expand_tensor_dims,
+ get_device_info,
+ get_submodule_by_name,
+ get_unwrapped_model_state_dict,
+ is_compiled_module,
+ set_requires_grad,
+ synchronize_device,
+ unwrap_module,
+)
+
+
+def get_parameter_names(obj: Any, method_name: Optional[str] = None) -> Set[str]:
+ if method_name is not None:
+ obj = getattr(obj, method_name)
+ return {name for name, _ in inspect.signature(obj).parameters.items()}
+
+
+def get_non_null_items(
+ x: Union[List[Any], Tuple[Any], Dict[str, Any]],
+) -> Union[List[Any], Tuple[Any], Dict[str, Any]]:
+ if isinstance(x, dict):
+ return {k: v for k, v in x.items() if v is not None}
+ if isinstance(x, (list, tuple)):
+ return type(x)(v for v in x if v is not None)
diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/_common.py b/docs/finetrainers-src-codebase/finetrainers/utils/_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..a54abe26707181d6f4795e99f24fae34b911b2b9
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/utils/_common.py
@@ -0,0 +1,7 @@
+DIFFUSERS_TRANSFORMER_BLOCK_NAMES = [
+ "transformer_blocks",
+ "single_transformer_blocks",
+ "temporal_transformer_blocks",
+ "blocks",
+ "layers",
+]
diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/activation_checkpoint.py b/docs/finetrainers-src-codebase/finetrainers/utils/activation_checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc4193a6cc027a771fe1fc2c3cb34595fbc336b2
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/utils/activation_checkpoint.py
@@ -0,0 +1,71 @@
+import collections
+from enum import Enum
+
+import torch
+from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper
+
+from ._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES
+
+
+class CheckpointType(str, Enum):
+ FULL = "full"
+ OPS = "ops"
+ BLOCK_SKIP = "block_skip"
+
+
+_SELECTIVE_ACTIVATION_CHECKPOINTING_OPS = {
+ torch.ops.aten.mm.default,
+ torch.ops.aten._scaled_dot_product_efficient_attention.default,
+ torch.ops.aten._scaled_dot_product_flash_attention.default,
+ torch.ops._c10d_functional.reduce_scatter_tensor.default,
+}
+
+
+def apply_activation_checkpointing(
+ module: torch.nn.Module, checkpointing_type: str = CheckpointType.FULL, n_layer: int = 1
+) -> torch.nn.Module:
+ if checkpointing_type == CheckpointType.FULL:
+ module = _apply_activation_checkpointing_blocks(module)
+ elif checkpointing_type == CheckpointType.OPS:
+ module = _apply_activation_checkpointing_ops(module, _SELECTIVE_ACTIVATION_CHECKPOINTING_OPS)
+ elif checkpointing_type == CheckpointType.BLOCK_SKIP:
+ module = _apply_activation_checkpointing_blocks(module, n_layer)
+ else:
+ raise ValueError(
+ f"Checkpointing type '{checkpointing_type}' not supported. Supported types are {CheckpointType.__members__.keys()}"
+ )
+ return module
+
+
+def _apply_activation_checkpointing_blocks(module: torch.nn.Module, n_layer: int = None) -> torch.nn.Module:
+ for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES:
+ blocks: torch.nn.Module = getattr(module, transformer_block_name, None)
+ if blocks is None:
+ continue
+ for index, (layer_id, block) in enumerate(blocks.named_children()):
+ if n_layer is None or index % n_layer == 0:
+ block = checkpoint_wrapper(block, preserve_rng_state=False)
+ blocks.register_module(layer_id, block)
+ return module
+
+
+def _apply_activation_checkpointing_ops(module: torch.nn.Module, ops) -> torch.nn.Module:
+ from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
+
+ def _get_custom_policy(meta):
+ def _custom_policy(ctx, func, *args, **kwargs):
+ mode = "recompute" if ctx.is_recompute else "forward"
+ mm_count_key = f"{mode}_mm_count"
+ if func == torch.ops.aten.mm.default:
+ meta[mm_count_key] += 1
+ # Saves output of all compute ops, except every second mm
+ to_save = func in ops and not (func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0)
+ return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE
+
+ return _custom_policy
+
+ def selective_checkpointing_context_fn():
+ meta = collections.defaultdict(int)
+ return create_selective_checkpoint_contexts(_get_custom_policy(meta))
+
+ return checkpoint_wrapper(module, context_fn=selective_checkpointing_context_fn, preserve_rng_state=False)
diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/args_config.py b/docs/finetrainers-src-codebase/finetrainers/utils/args_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..64a1ed0754116615cc885202da218da920ce52fb
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/utils/args_config.py
@@ -0,0 +1,20 @@
+import argparse
+from typing import TYPE_CHECKING, Any, Dict
+
+
+if TYPE_CHECKING:
+ from finetrainers.args import BaseArgs
+
+
+class ArgsConfigMixin:
+ def add_args(self, parser: argparse.ArgumentParser):
+ raise NotImplementedError("ArgsConfigMixin::add_args should be implemented by subclasses.")
+
+ def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"):
+ raise NotImplementedError("ArgsConfigMixin::map_args should be implemented by subclasses.")
+
+ def validate_args(self, args: "BaseArgs"):
+ raise NotImplementedError("ArgsConfigMixin::validate_args should be implemented by subclasses.")
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {}
diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/data.py b/docs/finetrainers-src-codebase/finetrainers/utils/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecebdcf90b5d1ff719f2d5b18d5946bf47bde97a
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/utils/data.py
@@ -0,0 +1,51 @@
+from pathlib import Path
+from typing import Any, Union
+
+import torch
+
+from finetrainers.constants import PRECOMPUTED_CONDITIONS_DIR_NAME, PRECOMPUTED_LATENTS_DIR_NAME
+from finetrainers.logging import get_logger
+
+
+logger = get_logger()
+
+
+def should_perform_precomputation(precomputation_dir: Union[str, Path]) -> bool:
+ if isinstance(precomputation_dir, str):
+ precomputation_dir = Path(precomputation_dir)
+ conditions_dir = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME
+ latents_dir = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME
+ if conditions_dir.exists() and latents_dir.exists():
+ num_files_conditions = len(list(conditions_dir.glob("*.pt")))
+ num_files_latents = len(list(latents_dir.glob("*.pt")))
+ if num_files_conditions != num_files_latents:
+ logger.warning(
+ f"Number of precomputed conditions ({num_files_conditions}) does not match number of precomputed latents ({num_files_latents})."
+ f"Cleaning up precomputed directories and re-running precomputation."
+ )
+ # clean up precomputed directories
+ for file in conditions_dir.glob("*.pt"):
+ file.unlink()
+ for file in latents_dir.glob("*.pt"):
+ file.unlink()
+ return True
+ if num_files_conditions > 0:
+ logger.info(f"Found {num_files_conditions} precomputed conditions and latents.")
+ return False
+ logger.info("Precomputed data not found. Running precomputation.")
+ return True
+
+
+def determine_batch_size(x: Any) -> int:
+ if isinstance(x, list):
+ return len(x)
+ if isinstance(x, torch.Tensor):
+ return x.size(0)
+ if isinstance(x, dict):
+ for key in x:
+ try:
+ return determine_batch_size(x[key])
+ except ValueError:
+ pass
+ return 1
+ raise ValueError("Could not determine batch size from input.")
diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/diffusion.py b/docs/finetrainers-src-codebase/finetrainers/utils/diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ed3746c160b7aa1ea96fb382ccbece85db6ae42
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/utils/diffusion.py
@@ -0,0 +1,152 @@
+import math
+from typing import Optional, Union
+
+import torch
+from diffusers import CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler
+from diffusers.training_utils import compute_loss_weighting_for_sd3
+
+
+# Default values copied from https://github.com/huggingface/diffusers/blob/8957324363d8b239d82db4909fbf8c0875683e3d/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L47
+def resolution_dependent_timestep_flow_shift(
+ latents: torch.Tensor,
+ sigmas: torch.Tensor,
+ base_image_seq_len: int = 256,
+ max_image_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+) -> torch.Tensor:
+ image_or_video_sequence_length = 0
+ if latents.ndim == 4:
+ image_or_video_sequence_length = latents.shape[2] * latents.shape[3]
+ elif latents.ndim == 5:
+ image_or_video_sequence_length = latents.shape[2] * latents.shape[3] * latents.shape[4]
+ else:
+ raise ValueError(f"Expected 4D or 5D tensor, got {latents.ndim}D tensor")
+
+ m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len)
+ b = base_shift - m * base_image_seq_len
+ mu = m * image_or_video_sequence_length + b
+ sigmas = default_flow_shift(latents, sigmas, shift=mu)
+ return sigmas
+
+
+def default_flow_shift(sigmas: torch.Tensor, shift: float = 1.0) -> torch.Tensor:
+ sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
+ return sigmas
+
+
+def compute_density_for_timestep_sampling(
+ weighting_scheme: str,
+ batch_size: int,
+ logit_mean: float = None,
+ logit_std: float = None,
+ mode_scale: float = None,
+ device: torch.device = torch.device("cpu"),
+ generator: Optional[torch.Generator] = None,
+) -> torch.Tensor:
+ r"""
+ Compute the density for sampling the timesteps when doing SD3 training.
+
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
+
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
+ """
+ if weighting_scheme == "logit_normal":
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator)
+ u = torch.nn.functional.sigmoid(u)
+ elif weighting_scheme == "mode":
+ u = torch.rand(size=(batch_size,), device=device, generator=generator)
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
+ else:
+ u = torch.rand(size=(batch_size,), device=device, generator=generator)
+ return u
+
+
+def get_scheduler_alphas(scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]) -> torch.Tensor:
+ if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
+ return None
+ elif isinstance(scheduler, CogVideoXDDIMScheduler):
+ return scheduler.alphas_cumprod.clone()
+ else:
+ raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
+
+
+def get_scheduler_sigmas(scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]) -> torch.Tensor:
+ if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
+ return scheduler.sigmas.clone()
+ elif isinstance(scheduler, CogVideoXDDIMScheduler):
+ return scheduler.timesteps.clone().float() / float(scheduler.config.num_train_timesteps)
+ else:
+ raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
+
+
+def prepare_sigmas(
+ scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler],
+ sigmas: torch.Tensor,
+ batch_size: int,
+ num_train_timesteps: int,
+ flow_weighting_scheme: str = "none",
+ flow_logit_mean: float = 0.0,
+ flow_logit_std: float = 1.0,
+ flow_mode_scale: float = 1.29,
+ device: torch.device = torch.device("cpu"),
+ generator: Optional[torch.Generator] = None,
+) -> torch.Tensor:
+ if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
+ weights = compute_density_for_timestep_sampling(
+ weighting_scheme=flow_weighting_scheme,
+ batch_size=batch_size,
+ logit_mean=flow_logit_mean,
+ logit_std=flow_logit_std,
+ mode_scale=flow_mode_scale,
+ device=device,
+ generator=generator,
+ )
+ indices = (weights * num_train_timesteps).long()
+ elif isinstance(scheduler, CogVideoXDDIMScheduler):
+ # TODO(aryan): Currently, only uniform sampling is supported. Add more sampling schemes.
+ weights = torch.rand(size=(batch_size,), device=device, generator=generator)
+ indices = (weights * num_train_timesteps).long()
+ else:
+ raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
+
+ return sigmas[indices]
+
+
+def prepare_loss_weights(
+ scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler],
+ alphas: Optional[torch.Tensor] = None,
+ sigmas: Optional[torch.Tensor] = None,
+ flow_weighting_scheme: str = "none",
+) -> torch.Tensor:
+ if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
+ return compute_loss_weighting_for_sd3(sigmas=sigmas, weighting_scheme=flow_weighting_scheme)
+ elif isinstance(scheduler, CogVideoXDDIMScheduler):
+ # SNR is computed as (alphas / (1 - alphas)), but for some reason CogVideoX uses 1 / (1 - alphas).
+ # TODO(aryan): Experiment if using alphas / (1 - alphas) gives better results.
+ return 1 / (1 - alphas)
+ else:
+ raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
+
+
+def prepare_target(
+ scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler],
+ noise: torch.Tensor,
+ latents: torch.Tensor,
+) -> torch.Tensor:
+ if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
+ target = noise - latents
+ elif isinstance(scheduler, CogVideoXDDIMScheduler):
+ target = latents
+ else:
+ raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
+
+ return target
+
+
+def _enable_vae_memory_optimizations(vae, enable_slicing: bool = False, enable_tiling: bool = False):
+ if hasattr(vae, "enable_slicing") and enable_slicing:
+ vae.enable_slicing()
+ if hasattr(vae, "enable_tiling") and enable_tiling:
+ vae.enable_tiling()
diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/file.py b/docs/finetrainers-src-codebase/finetrainers/utils/file.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd8308b32c5be0427e7c1b8fd3d1f978aad002a1
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/utils/file.py
@@ -0,0 +1,51 @@
+import pathlib
+import shutil
+from pathlib import Path
+from typing import List, Union
+
+from finetrainers.logging import get_logger
+
+
+logger = get_logger()
+
+
+def find_files(root: str, pattern: str, depth: int = 0) -> List[str]:
+ root_path = pathlib.Path(root)
+ result_files = []
+
+ def within_depth(path: pathlib.Path) -> bool:
+ return len(path.relative_to(root_path).parts) <= depth
+
+ if depth == 0:
+ result_files.extend([str(file) for file in root_path.glob(pattern)])
+ else:
+ for file in root_path.rglob(pattern):
+ if not file.is_file() or not within_depth(file.parent):
+ continue
+ result_files.append(str(file))
+
+ return result_files
+
+
+def delete_files(dirs: Union[str, List[str], Path, List[Path]]) -> None:
+ if not isinstance(dirs, list):
+ dirs = [dirs]
+ dirs = [Path(d) if isinstance(d, str) else d for d in dirs]
+ logger.debug(f"Deleting files: {dirs}")
+ for dir in dirs:
+ if not dir.exists():
+ continue
+ shutil.rmtree(dir, ignore_errors=True)
+
+
+def string_to_filename(s: str) -> str:
+ return (
+ s.replace(" ", "-")
+ .replace("/", "-")
+ .replace(":", "-")
+ .replace(".", "-")
+ .replace(",", "-")
+ .replace(";", "-")
+ .replace("!", "-")
+ .replace("?", "-")
+ )
diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/hub.py b/docs/finetrainers-src-codebase/finetrainers/utils/hub.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea1a16eb42cbb1f2848376440817a3e1680ce61c
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/utils/hub.py
@@ -0,0 +1,77 @@
+import os
+from typing import List, Union
+
+import numpy as np
+import wandb
+from diffusers.utils import export_to_video
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from PIL import Image
+
+
+def save_model_card(
+ args,
+ repo_id: str,
+ videos: Union[List[str], Union[List[Image.Image], List[np.ndarray]]],
+ validation_prompts: List[str],
+ fps: int = 30,
+) -> None:
+ widget_dict = []
+ output_dir = str(args.output_dir)
+ if videos is not None and len(videos) > 0:
+ for i, (video, validation_prompt) in enumerate(zip(videos, validation_prompts)):
+ if not isinstance(video, str):
+ export_to_video(video, os.path.join(output_dir, f"final_video_{i}.mp4"), fps=fps)
+ widget_dict.append(
+ {
+ "text": validation_prompt if validation_prompt else " ",
+ "output": {"url": video if isinstance(video, str) else f"final_video_{i}.mp4"},
+ }
+ )
+
+ model_description = f"""
+# LoRA Finetune
+
+
+
+## Model description
+
+This is a lora finetune of model: `{args.pretrained_model_name_or_path}`.
+
+The model was trained using [`finetrainers`](https://github.com/a-r-r-o-w/finetrainers).
+
+## Download model
+
+[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
+
+## Usage
+
+Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
+
+```py
+TODO
+```
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
+"""
+ if wandb.run.url:
+ model_description += f"""
+Find out the wandb run URL and training configurations [here]({wandb.run.url}).
+"""
+
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ base_model=args.pretrained_model_name_or_path,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-video",
+ "diffusers-training",
+ "diffusers",
+ "lora",
+ "template:sd-lora",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(args.output_dir, "README.md"))
diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/import_utils.py b/docs/finetrainers-src-codebase/finetrainers/utils/import_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4ac2470313ca116ec1dd6ec88cfd36fadd28b0b
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/utils/import_utils.py
@@ -0,0 +1,129 @@
+import importlib
+import importlib.util
+import operator as op
+from typing import Union
+
+import importlib_metadata
+from packaging.version import Version, parse
+
+from finetrainers.logging import get_logger
+
+
+logger = get_logger()
+
+STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
+
+
+# This function was copied from: https://github.com/huggingface/diffusers/blob/5873377a660dac60a6bd86ef9b4fdfc385305977/src/diffusers/utils/import_utils.py#L57
+def _is_package_available(pkg_name: str):
+ pkg_exists = importlib.util.find_spec(pkg_name) is not None
+ pkg_version = "N/A"
+
+ if pkg_exists:
+ try:
+ pkg_version = importlib_metadata.version(pkg_name)
+ logger.debug(f"Successfully imported {pkg_name} version {pkg_version}")
+ except (ImportError, importlib_metadata.PackageNotFoundError):
+ pkg_exists = False
+
+ return pkg_exists, pkg_version
+
+
+# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
+def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
+ """
+ Compares a library version to some requirement using a given operation.
+
+ Args:
+ library_or_version (`str` or `packaging.version.Version`):
+ A library name or a version to check.
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`.
+ requirement_version (`str`):
+ The version to compare the library version against
+ """
+ if operation not in STR_OPERATION_TO_FUNC.keys():
+ raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
+ operation = STR_OPERATION_TO_FUNC[operation]
+ if isinstance(library_or_version, str):
+ library_or_version = parse(importlib_metadata.version(library_or_version))
+ return operation(library_or_version, parse(requirement_version))
+
+
+_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
+_datasets_available, _datasets_version = _is_package_available("datasets")
+_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
+_kornia_available, _kornia_version = _is_package_available("kornia")
+_sageattention_available, _sageattention_version = _is_package_available("sageattention")
+_torch_available, _torch_version = _is_package_available("torch")
+_xformers_available, _xformers_version = _is_package_available("xformers")
+
+
+def is_bitsandbytes_available():
+ return _bitsandbytes_available
+
+
+def is_datasets_available():
+ return _datasets_available
+
+
+def is_flash_attn_available():
+ return _flash_attn_available
+
+
+def is_kornia_available():
+ return _kornia_available
+
+
+def is_sageattention_available():
+ return _sageattention_available
+
+
+def is_torch_available():
+ return _torch_available
+
+
+def is_xformers_available():
+ return _xformers_available
+
+
+def is_bitsandbytes_version(operation: str, version: str):
+ if not _bitsandbytes_available:
+ return False
+ return compare_versions(parse(_bitsandbytes_version), operation, version)
+
+
+def is_datasets_version(operation: str, version: str):
+ if not _datasets_available:
+ return False
+ return compare_versions(parse(_datasets_version), operation, version)
+
+
+def is_flash_attn_version(operation: str, version: str):
+ if not _flash_attn_available:
+ return False
+ return compare_versions(parse(_flash_attn_version), operation, version)
+
+
+def is_kornia_version(operation: str, version: str):
+ if not _kornia_available:
+ return False
+ return compare_versions(parse(_kornia_version), operation, version)
+
+
+def is_sageattention_version(operation: str, version: str):
+ if not _sageattention_available:
+ return False
+ return compare_versions(parse(_sageattention_version), operation, version)
+
+
+def is_torch_version(operation: str, version: str):
+ if not _torch_available:
+ return False
+ return compare_versions(parse(_torch_version), operation, version)
+
+
+def is_xformers_version(operation: str, version: str):
+ if not _xformers_available:
+ return False
+ return compare_versions(parse(_xformers_version), operation, version)
diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/memory.py b/docs/finetrainers-src-codebase/finetrainers/utils/memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..99f5239f32b02ce78e8f38d0212ac58378fc2772
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/utils/memory.py
@@ -0,0 +1,59 @@
+import gc
+from typing import Any, Dict, Union
+
+import torch
+
+from finetrainers.logging import get_logger
+
+
+logger = get_logger()
+
+
+def get_memory_statistics(precision: int = 3) -> Dict[str, Any]:
+ memory_allocated = None
+ memory_reserved = None
+ max_memory_allocated = None
+ max_memory_reserved = None
+
+ if torch.cuda.is_available():
+ device = torch.cuda.current_device()
+ memory_allocated = torch.cuda.memory_allocated(device)
+ memory_reserved = torch.cuda.memory_reserved(device)
+ max_memory_allocated = torch.cuda.max_memory_allocated(device)
+ max_memory_reserved = torch.cuda.max_memory_reserved(device)
+
+ elif torch.backends.mps.is_available():
+ memory_allocated = torch.mps.current_allocated_memory()
+
+ else:
+ logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.")
+
+ return {
+ "memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision),
+ "memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision),
+ "max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision),
+ "max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision),
+ }
+
+
+def bytes_to_gigabytes(x: int) -> float:
+ if x is not None:
+ return x / 1024**3
+
+
+def free_memory() -> None:
+ if torch.cuda.is_available():
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+
+ # TODO(aryan): handle non-cuda devices
+
+
+def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ if isinstance(x, torch.Tensor):
+ return x.contiguous()
+ elif isinstance(x, dict):
+ return {k: make_contiguous(v) for k, v in x.items()}
+ else:
+ return x
diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/model.py b/docs/finetrainers-src-codebase/finetrainers/utils/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4427f97d25ed44b2d9832cf456b082f65d66c2a8
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/utils/model.py
@@ -0,0 +1,32 @@
+import importlib
+import json
+import os
+from typing import Optional
+
+from huggingface_hub import hf_hub_download
+
+
+def resolve_component_cls(
+ pretrained_model_name_or_path: str,
+ component_name: str,
+ filename: str = "model_index.json",
+ revision: Optional[str] = None,
+ cache_dir: Optional[str] = None,
+):
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ if os.path.exists(str(pretrained_model_name_or_path)) and os.path.isdir(pretrained_model_name_or_path):
+ index_path = os.path.join(pretrained_model_name_or_path, filename)
+ else:
+ index_path = hf_hub_download(
+ repo_id=pretrained_model_name_or_path, filename=filename, revision=revision, cache_dir=cache_dir
+ )
+
+ with open(index_path, "r") as f:
+ model_index_dict = json.load(f)
+
+ if component_name not in model_index_dict:
+ raise ValueError(f"No {component_name} found in the model index dict.")
+
+ cls_config = model_index_dict[component_name]
+ library = importlib.import_module(cls_config[0])
+ return getattr(library, cls_config[1])
diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/serialization.py b/docs/finetrainers-src-codebase/finetrainers/utils/serialization.py
new file mode 100644
index 0000000000000000000000000000000000000000..d15b53ae28253cce99d13924885f5f9af7f1ff20
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/utils/serialization.py
@@ -0,0 +1,10 @@
+from typing import Any, Dict, Optional
+
+import safetensors.torch
+
+
+def safetensors_torch_save_function(weights: Dict[str, Any], filename: str, metadata: Optional[Dict[str, str]] = None):
+ if metadata is None:
+ metadata = {}
+ metadata["format"] = "pt"
+ safetensors.torch.save_file(weights, filename, metadata)
diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/timing.py b/docs/finetrainers-src-codebase/finetrainers/utils/timing.py
new file mode 100644
index 0000000000000000000000000000000000000000..99faf75770ae24fc7ad4dd863add4ea698b7968b
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/utils/timing.py
@@ -0,0 +1,108 @@
+import time
+from dataclasses import dataclass
+from enum import Enum
+
+import torch
+
+from finetrainers.constants import FINETRAINERS_ENABLE_TIMING
+from finetrainers.logging import get_logger
+
+
+logger = get_logger()
+
+
+class TimerDevice(str, Enum):
+ CPU = "cpu"
+ CUDA = "cuda"
+
+
+@dataclass
+class TimerData:
+ name: str
+ device: TimerDevice
+ start_time: float = 0.0
+ end_time: float = 0.0
+
+
+class Timer:
+ def __init__(self, name: str, device: TimerDevice, device_sync: bool = False):
+ self.data = TimerData(name=name, device=device)
+
+ self._device_sync = device_sync
+ self._start_event = None
+ self._end_event = None
+ self._active = False
+ self._enabled = FINETRAINERS_ENABLE_TIMING
+
+ def __enter__(self):
+ self.start()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.end()
+ return False
+
+ def start(self):
+ if self._active:
+ logger.warning(f"Timer {self.data.name} is already running. Please stop it before starting again.")
+ return
+ self._active = True
+ if not self._enabled:
+ return
+ if self.data.device == TimerDevice.CUDA and torch.cuda.is_available():
+ self._start_cuda()
+ else:
+ self._start_cpu()
+ if not self.data.device == TimerDevice.CPU:
+ logger.warning(
+ f"Timer device {self.data.device} is either not supported or incorrect device selected. Falling back to CPU."
+ )
+
+ def end(self):
+ if not self._active:
+ logger.warning(f"Timer {self.data.name} is not running. Please start it before stopping.")
+ return
+ self._active = False
+ if not self._enabled:
+ return
+ if self.data.device == TimerDevice.CUDA and torch.cuda.is_available():
+ self._end_cuda()
+ else:
+ self._end_cpu()
+ if not self.data.device == TimerDevice.CPU:
+ logger.warning(
+ f"Timer device {self.data.device} is either not supported or incorrect device selected. Falling back to CPU."
+ )
+
+ @property
+ def elapsed_time(self) -> float:
+ if self._active:
+ if self.data.device == TimerDevice.CUDA and torch.cuda.is_available():
+ premature_end_event = torch.cuda.Event(enable_timing=True)
+ premature_end_event.record()
+ premature_end_event.synchronize()
+ return self._start_event.elapsed_time(premature_end_event) / 1000.0
+ else:
+ return time.time() - self.data.start_time
+ else:
+ if self.data.device == TimerDevice.CUDA and torch.cuda.is_available():
+ return self._start_event.elapsed_time(self._end_event) / 1000.0
+ else:
+ return self.data.end_time - self.data.start_time
+
+ def _start_cpu(self):
+ self.data.start_time = time.time()
+
+ def _start_cuda(self):
+ torch.cuda.synchronize()
+ self._start_event = torch.cuda.Event(enable_timing=True)
+ self._end_event = torch.cuda.Event(enable_timing=True)
+ self._start_event.record()
+
+ def _end_cpu(self):
+ self.data.end_time = time.time()
+
+ def _end_cuda(self):
+ if self._device_sync:
+ torch.cuda.synchronize()
+ self._end_event.record()
diff --git a/docs/finetrainers-src-codebase/finetrainers/utils/torch.py b/docs/finetrainers-src-codebase/finetrainers/utils/torch.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bf064db8c923bd0e345651d3463bcb9caa3baa6
--- /dev/null
+++ b/docs/finetrainers-src-codebase/finetrainers/utils/torch.py
@@ -0,0 +1,395 @@
+import math
+import os
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.backends
+import torch.distributed as dist
+import torch.distributed.tensor
+
+from finetrainers.logging import get_logger
+
+
+logger = get_logger()
+
+_STRING_TO_DTYPE = {
+ "fp32": torch.float32,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+}
+
+_DTYPE_TO_STRING = {v: k for k, v in _STRING_TO_DTYPE.items()}
+
+_HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES = False
+
+
+def align_device_and_dtype(
+ x: Union[torch.Tensor, Dict[str, torch.Tensor]],
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+):
+ if isinstance(x, torch.Tensor):
+ if device is not None:
+ x = x.to(device)
+ if dtype is not None:
+ x = x.to(dtype)
+ elif isinstance(x, dict):
+ if device is not None:
+ x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
+ if dtype is not None:
+ x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
+ return x
+
+
+def apply_compile(model: torch.nn.Module, compile_scope: str) -> torch.nn.Module:
+ r"""Apply torch.compile to a model or its submodules if not already compiled."""
+ if getattr(model, "_torch_compiled", False):
+ return model # Already compiled
+
+ if compile_scope == "full":
+ model = torch.compile(model)
+ setattr(model, "_torch_compiled", True)
+ elif compile_scope == "regional":
+ if isinstance(model, torch.nn.ModuleList):
+ for name, module in model.named_children():
+ if not getattr(module, "_torch_compiled", False):
+ compiled_module = torch.compile(module)
+ setattr(compiled_module, "_torch_compiled", True)
+ model.register_module(name, compiled_module)
+ else:
+ for name, module in model.named_children():
+ apply_compile(module, compile_scope)
+ else:
+ raise ValueError(f"Unknown compile mode: {compile_scope}. Use 'full' or 'regional'.")
+
+ return model
+
+
+def _clip_grad_norm_while_handling_failing_dtensor_cases(
+ parameters: Union[torch.Tensor, List[torch.Tensor]],
+ max_norm: float,
+ norm_type: float = 2.0,
+ error_if_nonfinite: bool = False,
+ foreach: Optional[bool] = None,
+ pp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
+) -> Optional[torch.Tensor]:
+ global _HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES
+
+ if not _HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES:
+ try:
+ return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach, pp_mesh)
+ except NotImplementedError as e:
+ if "DTensor does not support cross-mesh operation" in str(e):
+ # https://github.com/pytorch/pytorch/issues/134212
+ logger.warning(
+ "DTensor does not support cross-mesh operation. If you haven't fully tensor-parallelized your "
+ "model, while combining other parallelisms such as FSDP, it could be the reason for this error. "
+ "Gradient clipping will be skipped and gradient norm will not be logged."
+ )
+ except Exception as e:
+ logger.warning(
+ f"An error occurred while clipping gradients: {e}. Gradient clipping will be skipped and gradient "
+ f"norm will not be logged."
+ )
+ _HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES = True
+ return None
+
+
+# Copied from https://github.com/pytorch/torchtitan/blob/4a169701555ab9bd6ca3769f9650ae3386b84c6e/torchtitan/utils.py#L362
+@torch.no_grad()
+def clip_grad_norm_(
+ parameters: Union[torch.Tensor, List[torch.Tensor]],
+ max_norm: float,
+ norm_type: float = 2.0,
+ error_if_nonfinite: bool = False,
+ foreach: Optional[bool] = None,
+ pp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
+) -> torch.Tensor:
+ r"""
+ Clip the gradient norm of parameters.
+
+ Gradient norm clipping requires computing the gradient norm over the entire model.
+ `torch.nn.utils.clip_grad_norm_` only computes gradient norm along DP/FSDP/TP dimensions.
+ We need to manually reduce the gradient norm across PP stages.
+ See https://github.com/pytorch/torchtitan/issues/596 for details.
+
+ Args:
+ parameters (`torch.Tensor` or `List[torch.Tensor]`):
+ Tensors that will have gradients normalized.
+ max_norm (`float`):
+ Maximum norm of the gradients after clipping.
+ norm_type (`float`, defaults to `2.0`):
+ Type of p-norm to use. Can be `inf` for infinity norm.
+ error_if_nonfinite (`bool`, defaults to `False`):
+ If `True`, an error is thrown if the total norm of the gradients from `parameters` is `nan`, `inf`, or `-inf`.
+ foreach (`bool`, defaults to `None`):
+ Use the faster foreach-based implementation. If `None`, use the foreach implementation for CUDA and CPU native tensors
+ and silently fall back to the slow implementation for other device types.
+ pp_mesh (`torch.distributed.device_mesh.DeviceMesh`, defaults to `None`):
+ Pipeline parallel device mesh. If not `None`, will reduce gradient norm across PP stages.
+
+ Returns:
+ `torch.Tensor`:
+ Total norm of the gradients
+ """
+ grads = [p.grad for p in parameters if p.grad is not None]
+
+ # TODO(aryan): Wait for next Pytorch release to use `torch.nn.utils.get_total_norm`
+ # total_norm = torch.nn.utils.get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
+ total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
+
+ # If total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`.
+ # We can simply reduce the DTensor to get the total norm in this tensor's process group
+ # and then convert it to a local tensor.
+ # It has two purposes:
+ # 1. to make sure the total norm is computed correctly when PP is used (see below)
+ # 2. to return a reduced total_norm tensor whose .item() would return the correct value
+ if isinstance(total_norm, torch.distributed.tensor.DTensor):
+ # Will reach here if any non-PP parallelism is used.
+ # If only using PP, total_norm will be a local tensor.
+ total_norm = total_norm.full_tensor()
+
+ if pp_mesh is not None:
+ if math.isinf(norm_type):
+ dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group())
+ else:
+ total_norm **= norm_type
+ dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group())
+ total_norm **= 1.0 / norm_type
+
+ _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
+ return total_norm
+
+
+def enable_determinism(
+ seed: int,
+ world_mesh: Optional[torch.distributed.DeviceMesh] = None,
+ deterministic: bool = False,
+) -> None:
+ r"""
+ For all ranks within the same DTensor SPMD group, the same seed will be set.
+ For PP groups, different seeds will be set.
+ """
+ if deterministic:
+ logger.info("Deterministic algorithms are enabled (expect performance degradation).")
+ torch.use_deterministic_algorithms(True)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+
+ if not world_mesh:
+ if seed is not None:
+ torch.manual_seed(seed)
+ os.environ["PYTHONHASHSEED"] = str(seed % 2**32)
+ logger.debug(f"Single-process job using seed: {seed}")
+ return
+
+ # For PP + SPMD cases, we want to separate the world into the SPMD mesh and the PP mesh,
+ # and choose a unique seed for each rank on the PP mesh.
+ if torch.distributed.distributed_c10d.get_world_size() > 1 and "pp" in world_mesh.mesh_dim_names:
+ pp_mesh = world_mesh["pp"]
+ seed += pp_mesh.get_local_rank()
+ seed %= 2**64
+
+ info = {
+ "pp_rank": pp_mesh.get_local_rank(),
+ "global_rank": torch.distributed.distributed_c10d.get_rank(),
+ "seed": seed,
+ }
+ logger.debug(f"Enabling determinism: {info}")
+ spmd_mesh_dims = list(filter(lambda name: name != "pp", world_mesh.mesh_dim_names))
+ spmd_mesh = world_mesh[spmd_mesh_dims] if len(spmd_mesh_dims) else None
+ else:
+ spmd_mesh = world_mesh
+ info = {"global_rank": torch.distributed.distributed_c10d.get_rank(), "seed": seed}
+ logger.debug(f"Enabling determinism: {info}")
+
+ # The native RNGs and python RNG may not be important, except for the 1-D PP case, but we seed them for consistency
+ torch.manual_seed(seed)
+ # PYTHONHASHSEED can be a decimal number in the range [0, 2**32 - 1]
+ os.environ["PYTHONHASHSEED"] = str(seed % 2**32)
+
+ # As long as we are not in the 1-D (PP-only) case, we will have a seed to use for all ranks of the SPMD mesh.
+ # IF PP is also used, this seed is unique per PP rank.
+ if spmd_mesh and spmd_mesh.get_coordinate() is not None:
+ torch.distributed.tensor._random.manual_seed(seed, spmd_mesh)
+
+
+def expand_tensor_dims(tensor: torch.Tensor, ndim: int) -> torch.Tensor:
+ assert len(tensor.shape) <= ndim
+ return tensor.reshape(tensor.shape + (1,) * (ndim - len(tensor.shape)))
+
+
+def get_device_info():
+ from torch._utils import _get_available_device_type, _get_device_module
+
+ device_type = _get_available_device_type()
+ if device_type is None:
+ device_type = "cuda"
+ device_module = _get_device_module(device_type)
+ return device_type, device_module
+
+
+def get_dtype_from_string(dtype: str):
+ return _STRING_TO_DTYPE[dtype]
+
+
+def get_string_from_dtype(dtype: torch.dtype):
+ return _DTYPE_TO_STRING[dtype]
+
+
+def get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
+ assert name.count("*") <= 1, "Wildcard '*' can only be used once in the name"
+ return _find_submodule_by_name(model, name)
+
+
+def get_unwrapped_model_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ # Remove _orig_mod occurrences from the state dict keys
+ return {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
+
+
+def is_compiled_module(module) -> bool:
+ return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
+
+
+def set_requires_grad(models: Union[torch.nn.Module, List[torch.nn.Module]], value: bool) -> None:
+ if isinstance(models, torch.nn.Module):
+ models = [models]
+ for model in models:
+ if model is not None:
+ model.requires_grad_(value)
+
+
+def synchronize_device() -> None:
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ elif torch.backends.mps.is_available():
+ torch.mps.synchronize()
+
+
+def unwrap_module(module):
+ """Unwraps a module if it was compiled with torch.compile()"""
+ return module._orig_mod if is_compiled_module(module) else module
+
+
+def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
+ if name == "":
+ return model
+ first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
+ if first_atom == "*":
+ # Wildcard '*' can only be used once in the name
+ assert isinstance(model, torch.nn.ModuleList), "Wildcard '*' can only be used with ModuleList"
+ submodules = []
+ for submodule in model:
+ subsubmodules = _find_submodule_by_name(submodule, remaining_name)
+ if not isinstance(subsubmodules, list):
+ subsubmodules = [subsubmodules]
+ submodules.extend(subsubmodules)
+ return submodules
+ else:
+ if hasattr(model, first_atom):
+ submodule = getattr(model, first_atom)
+ return _find_submodule_by_name(submodule, remaining_name)
+ else:
+ raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'")
+
+
+# TODO(aryan): remove everything below this after next torch release
+def _get_total_norm(
+ tensors: Union[torch.Tensor, List[torch.Tensor]],
+ norm_type: float = 2.0,
+ error_if_nonfinite: bool = False,
+ foreach: Optional[bool] = None,
+) -> torch.Tensor:
+ if isinstance(tensors, torch.Tensor):
+ tensors = [tensors]
+ else:
+ tensors = list(tensors)
+ norm_type = float(norm_type)
+ if len(tensors) == 0:
+ return torch.tensor(0.0)
+ first_device = tensors[0].device
+ grouped_tensors: dict[tuple[torch.device, torch.dtype], tuple[list[list[torch.Tensor]], list[int]]] = (
+ _group_tensors_by_device_and_dtype(
+ [tensors] # type: ignore[list-item]
+ )
+ ) # type: ignore[assignment]
+
+ norms: List[torch.Tensor] = []
+ for (device, _), ([device_tensors], _) in grouped_tensors.items():
+ if (foreach is None and _has_foreach_support(device_tensors, device)) or (
+ foreach and _device_has_foreach_support(device)
+ ):
+ norms.extend(torch._foreach_norm(device_tensors, norm_type))
+ elif foreach:
+ raise RuntimeError(f"foreach=True was passed, but can't use the foreach API on {device.type} tensors")
+ else:
+ norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_tensors])
+
+ total_norm = torch.linalg.vector_norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
+
+ if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
+ raise RuntimeError(
+ f"The total norm of order {norm_type} for gradients from "
+ "`parameters` is non-finite, so it cannot be clipped. To disable "
+ "this error and scale the gradients by the non-finite norm anyway, "
+ "set `error_if_nonfinite=False`"
+ )
+ return total_norm
+
+
+@torch.no_grad()
+def _clip_grads_with_norm_(
+ parameters: Union[torch.Tensor, List[torch.Tensor]],
+ max_norm: float,
+ total_norm: torch.Tensor,
+ foreach: Optional[bool] = None,
+) -> None:
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ grads = [p.grad for p in parameters if p.grad is not None]
+ max_norm = float(max_norm)
+ if len(grads) == 0:
+ return
+ grouped_grads: dict[Tuple[torch.device, torch.dtype], Tuple[List[List[torch.Tensor]], List[int]]] = (
+ _group_tensors_by_device_and_dtype([grads])
+ ) # type: ignore[assignment]
+
+ clip_coef = max_norm / (total_norm + 1e-6)
+ # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
+ # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
+ # when the gradients do not reside in CPU memory.
+ clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
+ for (device, _), ([device_grads], _) in grouped_grads.items():
+ if (foreach is None and _has_foreach_support(device_grads, device)) or (
+ foreach and _device_has_foreach_support(device)
+ ):
+ torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
+ elif foreach:
+ raise RuntimeError(f"foreach=True was passed, but can't use the foreach API on {device.type} tensors")
+ else:
+ clip_coef_clamped_device = clip_coef_clamped.to(device)
+ for g in device_grads:
+ g.mul_(clip_coef_clamped_device)
+
+
+def _get_foreach_kernels_supported_devices() -> list[str]:
+ r"""Return the device type list that supports foreach kernels."""
+ return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()]
+
+
+@torch.no_grad()
+def _group_tensors_by_device_and_dtype(
+ tensorlistlist: List[List[Optional[torch.Tensor]]],
+ with_indices: bool = False,
+) -> dict[tuple[torch.device, torch.dtype], tuple[List[List[Optional[torch.Tensor]]], List[int]]]:
+ return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices)
+
+
+def _device_has_foreach_support(device: torch.device) -> bool:
+ return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting()
+
+
+def _has_foreach_support(tensors: List[torch.Tensor], device: torch.device) -> bool:
+ return _device_has_foreach_support(device) and all(t is None or type(t) in [torch.Tensor] for t in tensors)
diff --git a/docs/finetrainers-src-codebase/pyproject.toml b/docs/finetrainers-src-codebase/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..79d64d3b5d8c741d94ebea0519d1f518dcfdf473
--- /dev/null
+++ b/docs/finetrainers-src-codebase/pyproject.toml
@@ -0,0 +1,28 @@
+[tool.ruff]
+line-length = 119
+
+[tool.ruff.lint]
+# Never enforce `E501` (line length violations).
+ignore = ["C901", "E501", "E741", "F402", "F823"]
+select = ["C", "E", "F", "I", "W"]
+
+# Ignore import violations in all `__init__.py` files.
+[tool.ruff.lint.per-file-ignores]
+"__init__.py" = ["E402", "F401", "F403", "F811"]
+
+[tool.ruff.lint.isort]
+lines-after-imports = 2
+known-first-party = []
+
+[tool.ruff.format]
+# Like Black, use double quotes for strings.
+quote-style = "double"
+
+# Like Black, indent with spaces, rather than tabs.
+indent-style = "space"
+
+# Like Black, respect magic trailing commas.
+skip-magic-trailing-comma = false
+
+# Like Black, automatically detect the appropriate line ending.
+line-ending = "auto"
diff --git a/docs/finetrainers-src-codebase/requirements.txt b/docs/finetrainers-src-codebase/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e5b32bd2364397588002c81b44d9226ed2c9ec9a
--- /dev/null
+++ b/docs/finetrainers-src-codebase/requirements.txt
@@ -0,0 +1,20 @@
+accelerate
+bitsandbytes
+datasets>=3.3.2
+diffusers>=0.32.1
+transformers>=4.45.2
+huggingface_hub
+hf_transfer>=0.1.8
+peft>=0.13.0
+decord>=0.6.0
+wandb
+pandas
+torch>=2.5.1
+torchvision>=0.20.1
+torchdata>=0.10.1
+torchao>=0.7.0
+sentencepiece>=0.2.0
+imageio-ffmpeg>=0.5.1
+numpy>=1.26.4
+kornia>=0.7.3
+ruff==0.9.10
diff --git a/docs/finetrainers-src-codebase/setup.py b/docs/finetrainers-src-codebase/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a4cad38e51b1df8178a13f363c4caafdb5917dc
--- /dev/null
+++ b/docs/finetrainers-src-codebase/setup.py
@@ -0,0 +1,46 @@
+from setuptools import find_packages, setup
+
+
+with open("README.md", "r", encoding="utf-8") as file:
+ long_description = file.read()
+
+with open("requirements.txt", "r", encoding="utf-8") as file:
+ requirements = [line for line in file.read().splitlines() if len(line) > 0]
+
+setup(
+ name="finetrainers",
+ version="0.2.0.dev0",
+ description="Finetrainers is a work-in-progress library to support (accessible) training of diffusion models",
+ long_description=long_description,
+ long_description_content_type="text/markdown",
+ author="Aryan V S",
+ author_email="contact.aryanvs@gmail.com",
+ url="https://github.com/a-r-r-o-w/finetrainers",
+ python_requires=">=3.8.0",
+ license="Apache-2.0",
+ packages=find_packages(),
+ install_requires=requirements,
+ extras_require={"dev": ["pytest==8.3.2", "ruff==0.1.5"]},
+ classifiers=[
+ "Development Status :: 1 - Planning",
+ "Intended Audience :: Science/Research",
+ "Intended Audience :: Developers",
+ "Intended Audience :: Education",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Operating System :: Microsoft :: Windows",
+ "Operating System :: Unix",
+ "License :: OSI Approved :: MIT License",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ ],
+)
+
+# Steps to publish:
+# 1. Update version in setup.py
+# 2. python setup.py sdist bdist_wheel
+# 3. Check if everything works with testpypi:
+# twine upload --repository testpypi dist/*
+# 4. Upload to pypi:
+# twine upload dist/*
diff --git a/docs/finetrainers-src-codebase/tests/README.md b/docs/finetrainers-src-codebase/tests/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8ef26d6044a2e628a98baf1e5f22a162a3e3d7d1
--- /dev/null
+++ b/docs/finetrainers-src-codebase/tests/README.md
@@ -0,0 +1,53 @@
+# Running tests
+
+TODO(aryan): everything here needs to be improved.
+
+## `trainer/` fast tests
+
+- For SFT tests: `test_sft_trainer.py`
+- For Control tests: `test_control_trainer.py`
+
+Accelerate:
+
+```
+# world_size=1 tests
+accelerate launch --config_file accelerate_configs/uncompiled_1.yaml -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_1___batch_size_1 and ___Accelerate"
+accelerate launch --config_file accelerate_configs/uncompiled_1.yaml -m pytest -s tests/trainer/test_sft_trainer.py -k "test___layerwise_upcasting___dp_degree_1___batch_size_1 and ___Accelerate"
+
+# world_size=2 tests
+accelerate launch --config_file accelerate_configs/uncompiled_2.yaml -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_2___batch_size_1 and ___Accelerate"
+```
+
+PTD:
+
+```
+# world_size=1 tests
+torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_1___batch_size_1 and ___PTD"
+torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___layerwise_upcasting___dp_degree_1___batch_size_1 and ___PTD"
+torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_1___batch_size_2 and ___PTD"
+
+# world_size=2 tests
+torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_2___batch_size_1 and ___PTD"
+torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___layerwise_upcasting___dp_degree_2___batch_size_1 and ___PTD"
+torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_2___batch_size_2 and ___PTD"
+torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_shards_2___batch_size_1 and ___PTD"
+torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_shards_2___batch_size_2 and ___PTD"
+torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___tp_degree_2___batch_size_2 and ___PTD"
+torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___cp_degree_2___batch_size_1 and ___PTD"
+
+# world_size=4 tests
+torchrun --nnodes=1 --nproc_per_node 4 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_2___dp_shards_2___batch_size_1 and ___PTD"
+torchrun --nnodes=1 --nproc_per_node 4 -m pytest -s tests/trainer/test_sft_trainer.py -k "test___dp_degree_2___cp_degree_2___batch_size_1 and ___PTD"
+```
+
+## CP tests
+
+PTD:
+
+```
+# world_size=2 tests
+torchrun --nnodes 1 --nproc_per_node 2 -m pytest -s tests/models/attention_dispatch.py::RingAttentionCP2Test
+
+# world_size=4 tests
+torchrun --nnodes 1 --nproc_per_node 4 -m pytest -s tests/models/attention_dispatch.py::RingAttentionCP4Test
+```
diff --git a/docs/finetrainers-src-codebase/tests/__init__.py b/docs/finetrainers-src-codebase/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/docs/finetrainers-src-codebase/tests/_test_dataset_old.py b/docs/finetrainers-src-codebase/tests/_test_dataset_old.py
new file mode 100644
index 0000000000000000000000000000000000000000..740a9c91e710b1b9cfaebf74b43071cd837acdb9
--- /dev/null
+++ b/docs/finetrainers-src-codebase/tests/_test_dataset_old.py
@@ -0,0 +1,104 @@
+# Run: python3 tests/test_dataset.py
+
+import sys
+
+
+def test_video_dataset():
+ from cogvideox.dataset import VideoDataset
+
+ dataset_dirs = VideoDataset(
+ data_root="assets/tests/",
+ caption_column="prompts.txt",
+ video_column="videos.txt",
+ max_num_frames=49,
+ id_token=None,
+ random_flip=None,
+ )
+ dataset_csv = VideoDataset(
+ data_root="assets/tests/",
+ dataset_file="assets/tests/metadata.csv",
+ caption_column="caption",
+ video_column="video",
+ max_num_frames=49,
+ id_token=None,
+ random_flip=None,
+ )
+
+ assert len(dataset_dirs) == 1
+ assert len(dataset_csv) == 1
+ assert dataset_dirs[0]["video"].shape == (49, 3, 480, 720)
+ assert (dataset_dirs[0]["video"] == dataset_csv[0]["video"]).all()
+
+ print(dataset_dirs[0]["video"].shape)
+
+
+def test_video_dataset_with_resizing():
+ from cogvideox.dataset import VideoDatasetWithResizing
+
+ dataset_dirs = VideoDatasetWithResizing(
+ data_root="assets/tests/",
+ caption_column="prompts.txt",
+ video_column="videos.txt",
+ max_num_frames=49,
+ id_token=None,
+ random_flip=None,
+ )
+ dataset_csv = VideoDatasetWithResizing(
+ data_root="assets/tests/",
+ dataset_file="assets/tests/metadata.csv",
+ caption_column="caption",
+ video_column="video",
+ max_num_frames=49,
+ id_token=None,
+ random_flip=None,
+ )
+
+ assert len(dataset_dirs) == 1
+ assert len(dataset_csv) == 1
+ assert dataset_dirs[0]["video"].shape == (48, 3, 480, 720) # Changes due to T2V frame bucket sampling
+ assert (dataset_dirs[0]["video"] == dataset_csv[0]["video"]).all()
+
+ print(dataset_dirs[0]["video"].shape)
+
+
+def test_video_dataset_with_bucket_sampler():
+ import torch
+ from cogvideox.dataset import BucketSampler, VideoDatasetWithResizing
+ from torch.utils.data import DataLoader
+
+ dataset_dirs = VideoDatasetWithResizing(
+ data_root="assets/tests/",
+ caption_column="prompts_multi.txt",
+ video_column="videos_multi.txt",
+ max_num_frames=49,
+ id_token=None,
+ random_flip=None,
+ )
+ sampler = BucketSampler(dataset_dirs, batch_size=8)
+
+ def collate_fn(data):
+ captions = [x["prompt"] for x in data[0]]
+ videos = [x["video"] for x in data[0]]
+ videos = torch.stack(videos)
+ return captions, videos
+
+ dataloader = DataLoader(dataset_dirs, batch_size=1, sampler=sampler, collate_fn=collate_fn)
+ first = False
+
+ for captions, videos in dataloader:
+ if not first:
+ assert len(captions) == 8 and isinstance(captions[0], str)
+ assert videos.shape == (8, 48, 3, 480, 720)
+ first = True
+ else:
+ assert len(captions) == 8 and isinstance(captions[0], str)
+ assert videos.shape == (8, 48, 3, 256, 360)
+ break
+
+
+if __name__ == "__main__":
+ sys.path.append("./training")
+
+ test_video_dataset()
+ test_video_dataset_with_resizing()
+ test_video_dataset_with_bucket_sampler()
diff --git a/docs/finetrainers-src-codebase/tests/data/__init__.py b/docs/finetrainers-src-codebase/tests/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/docs/finetrainers-src-codebase/tests/data/test_dataset.py b/docs/finetrainers-src-codebase/tests/data/test_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..66186aba4718320767476b46486e27716062c640
--- /dev/null
+++ b/docs/finetrainers-src-codebase/tests/data/test_dataset.py
@@ -0,0 +1,355 @@
+import pathlib
+import tempfile
+import unittest
+
+import torch
+from PIL import Image
+
+from finetrainers.data import (
+ ImageCaptionFilePairDataset,
+ ImageFileCaptionFileListDataset,
+ ImageFolderDataset,
+ ValidationDataset,
+ VideoCaptionFilePairDataset,
+ VideoFileCaptionFileListDataset,
+ VideoFolderDataset,
+ VideoWebDataset,
+ initialize_dataset,
+)
+from finetrainers.utils import find_files
+
+from .utils import create_dummy_directory_structure
+
+
+class DatasetTesterMixin:
+ num_data_files = None
+ directory_structure = None
+ caption = "A cat ruling the world"
+ metadata_extension = None
+
+ def setUp(self):
+ if self.num_data_files is None:
+ raise ValueError("num_data_files is not defined")
+ if self.directory_structure is None:
+ raise ValueError("dataset_structure is not defined")
+
+ self.tmpdir = tempfile.TemporaryDirectory()
+ create_dummy_directory_structure(
+ self.directory_structure, self.tmpdir, self.num_data_files, self.caption, self.metadata_extension
+ )
+
+ def tearDown(self):
+ self.tmpdir.cleanup()
+
+
+class ImageDatasetTesterMixin(DatasetTesterMixin):
+ metadata_extension = "jpg"
+
+
+class VideoDatasetTesterMixin(DatasetTesterMixin):
+ metadata_extension = "mp4"
+
+
+class ImageCaptionFilePairDatasetFastTests(ImageDatasetTesterMixin, unittest.TestCase):
+ num_data_files = 3
+ directory_structure = [
+ "0.jpg",
+ "1.jpg",
+ "2.jpg",
+ "0.txt",
+ "1.txt",
+ "2.txt",
+ ]
+
+ def setUp(self):
+ super().setUp()
+ self.dataset = ImageCaptionFilePairDataset(self.tmpdir.name, infinite=False)
+
+ def test_getitem(self):
+ iterator = iter(self.dataset)
+ for _ in range(self.num_data_files):
+ item = next(iterator)
+ self.assertEqual(item["caption"], self.caption)
+ self.assertTrue(torch.is_tensor(item["image"]))
+ self.assertEqual(item["image"].shape, (3, 64, 64))
+
+ def test_initialize_dataset(self):
+ dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False)
+ self.assertIsInstance(dataset, ImageCaptionFilePairDataset)
+
+
+class ImageFileCaptionFileListDatasetFastTests(ImageDatasetTesterMixin, unittest.TestCase):
+ num_data_files = 3
+ directory_structure = [
+ "prompts.txt",
+ "images.txt",
+ "images/",
+ "images/0.jpg",
+ "images/1.jpg",
+ "images/2.jpg",
+ ]
+
+ def setUp(self):
+ super().setUp()
+ self.dataset = ImageFileCaptionFileListDataset(self.tmpdir.name, infinite=False)
+
+ def test_getitem(self):
+ iterator = iter(self.dataset)
+ for i in range(3):
+ item = next(iterator)
+ self.assertEqual(item["caption"], self.caption)
+ self.assertTrue(torch.is_tensor(item["image"]))
+ self.assertEqual(item["image"].shape, (3, 64, 64))
+
+ def test_initialize_dataset(self):
+ dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False)
+ self.assertIsInstance(dataset, ImageFileCaptionFileListDataset)
+
+
+class ImageFolderDatasetFastTests___CSV(ImageDatasetTesterMixin, unittest.TestCase):
+ num_data_files = 3
+ directory_structure = [
+ "metadata.csv",
+ "0.jpg",
+ "1.jpg",
+ "2.jpg",
+ ]
+
+ def setUp(self):
+ super().setUp()
+ self.dataset = ImageFolderDataset(self.tmpdir.name, infinite=False)
+
+ def test_getitem(self):
+ iterator = iter(self.dataset)
+ for _ in range(3):
+ item = next(iterator)
+ self.assertIn("caption", item)
+ self.assertEqual(item["caption"], self.caption)
+ self.assertTrue(torch.is_tensor(item["image"]))
+
+ def test_initialize_dataset(self):
+ dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False)
+ self.assertIsInstance(dataset, ImageFolderDataset)
+
+
+class ImageFolderDatasetFastTests___JSONL(ImageDatasetTesterMixin, unittest.TestCase):
+ num_data_files = 3
+ directory_structure = [
+ "metadata.jsonl",
+ "0.jpg",
+ "1.jpg",
+ "2.jpg",
+ ]
+
+ def setUp(self):
+ super().setUp()
+ self.dataset = ImageFolderDataset(self.tmpdir.name, infinite=False)
+
+ def test_getitem(self):
+ iterator = iter(self.dataset)
+ for _ in range(3):
+ item = next(iterator)
+ self.assertIn("caption", item)
+ self.assertEqual(item["caption"], self.caption)
+ self.assertTrue(torch.is_tensor(item["image"]))
+
+ def test_initialize_dataset(self):
+ dataset = initialize_dataset(self.tmpdir.name, "image", infinite=False)
+ self.assertIsInstance(dataset, ImageFolderDataset)
+
+
+class VideoCaptionFilePairDatasetFastTests(VideoDatasetTesterMixin, unittest.TestCase):
+ num_data_files = 3
+ directory_structure = [
+ "0.mp4",
+ "1.mp4",
+ "2.mp4",
+ "0.txt",
+ "1.txt",
+ "2.txt",
+ ]
+
+ def setUp(self):
+ super().setUp()
+ self.dataset = VideoCaptionFilePairDataset(self.tmpdir.name, infinite=False)
+
+ def test_getitem(self):
+ iterator = iter(self.dataset)
+ for _ in range(self.num_data_files):
+ item = next(iterator)
+ self.assertEqual(item["caption"], self.caption)
+ self.assertTrue(torch.is_tensor(item["video"]))
+ self.assertEqual(len(item["video"]), 4)
+ self.assertEqual(item["video"][0].shape, (3, 64, 64))
+
+ def test_initialize_dataset(self):
+ dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False)
+ self.assertIsInstance(dataset, VideoCaptionFilePairDataset)
+
+
+class VideoFileCaptionFileListDatasetFastTests(VideoDatasetTesterMixin, unittest.TestCase):
+ num_data_files = 3
+ directory_structure = [
+ "prompts.txt",
+ "videos.txt",
+ "videos/",
+ "videos/0.mp4",
+ "videos/1.mp4",
+ "videos/2.mp4",
+ ]
+
+ def setUp(self):
+ super().setUp()
+ self.dataset = VideoFileCaptionFileListDataset(self.tmpdir.name, infinite=False)
+
+ def test_getitem(self):
+ iterator = iter(self.dataset)
+ for _ in range(3):
+ item = next(iterator)
+ self.assertEqual(item["caption"], self.caption)
+ self.assertTrue(torch.is_tensor(item["video"]))
+ self.assertEqual(len(item["video"]), 4)
+ self.assertEqual(item["video"][0].shape, (3, 64, 64))
+
+ def test_initialize_dataset(self):
+ dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False)
+ self.assertIsInstance(dataset, VideoFileCaptionFileListDataset)
+
+
+class VideoFolderDatasetFastTests___CSV(VideoDatasetTesterMixin, unittest.TestCase):
+ num_data_files = 3
+ directory_structure = [
+ "metadata.csv",
+ "0.mp4",
+ "1.mp4",
+ "2.mp4",
+ ]
+
+ def setUp(self):
+ super().setUp()
+ self.dataset = VideoFolderDataset(self.tmpdir.name, infinite=False)
+
+ def test_getitem(self):
+ iterator = iter(self.dataset)
+ for _ in range(3):
+ item = next(iterator)
+ self.assertIn("caption", item)
+ self.assertEqual(item["caption"], self.caption)
+ self.assertTrue(torch.is_tensor(item["video"]))
+ self.assertEqual(len(item["video"]), 4)
+ self.assertEqual(item["video"][0].shape, (3, 64, 64))
+
+ def test_initialize_dataset(self):
+ dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False)
+ self.assertIsInstance(dataset, VideoFolderDataset)
+
+
+class VideoFolderDatasetFastTests___JSONL(VideoDatasetTesterMixin, unittest.TestCase):
+ num_data_files = 3
+ directory_structure = [
+ "metadata.jsonl",
+ "0.mp4",
+ "1.mp4",
+ "2.mp4",
+ ]
+
+ def setUp(self):
+ super().setUp()
+ self.dataset = VideoFolderDataset(self.tmpdir.name, infinite=False)
+
+ def test_getitem(self):
+ iterator = iter(self.dataset)
+ for _ in range(3):
+ item = next(iterator)
+ self.assertIn("caption", item)
+ self.assertEqual(item["caption"], self.caption)
+ self.assertTrue(torch.is_tensor(item["video"]))
+ self.assertEqual(len(item["video"]), 4)
+ self.assertEqual(item["video"][0].shape, (3, 64, 64))
+
+ def test_initialize_dataset(self):
+ dataset = initialize_dataset(self.tmpdir.name, "video", infinite=False)
+ self.assertIsInstance(dataset, VideoFolderDataset)
+
+
+class ImageWebDatasetFastTests(unittest.TestCase):
+ # TODO(aryan): setup a dummy dataset
+ pass
+
+
+class VideoWebDatasetFastTests(unittest.TestCase):
+ def setUp(self):
+ self.num_data_files = 15
+ self.dataset = VideoWebDataset("finetrainers/dummy-squish-wds", infinite=False)
+
+ def test_getitem(self):
+ for index, item in enumerate(self.dataset):
+ if index > 2:
+ break
+ self.assertIn("caption", item)
+ self.assertIn("video", item)
+ self.assertTrue(torch.is_tensor(item["video"]))
+ self.assertEqual(len(item["video"]), 121)
+ self.assertEqual(item["video"][0].shape, (3, 720, 1280))
+
+ def test_initialize_dataset(self):
+ dataset = initialize_dataset("finetrainers/dummy-squish-wds", "video", infinite=False)
+ self.assertIsInstance(dataset, VideoWebDataset)
+
+
+class DatasetUtilsFastTests(unittest.TestCase):
+ def test_find_files_depth_0(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ file1 = tempfile.NamedTemporaryFile(dir=tmpdir, suffix=".txt", delete=False)
+ file2 = tempfile.NamedTemporaryFile(dir=tmpdir, suffix=".txt", delete=False)
+ file3 = tempfile.NamedTemporaryFile(dir=tmpdir, suffix=".txt", delete=False)
+
+ files = find_files(tmpdir, "*.txt")
+ self.assertEqual(len(files), 3)
+ self.assertIn(file1.name, files)
+ self.assertIn(file2.name, files)
+ self.assertIn(file3.name, files)
+
+ def test_find_files_depth_n(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ dir1 = tempfile.TemporaryDirectory(dir=tmpdir)
+ dir2 = tempfile.TemporaryDirectory(dir=dir1.name)
+ file1 = tempfile.NamedTemporaryFile(dir=dir1.name, suffix=".txt", delete=False)
+ file2 = tempfile.NamedTemporaryFile(dir=dir2.name, suffix=".txt", delete=False)
+
+ files = find_files(tmpdir, "*.txt", depth=1)
+ self.assertEqual(len(files), 1)
+ self.assertIn(file1.name, files)
+ self.assertNotIn(file2.name, files)
+
+ files = find_files(tmpdir, "*.txt", depth=2)
+ self.assertEqual(len(files), 2)
+ self.assertIn(file1.name, files)
+ self.assertIn(file2.name, files)
+ self.assertNotIn(dir1.name, files)
+ self.assertNotIn(dir2.name, files)
+
+
+class ValidationDatasetFastTests(unittest.TestCase):
+ def setUp(self):
+ num_data_files = 3
+
+ self.tmpdir = tempfile.TemporaryDirectory()
+ metadata_filename = pathlib.Path(self.tmpdir.name) / "metadata.csv"
+
+ with open(metadata_filename, "w") as f:
+ f.write("caption,image_path,video_path\n")
+ for i in range(num_data_files):
+ Image.new("RGB", (64, 64)).save((pathlib.Path(self.tmpdir.name) / f"{i}.jpg").as_posix())
+ f.write(f"test caption,{self.tmpdir.name}/{i}.jpg,\n")
+
+ self.dataset = ValidationDataset(metadata_filename.as_posix())
+
+ def tearDown(self):
+ self.tmpdir.cleanup()
+
+ def test_getitem(self):
+ for i, data in enumerate(self.dataset):
+ self.assertEqual(data["image_path"], f"{self.tmpdir.name}/{i}.jpg")
+ self.assertIsInstance(data["image"], Image.Image)
+ self.assertEqual(data["image"].size, (64, 64))
diff --git a/docs/finetrainers-src-codebase/tests/data/test_precomputation.py b/docs/finetrainers-src-codebase/tests/data/test_precomputation.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f1879f3c82b2e900ffa229d50b2cfefd506b698
--- /dev/null
+++ b/docs/finetrainers-src-codebase/tests/data/test_precomputation.py
@@ -0,0 +1,212 @@
+import os
+import tempfile
+import unittest
+
+from finetrainers.data import (
+ InMemoryDistributedDataPreprocessor,
+ PrecomputedDistributedDataPreprocessor,
+ VideoCaptionFilePairDataset,
+ initialize_preprocessor,
+ wrap_iterable_dataset_for_preprocessing,
+)
+from finetrainers.data.precomputation import PRECOMPUTED_DATA_DIR
+from finetrainers.utils import find_files
+
+from .utils import create_dummy_directory_structure
+
+
+class PreprocessorFastTests(unittest.TestCase):
+ def setUp(self):
+ self.rank = 0
+ self.world_size = 1
+ self.num_items = 3
+ self.processor_fn = {
+ "latent": self._latent_processor_fn,
+ "condition": self._condition_processor_fn,
+ }
+ self.save_dir = tempfile.TemporaryDirectory()
+
+ directory_structure = [
+ "0.mp4",
+ "1.mp4",
+ "2.mp4",
+ "0.txt",
+ "1.txt",
+ "2.txt",
+ ]
+ create_dummy_directory_structure(
+ directory_structure, self.save_dir, self.num_items, "a cat ruling the world", "mp4"
+ )
+
+ dataset = VideoCaptionFilePairDataset(self.save_dir.name, infinite=True)
+ dataset = wrap_iterable_dataset_for_preprocessing(
+ dataset,
+ dataset_type="video",
+ config={
+ "video_resolution_buckets": [[2, 32, 32]],
+ "reshape_mode": "bicubic",
+ },
+ )
+ self.dataset = dataset
+
+ def tearDown(self):
+ self.save_dir.cleanup()
+
+ @staticmethod
+ def _latent_processor_fn(**data):
+ video = data["video"]
+ video = video[:, :, :16, :16]
+ data["video"] = video
+ return data
+
+ @staticmethod
+ def _condition_processor_fn(**data):
+ caption = data["caption"]
+ caption = caption + " surrounded by mystical aura"
+ data["caption"] = caption
+ return data
+
+ def test_initialize_preprocessor(self):
+ preprocessor = initialize_preprocessor(
+ self.rank,
+ self.world_size,
+ self.num_items,
+ self.processor_fn,
+ self.save_dir.name,
+ enable_precomputation=False,
+ )
+ self.assertIsInstance(preprocessor, InMemoryDistributedDataPreprocessor)
+
+ preprocessor = initialize_preprocessor(
+ self.rank,
+ self.world_size,
+ self.num_items,
+ self.processor_fn,
+ self.save_dir.name,
+ enable_precomputation=True,
+ )
+ self.assertIsInstance(preprocessor, PrecomputedDistributedDataPreprocessor)
+
+ def test_in_memory_preprocessor_consume(self):
+ data_iterator = iter(self.dataset)
+ preprocessor = initialize_preprocessor(
+ self.rank,
+ self.world_size,
+ self.num_items,
+ self.processor_fn,
+ self.save_dir.name,
+ enable_precomputation=False,
+ )
+
+ condition_iterator = preprocessor.consume(
+ "condition", components={}, data_iterator=data_iterator, cache_samples=True
+ )
+ latent_iterator = preprocessor.consume(
+ "latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True
+ )
+
+ self.assertFalse(preprocessor.requires_data)
+ for _ in range(self.num_items):
+ condition_item = next(condition_iterator)
+ latent_item = next(latent_iterator)
+ self.assertIn("caption", condition_item)
+ self.assertIn("video", latent_item)
+ self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura")
+ self.assertEqual(latent_item["video"].shape[-2:], (16, 16))
+ self.assertTrue(preprocessor.requires_data)
+
+ def test_in_memory_preprocessor_consume_once(self):
+ data_iterator = iter(self.dataset)
+ preprocessor = initialize_preprocessor(
+ self.rank,
+ self.world_size,
+ self.num_items,
+ self.processor_fn,
+ self.save_dir.name,
+ enable_precomputation=False,
+ )
+
+ condition_iterator = preprocessor.consume_once(
+ "condition", components={}, data_iterator=data_iterator, cache_samples=True
+ )
+ latent_iterator = preprocessor.consume_once(
+ "latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True
+ )
+
+ self.assertFalse(preprocessor.requires_data)
+ for _ in range(self.num_items):
+ condition_item = next(condition_iterator)
+ latent_item = next(latent_iterator)
+ self.assertIn("caption", condition_item)
+ self.assertIn("video", latent_item)
+ self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura")
+ self.assertEqual(latent_item["video"].shape[-2:], (16, 16))
+ self.assertFalse(preprocessor.requires_data)
+
+ def test_precomputed_preprocessor_consume(self):
+ data_iterator = iter(self.dataset)
+ preprocessor = initialize_preprocessor(
+ self.rank,
+ self.world_size,
+ self.num_items,
+ self.processor_fn,
+ self.save_dir.name,
+ enable_precomputation=True,
+ )
+
+ condition_iterator = preprocessor.consume(
+ "condition", components={}, data_iterator=data_iterator, cache_samples=True
+ )
+ latent_iterator = preprocessor.consume(
+ "latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True
+ )
+
+ precomputed_data_dir = os.path.join(self.save_dir.name, PRECOMPUTED_DATA_DIR)
+ condition_file_list = find_files(precomputed_data_dir, "condition-*")
+ latent_file_list = find_files(precomputed_data_dir, "latent-*")
+ self.assertEqual(len(condition_file_list), 3)
+ self.assertEqual(len(latent_file_list), 3)
+
+ self.assertFalse(preprocessor.requires_data)
+ for _ in range(self.num_items):
+ condition_item = next(condition_iterator)
+ latent_item = next(latent_iterator)
+ self.assertIn("caption", condition_item)
+ self.assertIn("video", latent_item)
+ self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura")
+ self.assertEqual(latent_item["video"].shape[-2:], (16, 16))
+ self.assertTrue(preprocessor.requires_data)
+
+ def test_precomputed_preprocessor_consume_once(self):
+ data_iterator = iter(self.dataset)
+ preprocessor = initialize_preprocessor(
+ self.rank,
+ self.world_size,
+ self.num_items,
+ self.processor_fn,
+ self.save_dir.name,
+ enable_precomputation=True,
+ )
+
+ condition_iterator = preprocessor.consume_once(
+ "condition", components={}, data_iterator=data_iterator, cache_samples=True
+ )
+ latent_iterator = preprocessor.consume_once(
+ "latent", components={}, data_iterator=data_iterator, use_cached_samples=True, drop_samples=True
+ )
+
+ precomputed_data_dir = os.path.join(self.save_dir.name, PRECOMPUTED_DATA_DIR)
+ condition_file_list = find_files(precomputed_data_dir, "condition-*")
+ latent_file_list = find_files(precomputed_data_dir, "latent-*")
+ self.assertEqual(len(condition_file_list), 3)
+ self.assertEqual(len(latent_file_list), 3)
+
+ self.assertFalse(preprocessor.requires_data)
+ for _ in range(self.num_items):
+ condition_item = next(condition_iterator)
+ latent_item = next(latent_iterator)
+ self.assertIn("caption", condition_item)
+ self.assertIn("video", latent_item)
+ self.assertEqual(condition_item["caption"], "a cat ruling the world surrounded by mystical aura")
+ self.assertEqual(latent_item["video"].shape[-2:], (16, 16))
+ self.assertFalse(preprocessor.requires_data)
diff --git a/docs/finetrainers-src-codebase/tests/data/utils.py b/docs/finetrainers-src-codebase/tests/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba4cc1fe38bbd311ce718d5f35f91eb2c12c313e
--- /dev/null
+++ b/docs/finetrainers-src-codebase/tests/data/utils.py
@@ -0,0 +1,53 @@
+import pathlib
+from typing import List
+
+from diffusers.utils import export_to_video
+from PIL import Image
+
+from finetrainers.data.dataset import COMMON_CAPTION_FILES, COMMON_IMAGE_FILES, COMMON_VIDEO_FILES # noqa
+
+
+def create_dummy_directory_structure(
+ directory_structure: List[str], tmpdir, num_data_files: int, caption: str, metadata_extension: str
+):
+ for item in directory_structure:
+ # TODO(aryan): this should be improved
+ if item in COMMON_CAPTION_FILES:
+ data_file = pathlib.Path(tmpdir.name) / item
+ with open(data_file.as_posix(), "w") as f:
+ for _ in range(num_data_files):
+ f.write(f"{caption}\n")
+ elif item in COMMON_IMAGE_FILES:
+ data_file = pathlib.Path(tmpdir.name) / item
+ with open(data_file.as_posix(), "w") as f:
+ for i in range(num_data_files):
+ f.write(f"images/{i}.jpg\n")
+ elif item in COMMON_VIDEO_FILES:
+ data_file = pathlib.Path(tmpdir.name) / item
+ with open(data_file.as_posix(), "w") as f:
+ for i in range(num_data_files):
+ f.write(f"videos/{i}.mp4\n")
+ elif item == "metadata.csv":
+ data_file = pathlib.Path(tmpdir.name) / item
+ with open(data_file.as_posix(), "w") as f:
+ f.write("file_name,caption\n")
+ for i in range(num_data_files):
+ f.write(f"{i}.{metadata_extension},{caption}\n")
+ elif item == "metadata.jsonl":
+ data_file = pathlib.Path(tmpdir.name) / item
+ with open(data_file.as_posix(), "w") as f:
+ for i in range(num_data_files):
+ f.write(f'{{"file_name": "{i}.{metadata_extension}", "caption": "{caption}"}}\n')
+ elif item.endswith(".txt"):
+ data_file = pathlib.Path(tmpdir.name) / item
+ with open(data_file.as_posix(), "w") as f:
+ f.write(caption)
+ elif item.endswith(".jpg") or item.endswith(".png"):
+ data_file = pathlib.Path(tmpdir.name) / item
+ Image.new("RGB", (64, 64)).save(data_file.as_posix())
+ elif item.endswith(".mp4"):
+ data_file = pathlib.Path(tmpdir.name) / item
+ export_to_video([Image.new("RGB", (64, 64))] * 4, data_file.as_posix(), fps=2)
+ else:
+ data_file = pathlib.Path(tmpdir.name, item)
+ data_file.mkdir(exist_ok=True, parents=True)
diff --git a/docs/finetrainers-src-codebase/tests/models/__init__.py b/docs/finetrainers-src-codebase/tests/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/docs/finetrainers-src-codebase/tests/models/attention_dispatch.py b/docs/finetrainers-src-codebase/tests/models/attention_dispatch.py
new file mode 100644
index 0000000000000000000000000000000000000000..1de978d248b2733bab0e049555a78cd3d1832ce0
--- /dev/null
+++ b/docs/finetrainers-src-codebase/tests/models/attention_dispatch.py
@@ -0,0 +1,363 @@
+import os
+import random
+import unittest
+
+import numpy as np
+import torch
+from torch.nn.functional import scaled_dot_product_attention
+
+from finetrainers.models.attention_dispatch import (
+ AttentionProvider,
+ _AttentionProviderRegistry,
+ _set_context_parallel_options,
+ attention_dispatch,
+ attention_provider,
+ flash_attn_flash_attention,
+ native_cudnn_attention,
+ native_efficient_attention,
+ native_flash_attention,
+)
+from finetrainers.parallel.ptd import _EquipartitionSharder
+
+
+def set_seed(seed):
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+
+
+def get_world_size():
+ if torch.distributed.is_initialized():
+ return torch.distributed.get_world_size()
+ return int(os.environ.get("WORLD_SIZE", 1))
+
+
+class AttentionDispatchTest(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ set_seed(0)
+
+ def test_forward(self):
+ if not torch.cuda.is_available():
+ self.skipTest("CUDA is not available")
+ cuda_capability = torch.cuda.get_device_capability()
+
+ query, key, value = self._create_dummy_inputs()
+
+ all_providers = [
+ (AttentionProvider._NATIVE_MATH, 0),
+ (AttentionProvider.NATIVE, 5e-3),
+ (AttentionProvider.FLASH, 5e-3),
+ (AttentionProvider.FLASH_VARLEN, 5e-3),
+ (AttentionProvider.FLEX, 2e-2),
+ (AttentionProvider._NATIVE_CUDNN, 5e-3),
+ (AttentionProvider._NATIVE_EFFICIENT, 5e-3),
+ (AttentionProvider._NATIVE_FLASH, 5e-3),
+ (AttentionProvider.SAGE, 1e-1),
+ (AttentionProvider.SAGE_VARLEN, 2e-0),
+ (AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA, 2e-0), # TODO: look into the high difference threshold
+ (AttentionProvider._SAGE_QK_INT8_PV_FP16_TRITON, 2e-0),
+ (AttentionProvider.XFORMERS, 5e-3),
+ ]
+
+ if cuda_capability >= (8, 9):
+ all_providers.append((AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA, 2e-0))
+ if cuda_capability >= (9, 0):
+ all_providers.append((AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA_SM90, 2e-0))
+
+ ref_output = None
+ for i, (provider, threshold) in enumerate(all_providers):
+ try:
+ output = self._check_forward_pass(provider, query, key, value)
+ if i == 0:
+ ref_output = output.detach().clone()
+ else:
+ self.assertTrue(
+ torch.allclose(output, ref_output, atol=threshold), f"Forward pass mismatch for {provider}"
+ )
+ except Exception as e:
+ print(f"Warning: Forward pass test failed for {provider} with error: {e}")
+
+ def test_backward(self):
+ if not torch.cuda.is_available():
+ self.skipTest("CUDA is not available")
+
+ query, key, value = self._create_dummy_inputs()
+
+ selected_providers = [
+ AttentionProvider.FLASH,
+ AttentionProvider.FLASH_VARLEN,
+ AttentionProvider.FLEX,
+ AttentionProvider.NATIVE,
+ AttentionProvider.XFORMERS,
+ ]
+
+ ref_output = None
+ for i, provider in enumerate(selected_providers):
+ try:
+ output = self._check_backward_pass(provider, query, key, value)
+ if i == 0:
+ ref_output = output.detach().clone()
+ else:
+ if provider == AttentionProvider.FLEX:
+ threshold = 1e-2
+ else:
+ threshold = 1e-3
+ self.assertTrue(
+ torch.allclose(output, ref_output, atol=threshold), f"Backward pass mismatch for {provider}"
+ )
+ except Exception as e:
+ print(f"Warning: Backward pass test failed for {provider} with error: {e}")
+
+ def _create_dummy_inputs(
+ self, batch_size=2, num_heads=8, seq_len=256, head_dim=64, dtype=torch.bfloat16, device="cuda"
+ ):
+ torch.manual_seed(0)
+ query = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
+ key = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
+ value = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
+ return query, key, value
+
+ def _check_forward_pass(self, provider: AttentionProvider, query, key, value):
+ kwargs = {}
+ if provider == AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA:
+ kwargs["pv_accum_dtype"] = "fp32"
+ with attention_provider(provider):
+ output = attention_dispatch(query, key, value, attention_kwargs=kwargs)
+ self.assertIsNotNone(output)
+ self.assertEqual(output.shape, query.shape)
+ return output
+
+ def _check_backward_pass(self, provider: AttentionProvider, query, key, value):
+ query.requires_grad_(True)
+ key.requires_grad_(True)
+ value.requires_grad_(True)
+
+ with attention_provider(provider):
+ output = attention_dispatch(query, key, value)
+ loss = output.mean()
+ loss.backward()
+
+ self.assertTrue(query.grad is not None)
+ self.assertTrue(key.grad is not None)
+ self.assertTrue(value.grad is not None)
+
+ query.grad.zero_()
+ key.grad.zero_()
+ value.grad.zero_()
+ return output
+
+
+class RingAttentionTest(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ torch.distributed.init_process_group(backend="nccl")
+ rank, world_size = torch.distributed.get_rank(), torch.distributed.get_world_size()
+
+ cls.rank = rank
+ cls.world_size = world_size
+ torch.cuda.set_device(rank)
+ cls.mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (world_size,))
+
+ set_seed(0)
+ cls.batch_size = 2
+ cls.num_heads = 8
+ cls.seq_len = 256
+ cls.head_dim = 64
+ cls.dtype = torch.bfloat16
+ cls.device = "cuda"
+
+ _AttentionProviderRegistry._set_context_parallel(
+ mesh=cls.mesh, convert_to_fp32=True, rotate_method="allgather"
+ )
+ _set_context_parallel_options(is_causal=False)
+
+ cls.full_query = torch.randn(
+ cls.batch_size,
+ cls.num_heads,
+ cls.seq_len * cls.world_size,
+ cls.head_dim,
+ dtype=cls.dtype,
+ device=cls.device,
+ requires_grad=True,
+ )
+ cls.full_key = torch.randn(
+ cls.batch_size,
+ cls.num_heads,
+ cls.seq_len * cls.world_size,
+ cls.head_dim,
+ dtype=cls.dtype,
+ device=cls.device,
+ requires_grad=True,
+ )
+ cls.full_value = torch.randn(
+ cls.batch_size,
+ cls.num_heads,
+ cls.seq_len * cls.world_size,
+ cls.head_dim,
+ dtype=cls.dtype,
+ device=cls.device,
+ requires_grad=True,
+ )
+
+ # Ensure all ranks have the same data
+ with torch.no_grad():
+ torch.distributed.broadcast(cls.full_query, src=0)
+ torch.distributed.broadcast(cls.full_key, src=0)
+ torch.distributed.broadcast(cls.full_value, src=0)
+
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
+ reference_output = scaled_dot_product_attention(cls.full_query, cls.full_key, cls.full_value)
+
+ cls.reference_output = reference_output.detach().clone()
+ reference_output.sum().backward()
+
+ cls.query, cls.key, cls.value = (
+ _EquipartitionSharder.shard(x, dim=2, mesh=cls.mesh).detach().clone()
+ for x in (cls.full_query, cls.full_key, cls.full_value)
+ )
+
+ @classmethod
+ def tearDownClass(cls):
+ torch.distributed.destroy_process_group()
+
+ def _test_forward_native_cudnn_attention(self, atol: float = 1e-3):
+ output = native_cudnn_attention(self.query, self.key, self.value)
+ output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh)
+ self.assertEqual(output.shape, self.reference_output.shape)
+ self.assertTrue(torch.allclose(output, self.reference_output, atol=atol))
+
+ def _test_forward_native_efficient_attention(self, atol: float = 1e-3):
+ output = native_efficient_attention(self.query, self.key, self.value)
+ output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh)
+ self.assertEqual(output.shape, self.reference_output.shape)
+ self.assertTrue(torch.allclose(output, self.reference_output, atol=atol))
+
+ def _test_forward_native_flash_attention(self, atol: float = 1e-3):
+ output = native_flash_attention(self.query, self.key, self.value)
+ output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh)
+ self.assertEqual(output.shape, self.reference_output.shape)
+ self.assertTrue(torch.allclose(output, self.reference_output, atol=atol))
+
+ def _test_forward_flash_attn_flash_attention(self, atol: float = 1e-3):
+ output = flash_attn_flash_attention(self.query, self.key, self.value)
+ output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh)
+ self.assertEqual(output.shape, self.reference_output.shape)
+ self.assertTrue(torch.allclose(output, self.reference_output, atol=atol))
+
+ def _test_backward_native_cudnn_attention(self, atol: float = 1e-3):
+ query, key, value = (x.detach().clone() for x in (self.query, self.key, self.value))
+ query.requires_grad = True
+ key.requires_grad = True
+ value.requires_grad = True
+ output = native_cudnn_attention(query, key, value)
+ output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh)
+ output.sum().backward()
+ with torch.no_grad():
+ q_g, k_g, v_g = (
+ _EquipartitionSharder.shard(x, dim=2, mesh=self.mesh)
+ for x in (self.full_query.grad, self.full_key.grad, self.full_value.grad)
+ )
+ self.assertTrue(torch.allclose(query.grad, q_g, atol=atol))
+ self.assertTrue(torch.allclose(key.grad, k_g, atol=atol))
+ self.assertTrue(torch.allclose(value.grad, v_g, atol=atol))
+
+ def _test_backward_native_efficient_attention(self, atol: float = 1e-3):
+ query, key, value = (x.detach().clone() for x in (self.query, self.key, self.value))
+ query.requires_grad = True
+ key.requires_grad = True
+ value.requires_grad = True
+ output = native_efficient_attention(query, key, value)
+ output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh)
+ output.sum().backward()
+ with torch.no_grad():
+ q_g, k_g, v_g = (
+ _EquipartitionSharder.shard(x, dim=2, mesh=self.mesh)
+ for x in (self.full_query.grad, self.full_key.grad, self.full_value.grad)
+ )
+ self.assertTrue(torch.allclose(query.grad, q_g, atol=atol))
+ self.assertTrue(torch.allclose(key.grad, k_g, atol=atol))
+ self.assertTrue(torch.allclose(value.grad, v_g, atol=atol))
+
+ def _test_backward_native_flash_attention(self, atol: float = 1e-3):
+ query, key, value = (x.detach().clone() for x in (self.query, self.key, self.value))
+ query.requires_grad = True
+ key.requires_grad = True
+ value.requires_grad = True
+ output = native_flash_attention(query, key, value)
+ output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh)
+ output.sum().backward()
+ with torch.no_grad():
+ q_g, k_g, v_g = (
+ _EquipartitionSharder.shard(x, dim=2, mesh=self.mesh)
+ for x in (self.full_query.grad, self.full_key.grad, self.full_value.grad)
+ )
+ self.assertTrue(torch.allclose(query.grad, q_g, atol=atol))
+ self.assertTrue(torch.allclose(key.grad, k_g, atol=atol))
+ self.assertTrue(torch.allclose(value.grad, v_g, atol=atol))
+
+ def _test_backward_flash_attn_flash_attention(self, atol: float = 1e-3):
+ query, key, value = (x.detach().clone() for x in (self.query, self.key, self.value))
+ query.requires_grad = True
+ key.requires_grad = True
+ value.requires_grad = True
+ output = flash_attn_flash_attention(query, key, value)
+ output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh)
+ output.sum().backward()
+ with torch.no_grad():
+ q_g, k_g, v_g = (
+ _EquipartitionSharder.shard(x, dim=2, mesh=self.mesh)
+ for x in (self.full_query.grad, self.full_key.grad, self.full_value.grad)
+ )
+ self.assertTrue(torch.allclose(query.grad, q_g, atol=atol))
+ self.assertTrue(torch.allclose(key.grad, k_g, atol=atol))
+ self.assertTrue(torch.allclose(value.grad, v_g, atol=atol))
+
+
+class RingAttentionCPTesterMixin:
+ def test_forward_native_cudnn_attention(self):
+ self._test_forward_native_cudnn_attention(atol=1e-2)
+
+ def test_forward_native_efficient_attention(self):
+ self._test_forward_native_efficient_attention(atol=1e-2)
+
+ def test_forward_native_flash_attention(self):
+ self._test_forward_native_flash_attention(atol=1e-2)
+
+ def test_forward_flash_attn_flash_attention(self):
+ self._test_forward_flash_attn_flash_attention(atol=1e-2)
+
+ def test_backward_native_cudnn_attention(self):
+ atol = 1e-2 * self.world_size # TODO: make bounds more strict
+ self._test_backward_native_cudnn_attention(atol=atol)
+
+ def test_backward_native_efficient_attention(self):
+ atol = 1e-2 * self.world_size # TODO: make bounds more strict
+ self._test_backward_native_efficient_attention(atol=atol)
+
+ def test_backward_native_flash_attention(self):
+ atol = 1e-2 * self.world_size # TODO: make bounds more strict
+ self._test_backward_native_flash_attention(atol=atol)
+
+ @unittest.skip(
+ """query diff: 0.298828125, key diff: 2.09375, value diff: 0.68359375; Needs further investigation"""
+ )
+ def test_backward_flash_attn_flash_attention(self):
+ # Seems to require much higher bound for some reason
+ atol = 1.5e-1 * self.world_size # TODO: make bounds more strict
+ self._test_backward_flash_attn_flash_attention(atol=atol)
+
+
+@unittest.skipIf(
+ not torch.cuda.is_available() or get_world_size() != 2, "CUDA is not available or world size is not 2"
+)
+class RingAttentionCP2Test(RingAttentionTest, RingAttentionCPTesterMixin):
+ pass
+
+
+@unittest.skipIf(
+ not torch.cuda.is_available() or get_world_size() != 4, "CUDA is not available or world size is not 4"
+)
+class RingAttentionCP4Test(RingAttentionTest, RingAttentionCPTesterMixin):
+ pass
diff --git a/docs/finetrainers-src-codebase/tests/models/cogvideox/__init__.py b/docs/finetrainers-src-codebase/tests/models/cogvideox/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/docs/finetrainers-src-codebase/tests/models/cogvideox/base_specification.py b/docs/finetrainers-src-codebase/tests/models/cogvideox/base_specification.py
new file mode 100644
index 0000000000000000000000000000000000000000..c02fd5e434246fafbc18e271404f7cb4a6edfd1c
--- /dev/null
+++ b/docs/finetrainers-src-codebase/tests/models/cogvideox/base_specification.py
@@ -0,0 +1,71 @@
+import torch
+from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXTransformer3DModel
+from transformers import AutoTokenizer, T5EncoderModel
+
+from finetrainers.models.cogvideox import CogVideoXModelSpecification
+
+
+class DummyCogVideoXModelSpecification(CogVideoXModelSpecification):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def load_condition_models(self):
+ text_encoder = T5EncoderModel.from_pretrained(
+ "hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype
+ )
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+ return {"text_encoder": text_encoder, "tokenizer": tokenizer}
+
+ def load_latent_models(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLCogVideoX(
+ in_channels=3,
+ out_channels=3,
+ down_block_types=(
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ ),
+ up_block_types=(
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ ),
+ block_out_channels=(8, 8, 8, 8),
+ latent_channels=4,
+ layers_per_block=1,
+ norm_num_groups=2,
+ temporal_compression_ratio=4,
+ )
+ # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
+ # Doing so overrides things like _keep_in_fp32_modules
+ vae.to(self.vae_dtype)
+ self.vae_config = vae.config
+ return {"vae": vae}
+
+ def load_diffusion_models(self):
+ torch.manual_seed(0)
+ transformer = CogVideoXTransformer3DModel(
+ num_attention_heads=4,
+ attention_head_dim=16,
+ in_channels=4,
+ out_channels=4,
+ time_embed_dim=2,
+ text_embed_dim=32,
+ num_layers=2,
+ sample_width=24,
+ sample_height=24,
+ sample_frames=9,
+ patch_size=2,
+ temporal_compression_ratio=4,
+ max_text_seq_length=16,
+ use_rotary_positional_embeddings=True,
+ )
+ # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
+ # Doing so overrides things like _keep_in_fp32_modules
+ transformer.to(self.transformer_dtype)
+ self.transformer_config = transformer.config
+ scheduler = CogVideoXDDIMScheduler()
+ return {"transformer": transformer, "scheduler": scheduler}
diff --git a/docs/finetrainers-src-codebase/tests/models/cogview4/__init__.py b/docs/finetrainers-src-codebase/tests/models/cogview4/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/docs/finetrainers-src-codebase/tests/models/cogview4/base_specification.py b/docs/finetrainers-src-codebase/tests/models/cogview4/base_specification.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa4f634c17c3dae14321994a129f8af005f15663
--- /dev/null
+++ b/docs/finetrainers-src-codebase/tests/models/cogview4/base_specification.py
@@ -0,0 +1,35 @@
+import torch
+from diffusers import AutoencoderKL, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
+from transformers import AutoTokenizer, GlmModel
+
+from finetrainers.models.cogview4 import CogView4ModelSpecification
+
+
+class DummyCogView4ModelSpecification(CogView4ModelSpecification):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def load_condition_models(self):
+ text_encoder = GlmModel.from_pretrained(
+ "hf-internal-testing/tiny-random-cogview4", subfolder="text_encoder", torch_dtype=self.text_encoder_dtype
+ )
+ tokenizer = AutoTokenizer.from_pretrained(
+ "hf-internal-testing/tiny-random-cogview4", subfolder="tokenizer", trust_remote_code=True
+ )
+ return {"text_encoder": text_encoder, "tokenizer": tokenizer}
+
+ def load_latent_models(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKL.from_pretrained(
+ "hf-internal-testing/tiny-random-cogview4", subfolder="vae", torch_dtype=self.vae_dtype
+ )
+ self.vae_config = vae.config
+ return {"vae": vae}
+
+ def load_diffusion_models(self):
+ torch.manual_seed(0)
+ transformer = CogView4Transformer2DModel.from_pretrained(
+ "hf-internal-testing/tiny-random-cogview4", subfolder="transformer", torch_dtype=self.transformer_dtype
+ )
+ scheduler = FlowMatchEulerDiscreteScheduler()
+ return {"transformer": transformer, "scheduler": scheduler}
diff --git a/docs/finetrainers-src-codebase/tests/models/cogview4/control_specification.py b/docs/finetrainers-src-codebase/tests/models/cogview4/control_specification.py
new file mode 100644
index 0000000000000000000000000000000000000000..eab8c670ce307adb4b9dc0bbf8b28850d37514e8
--- /dev/null
+++ b/docs/finetrainers-src-codebase/tests/models/cogview4/control_specification.py
@@ -0,0 +1,61 @@
+import torch
+from diffusers import AutoencoderKL, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
+from transformers import AutoTokenizer, GlmConfig, GlmModel
+
+from finetrainers.models.cogview4 import CogView4ControlModelSpecification
+from finetrainers.models.utils import _expand_linear_with_zeroed_weights
+
+
+class DummyCogView4ControlModelSpecification(CogView4ControlModelSpecification):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ # This needs to be updated for the test to work correctly.
+ # TODO(aryan): it will not be needed if we hosted the dummy model so that the correct config could be loaded
+ # with ModelSpecification::_load_configs
+ self.transformer_config.in_channels = 4
+
+ def load_condition_models(self):
+ text_encoder_config = GlmConfig(
+ hidden_size=32, intermediate_size=8, num_hidden_layers=2, num_attention_heads=4, head_dim=8
+ )
+ text_encoder = GlmModel(text_encoder_config).to(self.text_encoder_dtype)
+ # TODO(aryan): try to not rely on trust_remote_code by creating dummy tokenizer
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat", trust_remote_code=True)
+ return {"text_encoder": text_encoder, "tokenizer": tokenizer}
+
+ def load_latent_models(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ sample_size=128,
+ ).to(self.vae_dtype)
+ return {"vae": vae}
+
+ def load_diffusion_models(self, new_in_features: int):
+ torch.manual_seed(0)
+ transformer = CogView4Transformer2DModel(
+ patch_size=2,
+ in_channels=4,
+ num_layers=2,
+ attention_head_dim=4,
+ num_attention_heads=4,
+ out_channels=4,
+ text_embed_dim=32,
+ time_embed_dim=8,
+ condition_dim=4,
+ ).to(self.transformer_dtype)
+ actual_new_in_features = new_in_features * transformer.config.patch_size**2
+ transformer.patch_embed.proj = _expand_linear_with_zeroed_weights(
+ transformer.patch_embed.proj, new_in_features=actual_new_in_features
+ )
+ transformer.register_to_config(in_channels=new_in_features)
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {"transformer": transformer, "scheduler": scheduler}
diff --git a/docs/finetrainers-src-codebase/tests/models/flux/__init__.py b/docs/finetrainers-src-codebase/tests/models/flux/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/docs/finetrainers-src-codebase/tests/models/flux/base_specification.py b/docs/finetrainers-src-codebase/tests/models/flux/base_specification.py
new file mode 100644
index 0000000000000000000000000000000000000000..59a2391535527e7f078c0c9759808cb2fa1b0d00
--- /dev/null
+++ b/docs/finetrainers-src-codebase/tests/models/flux/base_specification.py
@@ -0,0 +1,6 @@
+from finetrainers.models.flux import FluxModelSpecification
+
+
+class DummyFluxModelSpecification(FluxModelSpecification):
+ def __init__(self, **kwargs):
+ super().__init__(pretrained_model_name_or_path="hf-internal-testing/tiny-flux-pipe", **kwargs)
diff --git a/docs/finetrainers-src-codebase/tests/models/hunyuan_video/base_specification.py b/docs/finetrainers-src-codebase/tests/models/hunyuan_video/base_specification.py
new file mode 100644
index 0000000000000000000000000000000000000000..620a0ef1b9b431c7397a1db9d56cb6ebb17b9e6f
--- /dev/null
+++ b/docs/finetrainers-src-codebase/tests/models/hunyuan_video/base_specification.py
@@ -0,0 +1,119 @@
+import torch
+from diffusers import AutoencoderKLHunyuanVideo, FlowMatchEulerDiscreteScheduler, HunyuanVideoTransformer3DModel
+from transformers import (
+ CLIPTextConfig,
+ CLIPTextModel,
+ CLIPTokenizer,
+ LlamaConfig,
+ LlamaModel,
+ LlamaTokenizer,
+)
+
+from finetrainers.models.hunyuan_video import HunyuanVideoModelSpecification
+
+
+class DummyHunyuanVideoModelSpecification(HunyuanVideoModelSpecification):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def load_condition_models(self):
+ llama_text_encoder_config = LlamaConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=16,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=2,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=8,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=2,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = LlamaModel(llama_text_encoder_config)
+ tokenizer = LlamaTokenizer.from_pretrained("finetrainers/dummy-hunyaunvideo", subfolder="tokenizer")
+
+ torch.manual_seed(0)
+ text_encoder_2 = CLIPTextModel(clip_text_encoder_config)
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ text_encoder.to(self.text_encoder_dtype)
+ text_encoder_2.to(self.text_encoder_2_dtype)
+
+ return {
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ }
+
+ def load_latent_models(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLHunyuanVideo(
+ in_channels=3,
+ out_channels=3,
+ latent_channels=4,
+ down_block_types=(
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ ),
+ up_block_types=(
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ ),
+ block_out_channels=(8, 8, 8, 8),
+ layers_per_block=1,
+ act_fn="silu",
+ norm_num_groups=4,
+ scaling_factor=0.476986,
+ spatial_compression_ratio=8,
+ temporal_compression_ratio=4,
+ mid_block_add_attention=True,
+ )
+ # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
+ # Doing so overrides things like _keep_in_fp32_modules
+ vae.to(self.vae_dtype)
+ self.vae_config = vae.config
+ return {"vae": vae}
+
+ def load_diffusion_models(self):
+ torch.manual_seed(0)
+ transformer = HunyuanVideoTransformer3DModel(
+ in_channels=4,
+ out_channels=4,
+ num_attention_heads=2,
+ attention_head_dim=10,
+ num_layers=2,
+ num_single_layers=2,
+ num_refiner_layers=1,
+ patch_size=1,
+ patch_size_t=1,
+ guidance_embeds=True,
+ text_embed_dim=16,
+ pooled_projection_dim=8,
+ rope_axes_dim=(2, 4, 4),
+ )
+ # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
+ # Doing so overrides things like _keep_in_fp32_modules
+ transformer.to(self.transformer_dtype)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+ return {"transformer": transformer, "scheduler": scheduler}
diff --git a/docs/finetrainers-src-codebase/tests/models/ltx_video/__init__.py b/docs/finetrainers-src-codebase/tests/models/ltx_video/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/docs/finetrainers-src-codebase/tests/models/ltx_video/_test_tp.py b/docs/finetrainers-src-codebase/tests/models/ltx_video/_test_tp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3432d716a88a98026620654d22a4a5bbcc63ae7
--- /dev/null
+++ b/docs/finetrainers-src-codebase/tests/models/ltx_video/_test_tp.py
@@ -0,0 +1,245 @@
+import copy
+
+import torch
+import torch.distributed as dist
+from diffusers import LTXVideoTransformer3DModel
+from torch._utils import _get_device_module
+from torch.distributed.tensor import DTensor, Replicate
+from torch.distributed.tensor.debug import CommDebugMode
+from torch.distributed.tensor.device_mesh import DeviceMesh
+from torch.distributed.tensor.parallel.api import parallelize_module
+from torch.distributed.tensor.parallel.style import (
+ ColwiseParallel,
+ RowwiseParallel,
+)
+
+
+# from torch.utils._python_dispatch import TorchDispatchMode
+
+
+DEVICE_TYPE = "cuda"
+PG_BACKEND = "nccl"
+DEVICE_COUNT = _get_device_module(DEVICE_TYPE).device_count()
+
+
+def main(world_size: int, rank: int):
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats(rank)
+
+ CHANNELS = 128
+ CROSS_ATTENTION_DIM = 2048
+ CAPTION_CHANNELS = 4096
+ NUM_LAYERS = 28
+ NUM_ATTENTION_HEADS = 32
+ ATTENTION_HEAD_DIM = 64
+
+ # CHANNELS = 4
+ # CROSS_ATTENTION_DIM = 32
+ # CAPTION_CHANNELS = 64
+ # NUM_LAYERS = 1
+ # NUM_ATTENTION_HEADS = 4
+ # ATTENTION_HEAD_DIM = 8
+
+ config = {
+ "in_channels": CHANNELS,
+ "out_channels": CHANNELS,
+ "patch_size": 1,
+ "patch_size_t": 1,
+ "num_attention_heads": NUM_ATTENTION_HEADS,
+ "attention_head_dim": ATTENTION_HEAD_DIM,
+ "cross_attention_dim": CROSS_ATTENTION_DIM,
+ "num_layers": NUM_LAYERS,
+ "activation_fn": "gelu-approximate",
+ "qk_norm": "rms_norm_across_heads",
+ "norm_elementwise_affine": False,
+ "norm_eps": 1e-6,
+ "caption_channels": CAPTION_CHANNELS,
+ "attention_bias": True,
+ "attention_out_bias": True,
+ }
+
+ # Normal model
+ torch.manual_seed(0)
+ model = LTXVideoTransformer3DModel(**config).to(DEVICE_TYPE)
+
+ # TP model
+ model_tp = copy.deepcopy(model)
+ device_mesh = DeviceMesh(DEVICE_TYPE, torch.arange(world_size))
+ print(f"Device mesh: {device_mesh}")
+
+ transformer_tp_plan = {
+ # ===== Condition embeddings =====
+ # "time_embed.emb.timestep_embedder.linear_1": ColwiseParallel(),
+ # "time_embed.emb.timestep_embedder.linear_2": RowwiseParallel(output_layouts=Shard(-1)),
+ # "time_embed.linear": ColwiseParallel(input_layouts=Shard(-1), output_layouts=Replicate()),
+ # "time_embed": PrepareModuleOutput(output_layouts=(Replicate(), Shard(-1)), desired_output_layouts=(Replicate(), Replicate())),
+ # "caption_projection.linear_1": ColwiseParallel(),
+ # "caption_projection.linear_2": RowwiseParallel(),
+ # "rope": PrepareModuleOutput(output_layouts=(Replicate(), Replicate()), desired_output_layouts=(Shard(1), Shard(1)), use_local_output=False),
+ # ===== =====
+ }
+
+ for block in model_tp.transformer_blocks:
+ block_tp_plan = {}
+
+ # ===== Attention =====
+ # 8 all-to-all, 3 all-reduce
+ # block_tp_plan["attn1.to_q"] = ColwiseParallel(use_local_output=False)
+ # block_tp_plan["attn1.to_k"] = ColwiseParallel(use_local_output=False)
+ # block_tp_plan["attn1.to_v"] = ColwiseParallel(use_local_output=False)
+ # block_tp_plan["attn1.norm_q"] = SequenceParallel()
+ # block_tp_plan["attn1.norm_k"] = SequenceParallel()
+ # block_tp_plan["attn1.to_out.0"] = RowwiseParallel(input_layouts=Shard(1))
+ # block_tp_plan["attn2.to_q"] = ColwiseParallel(use_local_output=False)
+ # block_tp_plan["attn2.to_k"] = ColwiseParallel(use_local_output=False)
+ # block_tp_plan["attn2.to_v"] = ColwiseParallel(use_local_output=False)
+ # block_tp_plan["attn2.norm_q"] = SequenceParallel()
+ # block_tp_plan["attn2.norm_k"] = SequenceParallel()
+ # block_tp_plan["attn2.to_out.0"] = RowwiseParallel(input_layouts=Shard(1))
+ # ===== =====
+
+ block_tp_plan["ff.net.0.proj"] = ColwiseParallel()
+ block_tp_plan["ff.net.2"] = RowwiseParallel()
+ parallelize_module(block, device_mesh, block_tp_plan)
+
+ parallelize_module(model_tp, device_mesh, transformer_tp_plan)
+
+ comm_mode = CommDebugMode()
+
+ batch_size = 2
+ num_frames, height, width = 49, 512, 512
+ temporal_compression_ratio, spatial_compression_ratio = 8, 32
+ latent_num_frames, latent_height, latent_width = (
+ (num_frames - 1) // temporal_compression_ratio + 1,
+ height // spatial_compression_ratio,
+ width // spatial_compression_ratio,
+ )
+ video_sequence_length = latent_num_frames * latent_height * latent_width
+ caption_sequence_length = 64
+
+ hidden_states = torch.randn(batch_size, video_sequence_length, CHANNELS, device=DEVICE_TYPE)
+ encoder_hidden_states = torch.randn(batch_size, caption_sequence_length, CAPTION_CHANNELS, device=DEVICE_TYPE)
+ encoder_attention_mask = None
+ timestep = torch.randint(0, 1000, (batch_size, 1), device=DEVICE_TYPE)
+ inputs = {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "encoder_attention_mask": encoder_attention_mask,
+ "timestep": timestep,
+ "num_frames": latent_num_frames,
+ "height": latent_height,
+ "width": latent_width,
+ "rope_interpolation_scale": [1 / (8 / 25), 8, 8],
+ "return_dict": False,
+ }
+
+ output = model(**inputs)[0]
+
+ with comm_mode:
+ output_tp = model_tp(**inputs)[0]
+
+ output_tp = (
+ output_tp.redistribute(output_tp.device_mesh, [Replicate()]).to_local()
+ if isinstance(output_tp, DTensor)
+ else output_tp
+ )
+
+ print("Output shapes:", output.shape, output_tp.shape)
+ print(
+ "Comparing output:",
+ rank,
+ torch.allclose(output, output_tp, atol=1e-5, rtol=1e-5),
+ (output - output_tp).abs().max(),
+ )
+ print(f"Max memory reserved ({rank=}): {torch.cuda.max_memory_reserved(rank) / 1024**3:.2f} GB")
+
+ if rank == 0:
+ print()
+ print("get_comm_counts:", comm_mode.get_comm_counts())
+ # print()
+ # print("get_parameter_info:", comm_mode.get_parameter_info()) # Too much noise
+ print()
+ print("Sharding info:\n" + "".join(f"{k} - {v}\n" for k, v in comm_mode.get_sharding_info().items()))
+ print()
+ print("get_total_counts:", comm_mode.get_total_counts())
+ comm_mode.generate_json_dump("dump_comm_mode_log.json", noise_level=1)
+ comm_mode.log_comm_debug_tracing_table_to_file("dump_comm_mode_tracing_table.txt", noise_level=1)
+
+
+dist.init_process_group(PG_BACKEND)
+WORLD_SIZE = dist.get_world_size()
+RANK = dist.get_rank()
+
+torch.cuda.set_device(RANK)
+
+if RANK == 0:
+ print(f"World size: {WORLD_SIZE}")
+ print(f"Device count: {DEVICE_COUNT}")
+
+try:
+ with torch.no_grad():
+ main(WORLD_SIZE, RANK)
+finally:
+ dist.destroy_process_group()
+
+
+# LTXVideoTransformer3DModel(
+# (proj_in): Linear(in_features=128, out_features=2048, bias=True)
+# (time_embed): AdaLayerNormSingle(
+# (emb): PixArtAlphaCombinedTimestepSizeEmbeddings(
+# (time_proj): Timesteps()
+# (timestep_embedder): TimestepEmbedding(
+# (linear_1): Linear(in_features=256, out_features=2048, bias=True)
+# (act): SiLU()
+# (linear_2): Linear(in_features=2048, out_features=2048, bias=True)
+# )
+# )
+# (silu): SiLU()
+# (linear): Linear(in_features=2048, out_features=12288, bias=True)
+# )
+# (caption_projection): PixArtAlphaTextProjection(
+# (linear_1): Linear(in_features=4096, out_features=2048, bias=True)
+# (act_1): GELU(approximate='tanh')
+# (linear_2): Linear(in_features=2048, out_features=2048, bias=True)
+# )
+# (rope): LTXVideoRotaryPosEmbed()
+# (transformer_blocks): ModuleList(
+# (0-27): 28 x LTXVideoTransformerBlock(
+# (norm1): RMSNorm()
+# (attn1): Attention(
+# (norm_q): RMSNorm()
+# (norm_k): RMSNorm()
+# (to_q): Linear(in_features=2048, out_features=2048, bias=True)
+# (to_k): Linear(in_features=2048, out_features=2048, bias=True)
+# (to_v): Linear(in_features=2048, out_features=2048, bias=True)
+# (to_out): ModuleList(
+# (0): Linear(in_features=2048, out_features=2048, bias=True)
+# (1): Dropout(p=0.0, inplace=False)
+# )
+# )
+# (norm2): RMSNorm()
+# (attn2): Attention(
+# (norm_q): RMSNorm()
+# (norm_k): RMSNorm()
+# (to_q): Linear(in_features=2048, out_features=2048, bias=True)
+# (to_k): Linear(in_features=2048, out_features=2048, bias=True)
+# (to_v): Linear(in_features=2048, out_features=2048, bias=True)
+# (to_out): ModuleList(
+# (0): Linear(in_features=2048, out_features=2048, bias=True)
+# (1): Dropout(p=0.0, inplace=False)
+# )
+# )
+# (ff): FeedForward(
+# (net): ModuleList(
+# (0): GELU(
+# (proj): Linear(in_features=2048, out_features=8192, bias=True)
+# )
+# (1): Dropout(p=0.0, inplace=False)
+# (2): Linear(in_features=8192, out_features=2048, bias=True)
+# )
+# )
+# )
+# )
+# (norm_out): LayerNorm((2048,), eps=1e-06, elementwise_affine=False)
+# (proj_out): Linear(in_features=2048, out_features=128, bias=True)
+# )
diff --git a/docs/finetrainers-src-codebase/tests/models/ltx_video/base_specification.py b/docs/finetrainers-src-codebase/tests/models/ltx_video/base_specification.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf8e65aee6b4b7ee2b679e6e9c1cdf12d3597d97
--- /dev/null
+++ b/docs/finetrainers-src-codebase/tests/models/ltx_video/base_specification.py
@@ -0,0 +1,63 @@
+import torch
+from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXVideoTransformer3DModel
+from transformers import AutoTokenizer, T5EncoderModel
+
+from finetrainers.models.ltx_video import LTXVideoModelSpecification
+
+
+class DummyLTXVideoModelSpecification(LTXVideoModelSpecification):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def load_condition_models(self):
+ text_encoder = T5EncoderModel.from_pretrained(
+ "hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype
+ )
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+ return {"text_encoder": text_encoder, "tokenizer": tokenizer}
+
+ def load_latent_models(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLLTXVideo(
+ in_channels=3,
+ out_channels=3,
+ latent_channels=8,
+ block_out_channels=(8, 8, 8, 8),
+ decoder_block_out_channels=(8, 8, 8, 8),
+ layers_per_block=(1, 1, 1, 1, 1),
+ decoder_layers_per_block=(1, 1, 1, 1, 1),
+ spatio_temporal_scaling=(True, True, False, False),
+ decoder_spatio_temporal_scaling=(True, True, False, False),
+ decoder_inject_noise=(False, False, False, False, False),
+ upsample_residual=(False, False, False, False),
+ upsample_factor=(1, 1, 1, 1),
+ timestep_conditioning=False,
+ patch_size=1,
+ patch_size_t=1,
+ encoder_causal=True,
+ decoder_causal=False,
+ )
+ # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
+ # Doing so overrides things like _keep_in_fp32_modules
+ vae.to(self.vae_dtype)
+ self.vae_config = vae.config
+ return {"vae": vae}
+
+ def load_diffusion_models(self):
+ torch.manual_seed(0)
+ transformer = LTXVideoTransformer3DModel(
+ in_channels=8,
+ out_channels=8,
+ patch_size=1,
+ patch_size_t=1,
+ num_attention_heads=4,
+ attention_head_dim=8,
+ cross_attention_dim=32,
+ num_layers=1,
+ caption_channels=32,
+ )
+ # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
+ # Doing so overrides things like _keep_in_fp32_modules
+ transformer.to(self.transformer_dtype)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+ return {"transformer": transformer, "scheduler": scheduler}
diff --git a/docs/finetrainers-src-codebase/tests/models/wan/__init__.py b/docs/finetrainers-src-codebase/tests/models/wan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/docs/finetrainers-src-codebase/tests/models/wan/base_specification.py b/docs/finetrainers-src-codebase/tests/models/wan/base_specification.py
new file mode 100644
index 0000000000000000000000000000000000000000..59bb4809edb6702672334e10e176b955b3ef6a5e
--- /dev/null
+++ b/docs/finetrainers-src-codebase/tests/models/wan/base_specification.py
@@ -0,0 +1,54 @@
+import torch
+from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanTransformer3DModel
+from transformers import AutoTokenizer, T5EncoderModel
+
+from finetrainers.models.wan import WanModelSpecification
+
+
+class DummyWanModelSpecification(WanModelSpecification):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def load_condition_models(self):
+ text_encoder = T5EncoderModel.from_pretrained(
+ "hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype
+ )
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+ return {"text_encoder": text_encoder, "tokenizer": tokenizer}
+
+ def load_latent_models(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+ # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
+ # Doing so overrides things like _keep_in_fp32_modules
+ vae.to(self.vae_dtype)
+ self.vae_config = vae.config
+ return {"vae": vae}
+
+ def load_diffusion_models(self):
+ torch.manual_seed(0)
+ transformer = WanTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+ # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
+ # Doing so overrides things like _keep_in_fp32_modules
+ transformer.to(self.transformer_dtype)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+ return {"transformer": transformer, "scheduler": scheduler}
diff --git a/docs/finetrainers-src-codebase/tests/models/wan/control_specification.py b/docs/finetrainers-src-codebase/tests/models/wan/control_specification.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc0498a144d54422b6ae38c00ce83f4e201d67ae
--- /dev/null
+++ b/docs/finetrainers-src-codebase/tests/models/wan/control_specification.py
@@ -0,0 +1,66 @@
+import torch
+from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanTransformer3DModel
+from transformers import AutoTokenizer, T5EncoderModel
+
+from finetrainers.models.utils import _expand_conv3d_with_zeroed_weights
+from finetrainers.models.wan import WanControlModelSpecification
+
+
+class DummyWanControlModelSpecification(WanControlModelSpecification):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ # This needs to be updated for the test to work correctly.
+ # TODO(aryan): it will not be needed if we hosted the dummy model so that the correct config could be loaded
+ # with ModelSpecification::_load_configs
+ self.transformer_config.in_channels = 16
+
+ def load_condition_models(self):
+ text_encoder = T5EncoderModel.from_pretrained(
+ "hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype
+ )
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+ return {"text_encoder": text_encoder, "tokenizer": tokenizer}
+
+ def load_latent_models(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+ # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
+ # Doing so overrides things like _keep_in_fp32_modules
+ vae.to(self.vae_dtype)
+ self.vae_config = vae.config
+ return {"vae": vae}
+
+ def load_diffusion_models(self, new_in_features: int):
+ torch.manual_seed(0)
+ transformer = WanTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ ).to(self.transformer_dtype)
+
+ transformer.patch_embedding = _expand_conv3d_with_zeroed_weights(
+ transformer.patch_embedding, new_in_channels=new_in_features
+ )
+ transformer.register_to_config(in_channels=new_in_features)
+
+ # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
+ # Doing so overrides things like _keep_in_fp32_modules
+ transformer.to(self.transformer_dtype)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+ return {"transformer": transformer, "scheduler": scheduler}
diff --git a/docs/finetrainers-src-codebase/tests/test_lora_inference.py b/docs/finetrainers-src-codebase/tests/test_lora_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1c0439b19b16fbba898e61545a2d65a8c649393
--- /dev/null
+++ b/docs/finetrainers-src-codebase/tests/test_lora_inference.py
@@ -0,0 +1,44 @@
+"""
+Run this test in Lora adpater checking:
+
+```shell
+python3 test_lora_inference.py --prompt "A girl is ridding a bike." --model_path "THUDM/CogVideoX-5B" --lora_path "path/to/lora" --lora_name "lora_adapter" --output_file "output.mp4" --fps 8
+```
+
+"""
+
+import argparse
+
+import torch
+from diffusers import CogVideoXPipeline
+from diffusers.utils import export_to_video
+
+
+def generate_video(model_path, prompt, lora_path, lora_name, output_file, fps):
+ pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16).to("cuda")
+ pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name=lora_name)
+ pipe.set_adapters([lora_name], [1.0])
+ pipe.enable_model_cpu_offload()
+ pipe.vae.enable_slicing()
+ pipe.vae.enable_tiling()
+
+ video = pipe(prompt=prompt).frames[0]
+ export_to_video(video, output_file, fps=fps)
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Generate video using CogVideoX and LoRA weights")
+ parser.add_argument("--prompt", type=str, required=True, help="Text prompt for the video generation")
+ parser.add_argument("--model_path", type=str, default="THUDM/CogVideoX-5B", help="Base Model path or HF ID")
+ parser.add_argument("--lora_path", type=str, required=True, help="Path to the LoRA weights")
+ parser.add_argument("--lora_name", type=str, default="lora_adapter", help="Name of the LoRA adapter")
+ parser.add_argument("--output_file", type=str, default="output.mp4", help="Output video file name")
+ parser.add_argument("--fps", type=int, default=8, help="Frames per second for the output video")
+
+ args = parser.parse_args()
+
+ generate_video(args.prompt, args.lora_path, args.lora_name, args.output_file, args.fps)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/docs/finetrainers-src-codebase/tests/test_model_runs_minimally_lora.sh b/docs/finetrainers-src-codebase/tests/test_model_runs_minimally_lora.sh
new file mode 100755
index 0000000000000000000000000000000000000000..ebcab61605afadf1a984953547c771ea8e034f36
--- /dev/null
+++ b/docs/finetrainers-src-codebase/tests/test_model_runs_minimally_lora.sh
@@ -0,0 +1,48 @@
+#!/bin/bash
+
+# This shell script is for the maintainers and contributors to QUICKLY check
+# if the major changes they're introducing still work with the rest of the models supported
+# in `finetrainers`. It DOES NOT give a sense of implementation correctness as that requires
+# much longer training runs but it DOES ensure basic functionalities work in the large training
+# setup.
+
+# It should be run as so from the root of `finetrainers`: `bash tests/test_model_runs_minimally_lora.sh`
+
+######################################################
+# Set common variables.
+######################################################
+
+ROOT_DIR="$(cd "$(dirname "$0")/.." && pwd)"
+export ROOT_DIR
+export WANDB_MODE="offline"
+export NCCL_P2P_DISABLE=1
+export TORCH_NCCL_ENABLE_MONITORING=0
+export FINETRAINERS_LOG_LEVEL=DEBUG
+
+echo "Using $ROOT_DIR as rootdir."
+
+######################################################
+# Download Disney dataset.
+######################################################
+
+# Ensure dataset is downloaded
+DATA_ROOT="$ROOT_DIR/video-dataset-disney"
+if [ ! -d "$DATA_ROOT" ]; then
+ echo "Downloading Disney dataset to $DATA_ROOT..."
+ huggingface-cli download \
+ --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset \
+ --local-dir "$DATA_ROOT"
+else
+ echo "Dataset already exists at $DATA_ROOT. Skipping download."
+fi
+
+######################################################
+# Run models
+######################################################
+
+# Define models to test
+models=("dummy_ltx_video_lora" "dummy_cogvideox_lora" "dummy_hunyuanvideo_lora")
+for model_script in "${models[@]}"; do
+ echo "Running $model_script test..."
+ bash $ROOT_DIR/tests/scripts/$model_script.sh
+done
\ No newline at end of file
diff --git a/docs/finetrainers-src-codebase/tests/test_trackers.py b/docs/finetrainers-src-codebase/tests/test_trackers.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9fee180a53669983b3e8709f68a9240131dcf48
--- /dev/null
+++ b/docs/finetrainers-src-codebase/tests/test_trackers.py
@@ -0,0 +1,26 @@
+import logging
+import os
+import pathlib
+import tempfile
+import unittest
+
+from diffusers.utils.testing_utils import CaptureLogger
+
+from finetrainers.trackers import WandbTracker
+
+
+os.environ["WANDB_MODE"] = "offline"
+
+
+class WandbFastTests(unittest.TestCase):
+ def test_wandb_logdir(self):
+ logger = logging.getLogger("finetrainers")
+
+ with tempfile.TemporaryDirectory() as tempdir, CaptureLogger(logger) as cap_log:
+ tracker = WandbTracker("finetrainers-experiment", log_dir=tempdir, config={})
+ tracker.log({"loss": 0.1}, step=0)
+ tracker.log({"loss": 0.2}, step=1)
+ tracker.finish()
+ self.assertTrue(pathlib.Path(tempdir).exists())
+
+ self.assertTrue("WandB logging enabled" in cap_log.out)
diff --git a/docs/finetrainers-src-codebase/tests/trainer/__init__.py b/docs/finetrainers-src-codebase/tests/trainer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/docs/finetrainers-src-codebase/tests/trainer/test_control_trainer.py b/docs/finetrainers-src-codebase/tests/trainer/test_control_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4325904b81bcc077a16921adb2f47a999e44e62
--- /dev/null
+++ b/docs/finetrainers-src-codebase/tests/trainer/test_control_trainer.py
@@ -0,0 +1,274 @@
+# torchrun --nnodes=1 --nproc_per_node=1 -m pytest -s tests/trainer/test_sft_trainer.py
+
+import json
+import os
+import pathlib
+import tempfile
+import time
+import unittest
+
+import pytest
+from diffusers.utils import export_to_video
+from parameterized import parameterized
+from PIL import Image
+
+from finetrainers import BaseArgs, ControlTrainer, TrainingType, get_logger
+from finetrainers.trainer.control_trainer.config import ControlType
+
+
+os.environ["WANDB_MODE"] = "disabled"
+os.environ["FINETRAINERS_LOG_LEVEL"] = "INFO"
+
+from ..models.cogview4.control_specification import DummyCogView4ControlModelSpecification # noqa
+from ..models.wan.control_specification import DummyWanControlModelSpecification # noqa
+
+
+logger = get_logger()
+
+
+@pytest.fixture(autouse=True)
+def slow_down_tests():
+ yield
+ # Sleep between each test so that process groups are cleaned and resources are released.
+ # Not doing so seems to randomly trigger some test failures, which wouldn't fail if run individually.
+ # !!!Look into this in future!!!
+ time.sleep(5)
+
+
+class ControlTrainerFastTestsMixin:
+ model_specification_cls = None
+ num_data_files = 4
+ num_frames = 4
+ height = 64
+ width = 64
+
+ def setUp(self):
+ self.tmpdir = tempfile.TemporaryDirectory()
+ self.data_files = []
+ for i in range(self.num_data_files):
+ data_file = pathlib.Path(self.tmpdir.name) / f"{i}.mp4"
+ export_to_video(
+ [Image.new("RGB", (self.width, self.height))] * self.num_frames, data_file.as_posix(), fps=2
+ )
+ self.data_files.append(data_file.as_posix())
+
+ csv_filename = pathlib.Path(self.tmpdir.name) / "metadata.csv"
+ with open(csv_filename.as_posix(), "w") as f:
+ f.write("file_name,caption\n")
+ for i in range(self.num_data_files):
+ prompt = f"A cat ruling the world - {i}"
+ f.write(f'{i}.mp4,"{prompt}"\n')
+
+ dataset_config = {
+ "datasets": [
+ {
+ "data_root": self.tmpdir.name,
+ "dataset_type": "video",
+ "id_token": "TEST",
+ "video_resolution_buckets": [[self.num_frames, self.height, self.width]],
+ "reshape_mode": "bicubic",
+ }
+ ]
+ }
+
+ self.dataset_config_filename = pathlib.Path(self.tmpdir.name) / "dataset_config.json"
+ with open(self.dataset_config_filename.as_posix(), "w") as f:
+ json.dump(dataset_config, f)
+
+ def tearDown(self):
+ self.tmpdir.cleanup()
+
+ def get_base_args(self) -> BaseArgs:
+ args = BaseArgs()
+ args.dataset_config = self.dataset_config_filename.as_posix()
+ args.train_steps = 10
+ args.max_data_samples = 25
+ args.batch_size = 1
+ args.gradient_checkpointing = True
+ args.output_dir = self.tmpdir.name
+ args.checkpointing_steps = 6
+ args.enable_precomputation = False
+ args.precomputation_items = self.num_data_files
+ args.precomputation_dir = os.path.join(self.tmpdir.name, "precomputed")
+ args.compile_scopes = "regional" # This will only be in effect when `compile_modules` is set
+
+ args.control_type = ControlType.CANNY
+ args.train_qk_norm = True
+ args.frame_conditioning_type = "random"
+ args.frame_conditioning_index = None
+ args.frame_conditioning_concatenate_mask = False
+
+ return args
+
+ def get_args(self) -> BaseArgs:
+ raise NotImplementedError("`get_args` must be implemented in the subclass.")
+
+ def _test_training(self, args: BaseArgs):
+ model_specification = self.model_specification_cls()
+ trainer = ControlTrainer(args, model_specification)
+ trainer.run()
+
+
+class ControlTrainerLoRATestsMixin___PTD(ControlTrainerFastTestsMixin):
+ def get_args(self) -> BaseArgs:
+ args = self.get_base_args()
+ args.parallel_backend = "ptd"
+ args.training_type = TrainingType.CONTROL_LORA
+ args.rank = 4
+ args.lora_alpha = 4
+ args.target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
+ return args
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 1
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_degree_1___batch_size_2(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 1
+ args.batch_size = 2
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_degree_2___batch_size_2(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 2
+ args.batch_size = 2
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_shards_2___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_shards = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_shards_2___batch_size_2(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_shards = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 2
+ args.dp_shards = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___tp_degree_2___batch_size_2(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.tp_degree = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+
+class ControlTrainerFullFinetuneTestsMixin___PTD(ControlTrainerFastTestsMixin):
+ def get_args(self) -> BaseArgs:
+ args = self.get_base_args()
+ args.parallel_backend = "ptd"
+ args.training_type = TrainingType.CONTROL_FULL_FINETUNE
+ return args
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 1
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_degree_1___batch_size_2(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 1
+ args.batch_size = 2
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_degree_2___batch_size_2(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 2
+ args.batch_size = 2
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_shards_2___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_shards = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_shards_2___batch_size_2(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_shards = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 2
+ args.dp_shards = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___tp_degree_2___batch_size_2(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.tp_degree = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+
+class ControlTrainerCogView4LoRATests___PTD(ControlTrainerLoRATestsMixin___PTD, unittest.TestCase):
+ model_specification_cls = DummyCogView4ControlModelSpecification
+
+
+class ControlTrainerCogView4FullFinetuneTests___PTD(ControlTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase):
+ model_specification_cls = DummyCogView4ControlModelSpecification
+
+
+class ControlTrainerWanLoRATests___PTD(ControlTrainerLoRATestsMixin___PTD, unittest.TestCase):
+ model_specification_cls = DummyWanControlModelSpecification
+
+
+class ControlTrainerWanFullFinetuneTests___PTD(ControlTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase):
+ model_specification_cls = DummyWanControlModelSpecification
diff --git a/docs/finetrainers-src-codebase/tests/trainer/test_sft_trainer.py b/docs/finetrainers-src-codebase/tests/trainer/test_sft_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..96a09ce47275d307354e2cac69899c44139a719a
--- /dev/null
+++ b/docs/finetrainers-src-codebase/tests/trainer/test_sft_trainer.py
@@ -0,0 +1,537 @@
+# torchrun --nnodes=1 --nproc_per_node=1 -m pytest -s tests/trainer/test_sft_trainer.py
+
+import json
+import os
+import pathlib
+import tempfile
+import time
+import unittest
+
+import pytest
+import torch
+from diffusers.utils import export_to_video
+from parameterized import parameterized
+from PIL import Image
+
+from finetrainers import BaseArgs, SFTTrainer, TrainingType, get_logger
+
+
+os.environ["WANDB_MODE"] = "disabled"
+os.environ["FINETRAINERS_LOG_LEVEL"] = "INFO"
+
+from ..models.cogvideox.base_specification import DummyCogVideoXModelSpecification # noqa
+from ..models.cogview4.base_specification import DummyCogView4ModelSpecification # noqa
+from ..models.flux.base_specification import DummyFluxModelSpecification # noqa
+from ..models.hunyuan_video.base_specification import DummyHunyuanVideoModelSpecification # noqa
+from ..models.ltx_video.base_specification import DummyLTXVideoModelSpecification # noqa
+from ..models.wan.base_specification import DummyWanModelSpecification # noqa
+
+
+logger = get_logger()
+
+
+@pytest.fixture(autouse=True)
+def slow_down_tests():
+ yield
+ # Sleep between each test so that process groups are cleaned and resources are released.
+ # Not doing so seems to randomly trigger some test failures, which wouldn't fail if run individually.
+ # !!!Look into this in future!!!
+ time.sleep(5)
+
+
+class SFTTrainerFastTestsMixin:
+ model_specification_cls = None
+ num_data_files = 4
+ num_frames = 4
+ height = 64
+ width = 64
+
+ def setUp(self):
+ self.tmpdir = tempfile.TemporaryDirectory()
+ self.data_files = []
+ for i in range(self.num_data_files):
+ data_file = pathlib.Path(self.tmpdir.name) / f"{i}.mp4"
+ export_to_video(
+ [Image.new("RGB", (self.width, self.height))] * self.num_frames, data_file.as_posix(), fps=2
+ )
+ self.data_files.append(data_file.as_posix())
+
+ csv_filename = pathlib.Path(self.tmpdir.name) / "metadata.csv"
+ with open(csv_filename.as_posix(), "w") as f:
+ f.write("file_name,caption\n")
+ for i in range(self.num_data_files):
+ prompt = f"A cat ruling the world - {i}"
+ f.write(f'{i}.mp4,"{prompt}"\n')
+
+ dataset_config = {
+ "datasets": [
+ {
+ "data_root": self.tmpdir.name,
+ "dataset_type": "video",
+ "id_token": "TEST",
+ "video_resolution_buckets": [[self.num_frames, self.height, self.width]],
+ "reshape_mode": "bicubic",
+ }
+ ]
+ }
+
+ self.dataset_config_filename = pathlib.Path(self.tmpdir.name) / "dataset_config.json"
+ with open(self.dataset_config_filename.as_posix(), "w") as f:
+ json.dump(dataset_config, f)
+
+ def tearDown(self):
+ self.tmpdir.cleanup()
+ # For some reason, if the process group is not destroyed, the tests that follow will fail. Just manually
+ # make sure to destroy it here.
+ if torch.distributed.is_initialized():
+ torch.distributed.destroy_process_group()
+ time.sleep(3)
+
+ def get_base_args(self) -> BaseArgs:
+ args = BaseArgs()
+ args.dataset_config = self.dataset_config_filename.as_posix()
+ args.train_steps = 10
+ args.max_data_samples = 25
+ args.batch_size = 1
+ args.gradient_checkpointing = True
+ args.output_dir = self.tmpdir.name
+ args.checkpointing_steps = 6
+ args.enable_precomputation = False
+ args.precomputation_items = self.num_data_files
+ args.precomputation_dir = os.path.join(self.tmpdir.name, "precomputed")
+ args.compile_scopes = "regional" # This will only be in effect when `compile_modules` is set
+ # args.attn_provider_training = ["transformer:_native_cudnn"]
+ # args.attn_provider_inference = ["transformer:_native_cudnn"]
+ return args
+
+ def get_args(self) -> BaseArgs:
+ raise NotImplementedError("`get_args` must be implemented in the subclass.")
+
+ def _test_training(self, args: BaseArgs):
+ model_specification = self.model_specification_cls()
+ trainer = SFTTrainer(args, model_specification)
+ trainer.run()
+
+
+# =============== ===============
+
+
+class SFTTrainerLoRATestsMixin___Accelerate(SFTTrainerFastTestsMixin):
+ def get_args(self) -> BaseArgs:
+ args = self.get_base_args()
+ args.parallel_backend = "accelerate"
+ args.training_type = TrainingType.LORA
+ args.rank = 4
+ args.lora_alpha = 4
+ args.target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
+ return args
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 1
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(True,)])
+ def test___layerwise_upcasting___dp_degree_1___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 1
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ args.layerwise_upcasting_modules = ["transformer"]
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+
+class SFTTrainerFullFinetuneTestsMixin___Accelerate(SFTTrainerFastTestsMixin):
+ def get_args(self) -> BaseArgs:
+ args = self.get_base_args()
+ args.parallel_backend = "accelerate"
+ args.training_type = TrainingType.FULL_FINETUNE
+ return args
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 1
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+
+class SFTTrainerCogVideoXLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase):
+ model_specification_cls = DummyCogVideoXModelSpecification
+
+
+class SFTTrainerCogVideoXFullFinetuneTests___Accelerate(
+ SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase
+):
+ model_specification_cls = DummyCogVideoXModelSpecification
+
+
+class SFTTrainerCogView4LoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase):
+ model_specification_cls = DummyCogView4ModelSpecification
+
+
+class SFTTrainerCogView4FullFinetuneTests___Accelerate(
+ SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase
+):
+ model_specification_cls = DummyCogView4ModelSpecification
+
+
+class SFTTrainerFluxLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase):
+ model_specification_cls = DummyFluxModelSpecification
+
+
+class SFTTrainerFluxFullFinetuneTests___Accelerate(SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase):
+ model_specification_cls = DummyFluxModelSpecification
+
+
+class SFTTrainerHunyuanVideoLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase):
+ model_specification_cls = DummyHunyuanVideoModelSpecification
+
+
+class SFTTrainerHunyuanVideoFullFinetuneTests___Accelerate(
+ SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase
+):
+ model_specification_cls = DummyHunyuanVideoModelSpecification
+
+
+class SFTTrainerLTXVideoLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase):
+ model_specification_cls = DummyLTXVideoModelSpecification
+
+
+class SFTTrainerLTXVideoFullFinetuneTests___Accelerate(
+ SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase
+):
+ model_specification_cls = DummyLTXVideoModelSpecification
+
+
+class SFTTrainerWanLoRATests___Accelerate(SFTTrainerLoRATestsMixin___Accelerate, unittest.TestCase):
+ model_specification_cls = DummyWanModelSpecification
+
+
+class SFTTrainerWanFullFinetuneTests___Accelerate(SFTTrainerFullFinetuneTestsMixin___Accelerate, unittest.TestCase):
+ model_specification_cls = DummyWanModelSpecification
+
+
+# =============== ===============
+
+# =============== ===============
+
+
+class SFTTrainerLoRATestsMixin___PTD(SFTTrainerFastTestsMixin):
+ def get_args(self) -> BaseArgs:
+ args = self.get_base_args()
+ args.parallel_backend = "ptd"
+ args.training_type = TrainingType.LORA
+ args.rank = 4
+ args.lora_alpha = 4
+ args.target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
+ return args
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 1
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(True,)])
+ def test___layerwise_upcasting___dp_degree_1___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 1
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ args.layerwise_upcasting_modules = ["transformer"]
+ self._test_training(args)
+
+ @parameterized.expand([(True,)])
+ def test___compile___dp_degree_1___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 1
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ args.compile_modules = ["transformer"]
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_degree_1___batch_size_2(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 1
+ args.batch_size = 2
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(True,)])
+ def test___layerwise_upcasting___dp_degree_2___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ args.layerwise_upcasting_modules = ["transformer"]
+ self._test_training(args)
+
+ @parameterized.expand([(True,)])
+ def test___compile___dp_degree_2___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ args.compile_modules = ["transformer"]
+ self._test_training(args)
+
+ @parameterized.expand([(True,)])
+ def test___dp_degree_2___batch_size_2(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 2
+ args.batch_size = 2
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_shards_2___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_shards = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(True,)])
+ def test___dp_shards_2___batch_size_2(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_shards = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 2
+ args.dp_shards = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___tp_degree_2___batch_size_2(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.tp_degree = 2
+ args.batch_size = 2
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @unittest.skip(
+ "TODO: The model specifications for CP with cudnn/flash/efficient backend require the attention head dim to be a multiple with 8. Land math backend first for fast tests and then enable this test."
+ )
+ @parameterized.expand([(True,)])
+ def test___cp_degree_2___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.cp_degree = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @unittest.skip(
+ "TODO: The model specifications for CP with cudnn/flash/efficient backend require the attention head dim to be a multiple with 8. Land math backend first for fast tests and then enable this test."
+ )
+ @parameterized.expand([(True,)])
+ def test___dp_degree_2___cp_degree_2___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 2
+ args.cp_degree = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+
+class SFTTrainerFullFinetuneTestsMixin___PTD(SFTTrainerFastTestsMixin):
+ def get_args(self) -> BaseArgs:
+ args = self.get_base_args()
+ args.parallel_backend = "ptd"
+ args.training_type = TrainingType.FULL_FINETUNE
+ return args
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 1
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(True,)])
+ def test___compile___dp_degree_1___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 1
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ args.compile_modules = ["transformer"]
+ self._test_training(args)
+
+ @parameterized.expand([(True,)])
+ def test___dp_degree_1___batch_size_2(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 1
+ args.batch_size = 2
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(True,)])
+ def test___compile___dp_degree_2___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ args.compile_modules = ["transformer"]
+ self._test_training(args)
+
+ @parameterized.expand([(True,)])
+ def test___dp_degree_2___batch_size_2(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 2
+ args.batch_size = 2
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_shards_2___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_shards = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(True,)])
+ def test___dp_shards_2___batch_size_2(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_shards = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 2
+ args.dp_shards = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @parameterized.expand([(False,), (True,)])
+ def test___tp_degree_2___batch_size_2(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.tp_degree = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @unittest.skip(
+ "TODO: The model specifications for CP with cudnn/flash/efficient backend require the attention head dim to be a multiple with 8. Land math backend first for fast tests and then enable this test."
+ )
+ @parameterized.expand([(True,)])
+ def test___cp_degree_2___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.cp_degree = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+ @unittest.skip(
+ "TODO: The model specifications for CP with cudnn/flash/efficient backend require the attention head dim to be a multiple with 8. Land math backend first for fast tests and then enable this test."
+ )
+ @parameterized.expand([(True,)])
+ def test___dp_degree_2___cp_degree_2___batch_size_1(self, enable_precomputation: bool):
+ args = self.get_args()
+ args.dp_degree = 2
+ args.cp_degree = 2
+ args.batch_size = 1
+ args.enable_precomputation = enable_precomputation
+ self._test_training(args)
+
+
+class SFTTrainerCogVideoXLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase):
+ model_specification_cls = DummyCogVideoXModelSpecification
+
+
+class SFTTrainerCogVideoXFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase):
+ model_specification_cls = DummyCogVideoXModelSpecification
+
+
+class SFTTrainerCogView4LoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase):
+ model_specification_cls = DummyCogView4ModelSpecification
+
+
+class SFTTrainerCogView4FullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase):
+ model_specification_cls = DummyCogView4ModelSpecification
+
+
+class SFTTrainerFluxLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase):
+ model_specification_cls = DummyFluxModelSpecification
+
+
+class SFTTrainerFluxFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase):
+ model_specification_cls = DummyFluxModelSpecification
+
+
+class SFTTrainerHunyuanVideoLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase):
+ model_specification_cls = DummyHunyuanVideoModelSpecification
+
+
+class SFTTrainerHunyuanVideoFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase):
+ model_specification_cls = DummyHunyuanVideoModelSpecification
+
+
+class SFTTrainerLTXVideoLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase):
+ model_specification_cls = DummyLTXVideoModelSpecification
+
+
+class SFTTrainerLTXVideoFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase):
+ model_specification_cls = DummyLTXVideoModelSpecification
+
+
+class SFTTrainerWanLoRATests___PTD(SFTTrainerLoRATestsMixin___PTD, unittest.TestCase):
+ model_specification_cls = DummyWanModelSpecification
+
+
+class SFTTrainerWanFullFinetuneTests___PTD(SFTTrainerFullFinetuneTestsMixin___PTD, unittest.TestCase):
+ model_specification_cls = DummyWanModelSpecification
+
+
+# =============== ===============
diff --git a/docs/finetrainers-src-codebase/train.py b/docs/finetrainers-src-codebase/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..183ba1eddebf98c0d268442e9d8d79543b311c08
--- /dev/null
+++ b/docs/finetrainers-src-codebase/train.py
@@ -0,0 +1,86 @@
+import sys
+import traceback
+
+from finetrainers import BaseArgs, ControlTrainer, SFTTrainer, TrainingType, get_logger
+from finetrainers.config import _get_model_specifiction_cls
+from finetrainers.trainer.control_trainer.config import ControlFullRankConfig, ControlLowRankConfig
+from finetrainers.trainer.sft_trainer.config import SFTFullRankConfig, SFTLowRankConfig
+
+
+logger = get_logger()
+
+
+def main():
+ try:
+ import multiprocessing
+
+ multiprocessing.set_start_method("fork")
+ except Exception as e:
+ logger.error(
+ f'Failed to set multiprocessing start method to "fork". This can lead to poor performance, high memory usage, or crashes. '
+ f"See: https://pytorch.org/docs/stable/notes/multiprocessing.html\n"
+ f"Error: {e}"
+ )
+
+ try:
+ args = BaseArgs()
+
+ argv = [y.strip() for x in sys.argv for y in x.split()]
+ training_type_index = argv.index("--training_type")
+ if training_type_index == -1:
+ raise ValueError("Training type not provided in command line arguments.")
+
+ training_type = argv[training_type_index + 1]
+ training_cls = None
+ if training_type == TrainingType.LORA:
+ training_cls = SFTLowRankConfig
+ elif training_type == TrainingType.FULL_FINETUNE:
+ training_cls = SFTFullRankConfig
+ elif training_type == TrainingType.CONTROL_LORA:
+ training_cls = ControlLowRankConfig
+ elif training_type == TrainingType.CONTROL_FULL_FINETUNE:
+ training_cls = ControlFullRankConfig
+ else:
+ raise ValueError(f"Training type {training_type} not supported.")
+
+ args.register_args(training_cls())
+ args = args.parse_args()
+
+ model_specification_cls = _get_model_specifiction_cls(args.model_name, args.training_type)
+ model_specification = model_specification_cls(
+ pretrained_model_name_or_path=args.pretrained_model_name_or_path,
+ tokenizer_id=args.tokenizer_id,
+ tokenizer_2_id=args.tokenizer_2_id,
+ tokenizer_3_id=args.tokenizer_3_id,
+ text_encoder_id=args.text_encoder_id,
+ text_encoder_2_id=args.text_encoder_2_id,
+ text_encoder_3_id=args.text_encoder_3_id,
+ transformer_id=args.transformer_id,
+ vae_id=args.vae_id,
+ text_encoder_dtype=args.text_encoder_dtype,
+ text_encoder_2_dtype=args.text_encoder_2_dtype,
+ text_encoder_3_dtype=args.text_encoder_3_dtype,
+ transformer_dtype=args.transformer_dtype,
+ vae_dtype=args.vae_dtype,
+ revision=args.revision,
+ cache_dir=args.cache_dir,
+ )
+
+ if args.training_type in [TrainingType.LORA, TrainingType.FULL_FINETUNE]:
+ trainer = SFTTrainer(args, model_specification)
+ elif args.training_type in [TrainingType.CONTROL_LORA, TrainingType.CONTROL_FULL_FINETUNE]:
+ trainer = ControlTrainer(args, model_specification)
+ else:
+ raise ValueError(f"Training type {args.training_type} not supported.")
+
+ trainer.run()
+
+ except KeyboardInterrupt:
+ logger.info("Received keyboard interrupt. Exiting...")
+ except Exception as e:
+ logger.error(f"An error occurred during training: {e}")
+ logger.error(traceback.format_exc())
+
+
+if __name__ == "__main__":
+ main()
diff --git a/docs/finetrainers/documentation_global_README.md b/docs/finetrainers/documentation_global_README.md
deleted file mode 100644
index 052d7d45af231d88beb2f9a3d946e6c593556fa3..0000000000000000000000000000000000000000
--- a/docs/finetrainers/documentation_global_README.md
+++ /dev/null
@@ -1,99 +0,0 @@
-# finetrainers 🧪
-
-FineTrainers is a work-in-progress library to support (accessible) training of video models. Our first priority is to support LoRA training for all popular video models in [Diffusers](https://github.com/huggingface/diffusers), and eventually other methods like controlnets, control-loras, distillation, etc.
-
-`cogvideox-factory` was renamed to `finetrainers`. If you're looking to train CogVideoX or Mochi with the legacy training scripts, please refer to [this](./training/README.md) README instead. Everything in the `training/` directory will be eventually moved and supported under `finetrainers`.
-
-
-
- |
-
-
-
-## News
-
-- 🔥 **2025-03-03**: Wan T2V support added!
-- 🔥 **2025-03-03**: We have shipped a complete refactor to support multi-backend distributed training, better precomputation handling for big datasets, model specification format (externally usable for training custom models), FSDP & more.
-- 🔥 **2025-02-12**: We have shipped a set of tooling to curate small and high-quality video datasets for fine-tuning. See [video-dataset-scripts](https://github.com/huggingface/video-dataset-scripts) documentation page for details!
-- 🔥 **2025-02-12**: Check out [eisneim/ltx_lora_training_i2v_t2v](https://github.com/eisneim/ltx_lora_training_i2v_t2v/)! It builds off of `finetrainers` to support image to video training for LTX-Video and STG guidance for inference.
-- 🔥 **2025-01-15**: Support for naive FP8 weight-casting training added! This allows training HunyuanVideo in under 24 GB upto specific resolutions.
-- 🔥 **2025-01-13**: Support for T2V full-finetuning added! Thanks to [@ArEnSc](https://github.com/ArEnSc) for taking up the initiative!
-- 🔥 **2025-01-03**: Support for T2V LoRA finetuning of [CogVideoX](https://huggingface.co/docs/diffusers/main/api/pipelines/cogvideox) added!
-- 🔥 **2024-12-20**: Support for T2V LoRA finetuning of [Hunyuan Video](https://huggingface.co/docs/diffusers/main/api/pipelines/hunyuan_video) added! We would like to thank @SHYuanBest for his work on a training script [here](https://github.com/huggingface/diffusers/pull/10254).
-- 🔥 **2024-12-18**: Support for T2V LoRA finetuning of [LTX Video](https://huggingface.co/docs/diffusers/main/api/pipelines/ltx_video) added!
-
-## Table of Contents
-
-- [Quickstart](#quickstart)
-- [Support Matrix](#support-matrix)
-- [Featured Projects](#featured-projects)
-- [Acknowledgements](#acknowledgements)
-
-## Quickstart
-
-Clone the repository and make sure the requirements are installed: `pip install -r requirements.txt` and install `diffusers` from source by `pip install git+https://github.com/huggingface/diffusers`. The requirements specify `diffusers>=0.32.1`, but it is always recommended to use the `main` branch of Diffusers for the latest features and bugfixes. Note that the `main` branch for `finetrainers` is also the development branch, and stable support should be expected from the release tags.
-
-Checkout to the latest release tag:
-
-```bash
-git fetch --all --tags
-git checkout tags/v0.0.1
-```
-
-Follow the instructions mentioned in the [README](https://github.com/a-r-r-o-w/finetrainers/tree/v0.0.1) for the release tag.
-
-#### Using the main branch
-
-To get started quickly with example training scripts on the main development branch, refer to the following:
-- [LTX-Video Pika Effects Crush](./examples/training/sft/ltx_video/crush_smol_lora/)
-- [CogVideoX Pika Effects Crush](./examples/training/sft/cogvideox/crush_smol_lora/)
-- [Wan T2V Pika Effects Crush](./examples/training/sft/wan/crush_smol_lora/)
-
-The following are some simple datasets/HF orgs with good datasets to test training with quickly:
-- [Disney Video Generation Dataset](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset)
-- [bigdatapw Video Dataset Collection](https://huggingface.co/bigdata-pw)
-- [Finetrainers HF Dataset Collection](https://huggingface.co/finetrainers)
-
-Please checkout [`docs/models`](./docs/models/) and [`examples/training`](./examples/training/) to learn more about supported models for training & example reproducible training launch scripts.
-
-> [!IMPORTANT]
-> It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested. For fully reproducible training, please use the same environment as mentioned in [environment.md](./docs/environment.md).
-
-## Support Matrix
-
-> [!NOTE]
-> The following numbers were obtained from the [release branch](https://github.com/a-r-r-o-w/finetrainers/tree/v0.0.1). The `main` branch is unstable at the moment and may use higher memory.
-
-
-
-| **Model Name** | **Tasks** | **Min. LoRA VRAM*** | **Min. Full Finetuning VRAM^** |
-|:----------------------------------------------:|:-------------:|:----------------------------------:|:---------------------------------------------:|
-| [LTX-Video](./docs/models/ltx_video.md) | Text-to-Video | 5 GB | 21 GB |
-| [HunyuanVideo](./docs/models/hunyuan_video.md) | Text-to-Video | 32 GB | OOM |
-| [CogVideoX-5b](./docs/models/cogvideox.md) | Text-to-Video | 18 GB | 53 GB |
-| [Wan](./docs/models/wan.md) | Text-to-Video | TODO | TODO |
-
-
-
-*Noted for training-only, no validation, at resolution `49x512x768`, rank 128, with pre-computation, using **FP8** weights & gradient checkpointing. Pre-computation of conditions and latents may require higher limits (but typically under 16 GB).
-^Noted for training-only, no validation, at resolution `49x512x768`, with pre-computation, using **BF16** weights & gradient checkpointing.
-
-If you would like to use a custom dataset, refer to the dataset preparation guide [here](./docs/dataset/README.md).
-
-## Featured Projects 🔥
-
-Checkout some amazing projects citing `finetrainers`:
-- [Diffusion as Shader](https://github.com/IGL-HKUST/DiffusionAsShader)
-- [SkyworkAI's SkyReels-A1](https://github.com/SkyworkAI/SkyReels-A1)
-- [eisneim's LTX Image-to-Video](https://github.com/eisneim/ltx_lora_training_i2v_t2v/)
-- [wileewang's TransPixar](https://github.com/wileewang/TransPixar)
-- [Feizc's Video-In-Context](https://github.com/feizc/Video-In-Context)
-
-Checkout the following UIs built for `finetrainers`:
-- [jbilcke's VideoModelStudio](https://github.com/jbilcke-hf/VideoModelStudio)
-- [neph1's finetrainers-ui](https://github.com/neph1/finetrainers-ui)
-
-## Acknowledgements
-
-* `finetrainers` builds on top of & takes inspiration from great open-source libraries - `transformers`, `accelerate`, `torchtune`, `torchtitan`, `peft`, `diffusers`, `bitsandbytes`, `torchao` and `deepspeed` - to name a few.
-* Some of the design choices of `finetrainers` were inspired by [`SimpleTuner`](https://github.com/bghira/SimpleTuner).
diff --git a/vms/ui/project/tabs/manage_tab.py b/vms/ui/project/tabs/manage_tab.py
index 48a617f4017072a6dd6e94a229070ecd3cabb10f..4f1827f5466672f52efbb489c538ed230494c1be 100644
--- a/vms/ui/project/tabs/manage_tab.py
+++ b/vms/ui/project/tabs/manage_tab.py
@@ -79,7 +79,8 @@ class ManageTab(BaseTab):
self.components["download_output_btn"] = gr.DownloadButton(
"📁 Download output directory (.zip)",
variant="secondary",
- size="lg"
+ size="lg",
+ visible=False
)
with gr.Row():
with gr.Column():