import os import hashlib import json import logging import math import datetime from datetime import timedelta from pathlib import Path from typing import Any, Dict, List, Tuple import diffusers import torch import transformers import wandb from accelerate.accelerator import Accelerator, DistributedType from accelerate.logging import get_logger from accelerate.utils import ( DistributedDataParallelKwargs, InitProcessGroupKwargs, ProjectConfiguration, gather_object, set_seed, broadcast_object_list, ) from diffusers.optimization import get_scheduler from diffusers.pipelines import DiffusionPipeline from diffusers.utils.export_utils import export_to_video from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict from PIL import Image from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from safetensors.torch import save_file, load_file from finetune.constants import LOG_LEVEL, LOG_NAME from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize, I2VFlowDataset from finetune.datasets.utils import ( load_images, load_prompts, load_videos, preprocess_image_with_resize, preprocess_video_with_resize, ) from finetune.schemas import Args, Components, State from finetune.utils import ( cast_training_params, free_memory, get_intermediate_ckpt_path, get_latest_ckpt_path_to_resume_from, get_memory_statistics, get_optimizer, string_to_filename, unload_model, unwrap_model, ) from tqdm import tqdm import pdb logger = get_logger(LOG_NAME, LOG_LEVEL) _DTYPE_MAP = { "fp32": torch.float32, "fp16": torch.float16, # FP16 is Only Support for CogVideoX-2B "bf16": torch.bfloat16, } class Trainer: # If set, should be a list of components to unload (refer to `Components``) UNLOAD_LIST: List[str] = None def __init__(self, args: Args) -> None: self.args = args self.state = State( weight_dtype=self.__get_training_dtype(), train_frames=self.args.train_resolution[0], train_height=self.args.train_resolution[1], train_width=self.args.train_resolution[2], ) self.components: Components = self.load_components() self.accelerator: Accelerator = None self.dataset: Dataset = None self.data_loader: DataLoader = None self.optimizer = None self.lr_scheduler = None self._init_distributed() self._init_logging() self.state.using_deepspeed = self.accelerator.state.deepspeed_plugin is not None def _init_distributed(self): project_dir = Path(self.args.output_dir) logging_dir = project_dir / "tmp_logs" project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir) ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) init_process_group_kwargs = InitProcessGroupKwargs( backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout) ) mixed_precision = "no" if torch.backends.mps.is_available() else self.args.mixed_precision report_to = None if self.args.report_to.lower() == "none" else self.args.report_to accelerator = Accelerator( project_config=project_config, gradient_accumulation_steps=self.args.gradient_accumulation_steps, mixed_precision=mixed_precision, log_with=report_to, kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], ) run_id = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") if accelerator.is_main_process else "" [run_id] = broadcast_object_list([run_id]) final_out_dir = project_dir / f"{self.args.run_name}-{run_id}" final_log_dir = final_out_dir / "logs" if accelerator.is_main_process: final_log_dir.mkdir(parents=True, exist_ok=True) accelerator.wait_for_everyone() self.args.output_dir = final_out_dir accelerator.project_configuration.project_dir = final_out_dir accelerator.project_configuration.logging_dir = final_log_dir accelerator.init_trackers( project_name=self.args.model_name, config=vars(self.args), init_kwargs={ "wandb": { "dir": final_log_dir, "name": self.args.run_name, } } ) # Disable AMP for MPS. if torch.backends.mps.is_available(): accelerator.native_amp = False self.accelerator = accelerator if self.args.seed is not None: set_seed(self.args.seed) def _init_logging(self) -> None: logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=LOG_LEVEL, ) if self.accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() logger.info("Initialized Trainer") logger.info(f"Accelerator state: \n{self.accelerator.state}", main_process_only=False) def check_setting(self) -> None: # Check for unload_list if self.UNLOAD_LIST is None: logger.warning( "\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m" ) else: for name in self.UNLOAD_LIST: if name not in self.components.model_fields: raise ValueError(f"Invalid component name in unload_list: {name}") def prepare_models(self) -> None: logger.info("Initializing models") if self.components.vae is not None: if self.args.enable_slicing: self.components.vae.enable_slicing() if self.args.enable_tiling: self.components.vae.enable_tiling() self.state.transformer_config = self.components.transformer.config def prepare_dataset(self) -> None: logger.info("Initializing dataset and dataloader") if self.args.model_type == "i2v": self.dataset = I2VDatasetWithResize( **(self.args.model_dump()), device=self.accelerator.device, max_num_frames=self.state.train_frames, height=self.state.train_height, width=self.state.train_width, trainer=self, ) elif self.args.model_type == "t2v": self.dataset = T2VDatasetWithResize( **(self.args.model_dump()), device=self.accelerator.device, max_num_frames=self.state.train_frames, height=self.state.train_height, width=self.state.train_width, trainer=self, ) elif self.args.model_type == "i2vFlow": self.dataset = I2VFlowDataset( **(self.args.model_dump()), device=self.accelerator.device, max_num_frames=self.state.train_frames, height=self.state.train_height, width=self.state.train_width, trainer=self, ) else: raise ValueError(f"Invalid model type: {self.args.model_type}") # Prepare VAE and text encoder for encoding if self.args.training_type == "controlnet": self.components.transformer.requires_grad_(False) self.components.vae.requires_grad_(False) self.components.text_encoder.requires_grad_(False) self.components.vae = self.components.vae.to(self.accelerator.device, dtype=self.state.weight_dtype) self.components.text_encoder = self.components.text_encoder.to( self.accelerator.device, dtype=self.state.weight_dtype ) if not self.args.model_type == "i2vFlow": # Precompute latent for video and prompt embedding logger.info("Precomputing latent for video and prompt embedding ...") tmp_data_loader = torch.utils.data.DataLoader( self.dataset, collate_fn=self.collate_fn, batch_size=1, num_workers=0, pin_memory=self.args.pin_memory, ) tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader) for _ in tqdm(tmp_data_loader, desc="prepare dataloader"): ... self.accelerator.wait_for_everyone() logger.info("Precomputing latent for video and prompt embedding ... Done") unload_model(self.components.vae) unload_model(self.components.text_encoder) free_memory() self.data_loader = torch.utils.data.DataLoader( self.dataset, collate_fn=self.collate_fn, batch_size=self.args.batch_size, num_workers=self.args.num_workers, pin_memory=self.args.pin_memory, shuffle=True, ) def set_additional_trainable_parameters(self, block_names): # # Set requires_grad as True for trainable parameters of selected blocks! # for block_name in block_names: # if hasattr(self.components.transformer, block_name): # block = getattr(self.components.transformer, block_name) # for param in block.parameters(): # param.requires_grad_(True) # else: # raise ValueError(f"Model has no attribute '{block_name}'") # raise NotImplementedError self.components.transformer.patch_embed.proj.requires_grad_(True) def prepare_trainable_parameters(self): logger.info("Initializing trainable parameters") # For mixed precision training we cast all non-trainable weights to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = self.state.weight_dtype if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: # due to pytorch#99272, MPS does not yet support bfloat16. raise ValueError( "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." ) # For LoRA, we freeze all the parameters # For SFT, we train all the parameters in transformer model for attr_name, component in vars(self.components).items(): if hasattr(component, "requires_grad_"): if self.args.training_type == "sft" and attr_name == "transformer": component.requires_grad_(True) elif self.args.training_type == "controlnet" and attr_name == "controlnet": component.requires_grad_(True) if self.args.notextinflow: component.patch_embed.text_proj.requires_grad_(False) else: component.requires_grad_(False) if self.args.training_type == "lora": transformer_lora_config = LoraConfig( r=self.args.rank, lora_alpha=self.args.lora_alpha, init_lora_weights=True, target_modules=self.args.target_modules, ) self.components.transformer.add_adapter(transformer_lora_config) self.__prepare_saving_loading_hooks(transformer_lora_config, block_names=self.args.additional_save_blocks) # Add trainable blocks self.set_additional_trainable_parameters(block_names=self.args.additional_save_blocks) # Load components needed for training to GPU (except transformer), and cast them to the specified data type # ignore_list = ["transformer"] + self.UNLOAD_LIST # ?? ignore_list = self.UNLOAD_LIST self.__move_components_to_device(dtype=weight_dtype, ignore_list=ignore_list) if self.args.gradient_checkpointing: self.components.transformer.enable_gradient_checkpointing() if self.args.training_type == "controlnet": self.components.controlnet.enable_gradient_checkpointing() def prepare_optimizer(self) -> None: logger.info("Initializing optimizer and lr scheduler") # Make sure the trainable params are in float32 if self.args.training_type == "sft" or self.args.training_type == "lora": cast_training_params([self.components.transformer], dtype=torch.float32) # For LoRA, we only want to train the LoRA weights # For SFT, we want to train all the parameters trainable_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters())) trainable_parameters_name = [p[0] for p in filter(lambda p: p[1].requires_grad, self.components.transformer.named_parameters())] elif self.args.training_type == "controlnet": cast_training_params([self.components.controlnet], dtype=torch.float32) trainable_parameters = list(filter(lambda p: p.requires_grad, self.components.controlnet.parameters())) trainable_parameters_name = [p[0] for p in filter(lambda p: p[1].requires_grad, self.components.controlnet.named_parameters())] else: raise NotImplementedError("Choose training_type among 'sft', 'lora', 'controlnet'") # import pdb # pdb.set_trace() print("-"*200) print(f"Training type: {self.args.training_type}") print(f"Trainable parameters: {trainable_parameters_name}") print("-"*200) trainable_parameters_with_lr = { "params": trainable_parameters, "lr": self.args.learning_rate, } params_to_optimize = [trainable_parameters_with_lr] self.state.num_trainable_parameters = sum(p.numel() for p in trainable_parameters) use_deepspeed_opt = ( self.accelerator.state.deepspeed_plugin is not None and "optimizer" in self.accelerator.state.deepspeed_plugin.deepspeed_config ) optimizer = get_optimizer( params_to_optimize=params_to_optimize, optimizer_name=self.args.optimizer, learning_rate=self.args.learning_rate, beta1=self.args.beta1, beta2=self.args.beta2, beta3=self.args.beta3, epsilon=self.args.epsilon, weight_decay=self.args.weight_decay, use_deepspeed=use_deepspeed_opt, ) num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps) if self.args.train_steps is None: self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch self.state.overwrote_max_train_steps = True use_deepspeed_lr_scheduler = ( self.accelerator.state.deepspeed_plugin is not None and "scheduler" in self.accelerator.state.deepspeed_plugin.deepspeed_config ) total_training_steps = self.args.train_steps * self.accelerator.num_processes num_warmup_steps = self.args.lr_warmup_steps * self.accelerator.num_processes if use_deepspeed_lr_scheduler: from accelerate.utils import DummyScheduler lr_scheduler = DummyScheduler( name=self.args.lr_scheduler, optimizer=optimizer, total_num_steps=total_training_steps, num_warmup_steps=num_warmup_steps, ) else: lr_scheduler = get_scheduler( name=self.args.lr_scheduler, optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_training_steps, num_cycles=self.args.lr_num_cycles, power=self.args.lr_power, ) self.optimizer = optimizer self.lr_scheduler = lr_scheduler def prepare_for_training(self) -> None: if self.args.training_type == "sft" or self.args.training_type == "lora": self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler = self.accelerator.prepare( self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler ) elif self.args.training_type == "controlnet": self.components.controlnet, self.optimizer, self.data_loader, self.lr_scheduler = self.accelerator.prepare( self.components.controlnet, self.optimizer, self.data_loader, self.lr_scheduler ) # self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps) if self.state.overwrote_max_train_steps: self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch) self.state.num_update_steps_per_epoch = num_update_steps_per_epoch def prepare_for_validation(self): validation_prompts = load_prompts(self.args.validation_dir / self.args.validation_prompts) if self.args.validation_images is not None: validation_images = load_images(self.args.validation_dir / self.args.validation_images) else: validation_images = [None] * len(validation_prompts) if self.args.validation_videos is not None: validation_videos = load_videos(self.args.validation_dir / self.args.validation_videos) else: validation_videos = [None] * len(validation_prompts) self.state.validation_prompts = validation_prompts self.state.validation_images = validation_images self.state.validation_videos = validation_videos self.validate(0) def prepare_trackers(self) -> None: logger.info("Initializing trackers") tracker_name = self.args.tracker_name or "finetrainers-experiment" self.accelerator.init_trackers(tracker_name, config=self.args.model_dump()) def load_state_single_gpu(self, resume_from_checkpoint_path) -> None: state_dict_path = resume_from_checkpoint_path / "pytorch_model" / "mp_rank_00_model_states.pt" state_dict = torch.load(state_dict_path)['module'] if self.args.training_type == "controlnet": controlnet_ = unwrap_model(self.accelerator, self.components.controlnet) controlnet_.load_state_dict(state_dict) def train(self) -> None: # try: logger.info("Starting training") memory_statistics = get_memory_statistics() logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}") self.state.total_batch_size_count = ( self.args.batch_size * self.accelerator.num_processes * self.args.gradient_accumulation_steps ) info = { "trainable parameters": self.state.num_trainable_parameters, "total samples": len(self.dataset), "train epochs": self.args.train_epochs, "train steps": self.args.train_steps, "batches per device": self.args.batch_size, "total batches observed per epoch": len(self.data_loader), "train batch size total count": self.state.total_batch_size_count, "gradient accumulation steps": self.args.gradient_accumulation_steps, } logger.info(f"Training configuration: {json.dumps(info, indent=4)}") global_step = 0 first_epoch = 0 initial_global_step = 0 # Potentially load in the weights and states from a previous save ( resume_from_checkpoint_path, initial_global_step, global_step, first_epoch, ) = get_latest_ckpt_path_to_resume_from( resume_from_checkpoint=self.args.resume_from_checkpoint, num_update_steps_per_epoch=self.state.num_update_steps_per_epoch, ) # print(f"Before out_proj weight sum: {self.components.controlnet.out_projectors[0].weight.sum()}") if resume_from_checkpoint_path is not None: self.accelerator.load_state(resume_from_checkpoint_path) # try: # self.accelerator.load_state(resume_from_checkpoint_path) # except: # print("[Error] deepspeed.runtime.zero.utils.ZeRORuntimeException. We sidestep this issue for the case using single gpu.") # self.load_state_single_gpu(resume_from_checkpoint_path) # print(f"After out_proj weight sum: {self.components.controlnet.out_projectors[0].weight.sum()}") progress_bar = tqdm( range(0, self.args.train_steps), initial=initial_global_step, desc="Training steps", disable=not self.accelerator.is_local_main_process, ) accelerator = self.accelerator generator = torch.Generator(device=accelerator.device) if self.args.seed is not None: generator = generator.manual_seed(self.args.seed) self.state.generator = generator last_validated_step = -1 if global_step != 0: last_validated_step = global_step free_memory() for epoch in range(first_epoch, self.args.train_epochs): logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})") if self.args.training_type == "sft" or self.args.training_type == "lora": self.components.transformer.train() models_to_accumulate = [self.components.transformer] elif self.args.training_type == "controlnet": self.components.controlnet.train() models_to_accumulate = [self.components.controlnet] for step, batch in enumerate(self.data_loader): logger.debug(f"Starting step {step + 1}") logs = {} with accelerator.accumulate(models_to_accumulate): # These weighting schemes use a uniform timestep sampling and instead post-weight the loss loss = self.compute_loss(batch) accelerator.backward(loss) if accelerator.sync_gradients: if accelerator.distributed_type == DistributedType.DEEPSPEED: if self.args.training_type == "sft" or self.args.training_type == "lora": grad_norm = self.components.transformer.get_global_grad_norm() elif self.args.training_type == "controlnet": grad_norm = self.components.controlnet.get_global_grad_norm() # In some cases the grad norm may not return a float if torch.is_tensor(grad_norm): grad_norm = grad_norm.item() else: if self.args.training_type == "sft" or self.args.training_type == "lora": param_to_clip = self.components.transformer.parameters() elif self.args.training_type == "controlnet": param_to_clip = self.components.controlnet.parameters() grad_norm = accelerator.clip_grad_norm_( param_to_clip, self.args.max_grad_norm ) if torch.is_tensor(grad_norm): grad_norm = grad_norm.item() logs["grad_norm"] = grad_norm self.optimizer.step() self.lr_scheduler.step() self.optimizer.zero_grad() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 self.__maybe_save_checkpoint(global_step) logs["loss"] = loss.detach().item() logs["lr"] = self.lr_scheduler.get_last_lr()[0] progress_bar.set_postfix(logs) # Maybe run validation should_run_validation = ( self.args.do_validation and global_step % self.args.validation_steps == 0 and global_step != 0 and global_step != last_validated_step # prevent duplicate validation ) if should_run_validation: del loss free_memory() self.validate(global_step) should_run_validation = False last_validated_step = global_step accelerator.log(logs, step=global_step) if global_step >= self.args.train_steps: break memory_statistics = get_memory_statistics() logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}") accelerator.wait_for_everyone() self.__maybe_save_checkpoint(global_step, must_save=True) if self.args.do_validation: free_memory() self.validate(global_step) del self.components free_memory() memory_statistics = get_memory_statistics() logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}") accelerator.end_training() # except Exception as e: # logger.info(f"Error message: {e}") def validate(self, step: int) -> None: logger.info("Starting validation") accelerator = self.accelerator num_validation_samples = len(self.state.validation_prompts) if num_validation_samples == 0: logger.warning("No validation samples found. Skipping validation.") return self.components.transformer.eval() torch.set_grad_enabled(False) memory_statistics = get_memory_statistics() logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}") ##### Initialize pipeline ##### pipe = self.initialize_pipeline() if self.state.using_deepspeed: # Can't using model_cpu_offload in deepspeed, # so we need to move all components in pipe to device # pipe.to(self.accelerator.device, dtype=self.state.weight_dtype) self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["transformer"]) else: # if not using deepspeed, use model_cpu_offload to further reduce memory usage # Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage pipe.enable_model_cpu_offload(device=self.accelerator.device) # Convert all model weights to training dtype # Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32 pipe = pipe.to(dtype=self.state.weight_dtype) ################################# all_processes_artifacts = [] for i in range(num_validation_samples): if self.state.using_deepspeed and self.accelerator.deepspeed_plugin.zero_stage != 3: # Skip current validation on all processes but one if i % accelerator.num_processes != accelerator.process_index: continue prompt = self.state.validation_prompts[i] image = self.state.validation_images[i] video = self.state.validation_videos[i] if image is not None: image = preprocess_image_with_resize(image, self.state.train_height, self.state.train_width) # Convert image tensor (C, H, W) to PIL images image = image.to(torch.uint8) image = image.permute(1, 2, 0).cpu().numpy() image = Image.fromarray(image) if video is not None: video = preprocess_video_with_resize( video, self.state.train_frames, self.state.train_height, self.state.train_width ) # Convert video tensor (F, C, H, W) to list of PIL images video = video.round().clamp(0, 255).to(torch.uint8) video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video] logger.debug( f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}", main_process_only=False, ) validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe) if ( self.state.using_deepspeed and self.accelerator.deepspeed_plugin.zero_stage == 3 and not accelerator.is_main_process ): continue prompt_filename = string_to_filename(prompt)[:25] # Calculate hash of reversed prompt as a unique identifier reversed_prompt = prompt[::-1] hash_suffix = hashlib.md5(reversed_prompt.encode()).hexdigest()[:5] artifacts = { "image": {"type": "image", "value": image}, "video": {"type": "video", "value": video}, } for i, (artifact_type, artifact_value) in enumerate(validation_artifacts): artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}}) logger.debug( f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}", main_process_only=False, ) for key, value in list(artifacts.items()): artifact_type = value["type"] artifact_value = value["value"] if artifact_type not in ["image", "video"] or artifact_value is None: continue extension = "png" if artifact_type == "image" else "mp4" filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}.{extension}" validation_path = self.args.output_dir / "validation_res" validation_path.mkdir(parents=True, exist_ok=True) filename = str(validation_path / filename) if artifact_type == "image": logger.debug(f"Saving image to {filename}") artifact_value.save(filename) artifact_value = wandb.Image(filename) elif artifact_type == "video": logger.debug(f"Saving video to {filename}") export_to_video(artifact_value, filename, fps=self.args.gen_fps) artifact_value = wandb.Video(filename, caption=prompt) all_processes_artifacts.append(artifact_value) all_artifacts = gather_object(all_processes_artifacts) if accelerator.is_main_process: tracker_key = "validation" for tracker in accelerator.trackers: if tracker.name == "wandb": image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)] video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)] tracker.log( { tracker_key: {"images": image_artifacts, "videos": video_artifacts}, }, step=step, ) ########## Clean up ########## if self.state.using_deepspeed: del pipe # Unload models except those needed for training self.__move_components_to_cpu(unload_list=self.UNLOAD_LIST) else: pipe.remove_all_hooks() del pipe # Load models except those not needed for training self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST) self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype) # Change trainable weights back to fp32 to keep with dtype after prepare the model cast_training_params([self.components.transformer], dtype=torch.float32) free_memory() accelerator.wait_for_everyone() ################################ memory_statistics = get_memory_statistics() logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") torch.cuda.reset_peak_memory_stats(accelerator.device) torch.set_grad_enabled(True) self.components.transformer.train() def fit(self): self.check_setting() self.prepare_models() self.prepare_dataset() self.prepare_trainable_parameters() self.prepare_optimizer() self.prepare_for_training() self.prepare_trackers() # prepare for the first validation before training. if self.args.do_validation: self.prepare_for_validation() self.train() def collate_fn(self, examples: List[Dict[str, Any]]): raise NotImplementedError def load_components(self) -> Components: raise NotImplementedError def initialize_pipeline(self) -> DiffusionPipeline: raise NotImplementedError def encode_video(self, video: torch.Tensor) -> torch.Tensor: # shape of input video: [B, C, F, H, W], where B = 1 # shape of output video: [B, C', F', H', W'], where B = 1 raise NotImplementedError def encode_text(self, text: str) -> torch.Tensor: # shape of output text: [batch size, sequence length, embedding dimension] raise NotImplementedError def compute_loss(self, batch) -> torch.Tensor: raise NotImplementedError def validation_step(self) -> List[Tuple[str, Image.Image | List[Image.Image]]]: raise NotImplementedError def __get_training_dtype(self) -> torch.dtype: if self.args.mixed_precision == "no": return _DTYPE_MAP["fp32"] elif self.args.mixed_precision == "fp16": return _DTYPE_MAP["fp16"] elif self.args.mixed_precision == "bf16": return _DTYPE_MAP["bf16"] else: raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}") def __move_components_to_device(self, dtype, ignore_list: List[str] = []): ignore_list = set(ignore_list) components = self.components.model_dump() for name, component in components.items(): if not isinstance(component, type) and hasattr(component, "to"): if name not in ignore_list: setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype)) def __move_components_to_cpu(self, unload_list: List[str] = []): unload_list = set(unload_list) components = self.components.model_dump() for name, component in components.items(): if not isinstance(component, type) and hasattr(component, "to"): if name in unload_list: setattr(self.components, name, component.to("cpu")) def __prepare_saving_loading_hooks(self, transformer_lora_config, block_names=[]): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if self.accelerator.is_main_process: transformer_lora_layers_to_save = None for model in models: if isinstance( unwrap_model(self.accelerator, model), type(unwrap_model(self.accelerator, self.components.transformer)), ): model = unwrap_model(self.accelerator, model) # 1) Set Lora weight transformer_lora_layers_to_save = get_peft_model_state_dict(model) # 2) Set Other weight designated by block_names if len(block_names) != 0: tensor_dict = {} for block_name in block_names: if hasattr(model, block_name): block = getattr(model, block_name) for k, v in block.state_dict().items(): tensor_dict[f"{block_name}.{k}"] = v else: raise ValueError(f"Model has no attribute '{block_name}'") else: raise ValueError(f"Unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again if weights: weights.pop() # 1) Save Lora weight self.components.pipeline_cls.save_lora_weights( output_dir, transformer_lora_layers=transformer_lora_layers_to_save, ) # 2) Save Other weight if len(block_names) != 0: save_path = os.path.join(output_dir, "selected_blocks.safetensors") save_file(tensor_dict, save_path) def load_model_hook(models, input_dir): if not self.accelerator.distributed_type == DistributedType.DEEPSPEED: while len(models) > 0: model = models.pop() if isinstance( unwrap_model(self.accelerator, model), type(unwrap_model(self.accelerator, self.components.transformer)), ): transformer_ = unwrap_model(self.accelerator, model) else: raise ValueError(f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}") else: transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained( self.args.model_path, subfolder="transformer" ) transformer_.add_adapter(transformer_lora_config) # 1) Load Lora weight lora_state_dict = self.components.pipeline_cls.lora_state_dict(input_dir) transformer_state_dict = { f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") } incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") if incompatible_keys is not None: # check only for unexpected keys unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) if unexpected_keys: logger.warning( f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " f" {unexpected_keys}. " ) # 2) Load Other weight load_path = os.path.join(input_dir, "selected_blocks.safetensors") if os.path.exists(load_path): tensor_dict = load_file(load_path) block_state_dicts = {} for k, v in tensor_dict.items(): block_name, param_name = k.split(".", 1) if block_name not in block_state_dicts: block_state_dicts[block_name] = {} block_state_dicts[block_name][param_name] = v for block_name, state_dict in block_state_dicts.items(): if hasattr(transformer_, block_name): getattr(transformer_, block_name).load_state_dict(state_dict) else: raise ValueError(f"Transformer has no attribute '{block_name}'") # 3) Set optimizer state for desired device/dtype for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device=self.accelerator.device, dtype=torch.float32) self.accelerator.register_save_state_pre_hook(save_model_hook) self.accelerator.register_load_state_pre_hook(load_model_hook) def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False): if self.accelerator.distributed_type == DistributedType.DEEPSPEED or self.accelerator.is_main_process: if must_save or global_step % self.args.checkpointing_steps == 0: # for training save_path = get_intermediate_ckpt_path( checkpointing_limit=self.args.checkpointing_limit, step=global_step, output_dir=self.args.output_dir, ) self.accelerator.save_state(save_path, safe_serialization=True)