Spaces:
Paused
Paused
import os | |
import hashlib | |
import json | |
import logging | |
import math | |
import datetime | |
from datetime import timedelta | |
from pathlib import Path | |
from typing import Any, Dict, List, Tuple | |
import diffusers | |
import torch | |
import transformers | |
import wandb | |
from accelerate.accelerator import Accelerator, DistributedType | |
from accelerate.logging import get_logger | |
from accelerate.utils import ( | |
DistributedDataParallelKwargs, | |
InitProcessGroupKwargs, | |
ProjectConfiguration, | |
gather_object, | |
set_seed, | |
broadcast_object_list, | |
) | |
from diffusers.optimization import get_scheduler | |
from diffusers.pipelines import DiffusionPipeline | |
from diffusers.utils.export_utils import export_to_video | |
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict | |
from PIL import Image | |
from torch.utils.data import DataLoader, Dataset | |
from tqdm import tqdm | |
from safetensors.torch import save_file, load_file | |
from finetune.constants import LOG_LEVEL, LOG_NAME | |
from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize, I2VFlowDataset | |
from finetune.datasets.utils import ( | |
load_images, | |
load_prompts, | |
load_videos, | |
preprocess_image_with_resize, | |
preprocess_video_with_resize, | |
) | |
from finetune.schemas import Args, Components, State | |
from finetune.utils import ( | |
cast_training_params, | |
free_memory, | |
get_intermediate_ckpt_path, | |
get_latest_ckpt_path_to_resume_from, | |
get_memory_statistics, | |
get_optimizer, | |
string_to_filename, | |
unload_model, | |
unwrap_model, | |
) | |
from tqdm import tqdm | |
import pdb | |
logger = get_logger(LOG_NAME, LOG_LEVEL) | |
_DTYPE_MAP = { | |
"fp32": torch.float32, | |
"fp16": torch.float16, # FP16 is Only Support for CogVideoX-2B | |
"bf16": torch.bfloat16, | |
} | |
class Trainer: | |
# If set, should be a list of components to unload (refer to `Components``) | |
UNLOAD_LIST: List[str] = None | |
def __init__(self, args: Args) -> None: | |
self.args = args | |
self.state = State( | |
weight_dtype=self.__get_training_dtype(), | |
train_frames=self.args.train_resolution[0], | |
train_height=self.args.train_resolution[1], | |
train_width=self.args.train_resolution[2], | |
) | |
self.components: Components = self.load_components() | |
self.accelerator: Accelerator = None | |
self.dataset: Dataset = None | |
self.data_loader: DataLoader = None | |
self.optimizer = None | |
self.lr_scheduler = None | |
self._init_distributed() | |
self._init_logging() | |
self.state.using_deepspeed = self.accelerator.state.deepspeed_plugin is not None | |
def _init_distributed(self): | |
project_dir = Path(self.args.output_dir) | |
logging_dir = project_dir / "tmp_logs" | |
project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir) | |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) | |
init_process_group_kwargs = InitProcessGroupKwargs( | |
backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout) | |
) | |
mixed_precision = "no" if torch.backends.mps.is_available() else self.args.mixed_precision | |
report_to = None if self.args.report_to.lower() == "none" else self.args.report_to | |
accelerator = Accelerator( | |
project_config=project_config, | |
gradient_accumulation_steps=self.args.gradient_accumulation_steps, | |
mixed_precision=mixed_precision, | |
log_with=report_to, | |
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], | |
) | |
run_id = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") if accelerator.is_main_process else "" | |
[run_id] = broadcast_object_list([run_id]) | |
final_out_dir = project_dir / f"{self.args.run_name}-{run_id}" | |
final_log_dir = final_out_dir / "logs" | |
if accelerator.is_main_process: | |
final_log_dir.mkdir(parents=True, exist_ok=True) | |
accelerator.wait_for_everyone() | |
self.args.output_dir = final_out_dir | |
accelerator.project_configuration.project_dir = final_out_dir | |
accelerator.project_configuration.logging_dir = final_log_dir | |
accelerator.init_trackers( | |
project_name=self.args.model_name, | |
config=vars(self.args), | |
init_kwargs={ | |
"wandb": { | |
"dir": final_log_dir, | |
"name": self.args.run_name, | |
} | |
} | |
) | |
# Disable AMP for MPS. | |
if torch.backends.mps.is_available(): | |
accelerator.native_amp = False | |
self.accelerator = accelerator | |
if self.args.seed is not None: | |
set_seed(self.args.seed) | |
def _init_logging(self) -> None: | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=LOG_LEVEL, | |
) | |
if self.accelerator.is_local_main_process: | |
transformers.utils.logging.set_verbosity_warning() | |
diffusers.utils.logging.set_verbosity_info() | |
else: | |
transformers.utils.logging.set_verbosity_error() | |
diffusers.utils.logging.set_verbosity_error() | |
logger.info("Initialized Trainer") | |
logger.info(f"Accelerator state: \n{self.accelerator.state}", main_process_only=False) | |
def check_setting(self) -> None: | |
# Check for unload_list | |
if self.UNLOAD_LIST is None: | |
logger.warning( | |
"\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m" | |
) | |
else: | |
for name in self.UNLOAD_LIST: | |
if name not in self.components.model_fields: | |
raise ValueError(f"Invalid component name in unload_list: {name}") | |
def prepare_models(self) -> None: | |
logger.info("Initializing models") | |
if self.components.vae is not None: | |
if self.args.enable_slicing: | |
self.components.vae.enable_slicing() | |
if self.args.enable_tiling: | |
self.components.vae.enable_tiling() | |
self.state.transformer_config = self.components.transformer.config | |
def prepare_dataset(self) -> None: | |
logger.info("Initializing dataset and dataloader") | |
if self.args.model_type == "i2v": | |
self.dataset = I2VDatasetWithResize( | |
**(self.args.model_dump()), | |
device=self.accelerator.device, | |
max_num_frames=self.state.train_frames, | |
height=self.state.train_height, | |
width=self.state.train_width, | |
trainer=self, | |
) | |
elif self.args.model_type == "t2v": | |
self.dataset = T2VDatasetWithResize( | |
**(self.args.model_dump()), | |
device=self.accelerator.device, | |
max_num_frames=self.state.train_frames, | |
height=self.state.train_height, | |
width=self.state.train_width, | |
trainer=self, | |
) | |
elif self.args.model_type == "i2vFlow": | |
self.dataset = I2VFlowDataset( | |
**(self.args.model_dump()), | |
device=self.accelerator.device, | |
max_num_frames=self.state.train_frames, | |
height=self.state.train_height, | |
width=self.state.train_width, | |
trainer=self, | |
) | |
else: | |
raise ValueError(f"Invalid model type: {self.args.model_type}") | |
# Prepare VAE and text encoder for encoding | |
if self.args.training_type == "controlnet": | |
self.components.transformer.requires_grad_(False) | |
self.components.vae.requires_grad_(False) | |
self.components.text_encoder.requires_grad_(False) | |
self.components.vae = self.components.vae.to(self.accelerator.device, dtype=self.state.weight_dtype) | |
self.components.text_encoder = self.components.text_encoder.to( | |
self.accelerator.device, dtype=self.state.weight_dtype | |
) | |
if not self.args.model_type == "i2vFlow": | |
# Precompute latent for video and prompt embedding | |
logger.info("Precomputing latent for video and prompt embedding ...") | |
tmp_data_loader = torch.utils.data.DataLoader( | |
self.dataset, | |
collate_fn=self.collate_fn, | |
batch_size=1, | |
num_workers=0, | |
pin_memory=self.args.pin_memory, | |
) | |
tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader) | |
for _ in tqdm(tmp_data_loader, desc="prepare dataloader"): | |
... | |
self.accelerator.wait_for_everyone() | |
logger.info("Precomputing latent for video and prompt embedding ... Done") | |
unload_model(self.components.vae) | |
unload_model(self.components.text_encoder) | |
free_memory() | |
self.data_loader = torch.utils.data.DataLoader( | |
self.dataset, | |
collate_fn=self.collate_fn, | |
batch_size=self.args.batch_size, | |
num_workers=self.args.num_workers, | |
pin_memory=self.args.pin_memory, | |
shuffle=True, | |
) | |
def set_additional_trainable_parameters(self, block_names): | |
# # Set requires_grad as True for trainable parameters of selected blocks! | |
# for block_name in block_names: | |
# if hasattr(self.components.transformer, block_name): | |
# block = getattr(self.components.transformer, block_name) | |
# for param in block.parameters(): | |
# param.requires_grad_(True) | |
# else: | |
# raise ValueError(f"Model has no attribute '{block_name}'") | |
# raise NotImplementedError | |
self.components.transformer.patch_embed.proj.requires_grad_(True) | |
def prepare_trainable_parameters(self): | |
logger.info("Initializing trainable parameters") | |
# For mixed precision training we cast all non-trainable weights to half-precision | |
# as these weights are only used for inference, keeping weights in full precision is not required. | |
weight_dtype = self.state.weight_dtype | |
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: | |
# due to pytorch#99272, MPS does not yet support bfloat16. | |
raise ValueError( | |
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." | |
) | |
# For LoRA, we freeze all the parameters | |
# For SFT, we train all the parameters in transformer model | |
for attr_name, component in vars(self.components).items(): | |
if hasattr(component, "requires_grad_"): | |
if self.args.training_type == "sft" and attr_name == "transformer": | |
component.requires_grad_(True) | |
elif self.args.training_type == "controlnet" and attr_name == "controlnet": | |
component.requires_grad_(True) | |
if self.args.notextinflow: | |
component.patch_embed.text_proj.requires_grad_(False) | |
else: | |
component.requires_grad_(False) | |
if self.args.training_type == "lora": | |
transformer_lora_config = LoraConfig( | |
r=self.args.rank, | |
lora_alpha=self.args.lora_alpha, | |
init_lora_weights=True, | |
target_modules=self.args.target_modules, | |
) | |
self.components.transformer.add_adapter(transformer_lora_config) | |
self.__prepare_saving_loading_hooks(transformer_lora_config, block_names=self.args.additional_save_blocks) | |
# Add trainable blocks | |
self.set_additional_trainable_parameters(block_names=self.args.additional_save_blocks) | |
# Load components needed for training to GPU (except transformer), and cast them to the specified data type | |
# ignore_list = ["transformer"] + self.UNLOAD_LIST # ?? | |
ignore_list = self.UNLOAD_LIST | |
self.__move_components_to_device(dtype=weight_dtype, ignore_list=ignore_list) | |
if self.args.gradient_checkpointing: | |
self.components.transformer.enable_gradient_checkpointing() | |
if self.args.training_type == "controlnet": | |
self.components.controlnet.enable_gradient_checkpointing() | |
def prepare_optimizer(self) -> None: | |
logger.info("Initializing optimizer and lr scheduler") | |
# Make sure the trainable params are in float32 | |
if self.args.training_type == "sft" or self.args.training_type == "lora": | |
cast_training_params([self.components.transformer], dtype=torch.float32) | |
# For LoRA, we only want to train the LoRA weights | |
# For SFT, we want to train all the parameters | |
trainable_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters())) | |
trainable_parameters_name = [p[0] for p in filter(lambda p: p[1].requires_grad, self.components.transformer.named_parameters())] | |
elif self.args.training_type == "controlnet": | |
cast_training_params([self.components.controlnet], dtype=torch.float32) | |
trainable_parameters = list(filter(lambda p: p.requires_grad, self.components.controlnet.parameters())) | |
trainable_parameters_name = [p[0] for p in filter(lambda p: p[1].requires_grad, self.components.controlnet.named_parameters())] | |
else: | |
raise NotImplementedError("Choose training_type among 'sft', 'lora', 'controlnet'") | |
# import pdb | |
# pdb.set_trace() | |
print("-"*200) | |
print(f"Training type: {self.args.training_type}") | |
print(f"Trainable parameters: {trainable_parameters_name}") | |
print("-"*200) | |
trainable_parameters_with_lr = { | |
"params": trainable_parameters, | |
"lr": self.args.learning_rate, | |
} | |
params_to_optimize = [trainable_parameters_with_lr] | |
self.state.num_trainable_parameters = sum(p.numel() for p in trainable_parameters) | |
use_deepspeed_opt = ( | |
self.accelerator.state.deepspeed_plugin is not None | |
and "optimizer" in self.accelerator.state.deepspeed_plugin.deepspeed_config | |
) | |
optimizer = get_optimizer( | |
params_to_optimize=params_to_optimize, | |
optimizer_name=self.args.optimizer, | |
learning_rate=self.args.learning_rate, | |
beta1=self.args.beta1, | |
beta2=self.args.beta2, | |
beta3=self.args.beta3, | |
epsilon=self.args.epsilon, | |
weight_decay=self.args.weight_decay, | |
use_deepspeed=use_deepspeed_opt, | |
) | |
num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps) | |
if self.args.train_steps is None: | |
self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch | |
self.state.overwrote_max_train_steps = True | |
use_deepspeed_lr_scheduler = ( | |
self.accelerator.state.deepspeed_plugin is not None | |
and "scheduler" in self.accelerator.state.deepspeed_plugin.deepspeed_config | |
) | |
total_training_steps = self.args.train_steps * self.accelerator.num_processes | |
num_warmup_steps = self.args.lr_warmup_steps * self.accelerator.num_processes | |
if use_deepspeed_lr_scheduler: | |
from accelerate.utils import DummyScheduler | |
lr_scheduler = DummyScheduler( | |
name=self.args.lr_scheduler, | |
optimizer=optimizer, | |
total_num_steps=total_training_steps, | |
num_warmup_steps=num_warmup_steps, | |
) | |
else: | |
lr_scheduler = get_scheduler( | |
name=self.args.lr_scheduler, | |
optimizer=optimizer, | |
num_warmup_steps=num_warmup_steps, | |
num_training_steps=total_training_steps, | |
num_cycles=self.args.lr_num_cycles, | |
power=self.args.lr_power, | |
) | |
self.optimizer = optimizer | |
self.lr_scheduler = lr_scheduler | |
def prepare_for_training(self) -> None: | |
if self.args.training_type == "sft" or self.args.training_type == "lora": | |
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler = self.accelerator.prepare( | |
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler | |
) | |
elif self.args.training_type == "controlnet": | |
self.components.controlnet, self.optimizer, self.data_loader, self.lr_scheduler = self.accelerator.prepare( | |
self.components.controlnet, self.optimizer, self.data_loader, self.lr_scheduler | |
) | |
# self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype) | |
# We need to recalculate our total training steps as the size of the training dataloader may have changed. | |
num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps) | |
if self.state.overwrote_max_train_steps: | |
self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch | |
# Afterwards we recalculate our number of training epochs | |
self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch) | |
self.state.num_update_steps_per_epoch = num_update_steps_per_epoch | |
def prepare_for_validation(self): | |
validation_prompts = load_prompts(self.args.validation_dir / self.args.validation_prompts) | |
if self.args.validation_images is not None: | |
validation_images = load_images(self.args.validation_dir / self.args.validation_images) | |
else: | |
validation_images = [None] * len(validation_prompts) | |
if self.args.validation_videos is not None: | |
validation_videos = load_videos(self.args.validation_dir / self.args.validation_videos) | |
else: | |
validation_videos = [None] * len(validation_prompts) | |
self.state.validation_prompts = validation_prompts | |
self.state.validation_images = validation_images | |
self.state.validation_videos = validation_videos | |
self.validate(0) | |
def prepare_trackers(self) -> None: | |
logger.info("Initializing trackers") | |
tracker_name = self.args.tracker_name or "finetrainers-experiment" | |
self.accelerator.init_trackers(tracker_name, config=self.args.model_dump()) | |
def load_state_single_gpu(self, resume_from_checkpoint_path) -> None: | |
state_dict_path = resume_from_checkpoint_path / "pytorch_model" / "mp_rank_00_model_states.pt" | |
state_dict = torch.load(state_dict_path)['module'] | |
if self.args.training_type == "controlnet": | |
controlnet_ = unwrap_model(self.accelerator, self.components.controlnet) | |
controlnet_.load_state_dict(state_dict) | |
def train(self) -> None: | |
# try: | |
logger.info("Starting training") | |
memory_statistics = get_memory_statistics() | |
logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}") | |
self.state.total_batch_size_count = ( | |
self.args.batch_size * self.accelerator.num_processes * self.args.gradient_accumulation_steps | |
) | |
info = { | |
"trainable parameters": self.state.num_trainable_parameters, | |
"total samples": len(self.dataset), | |
"train epochs": self.args.train_epochs, | |
"train steps": self.args.train_steps, | |
"batches per device": self.args.batch_size, | |
"total batches observed per epoch": len(self.data_loader), | |
"train batch size total count": self.state.total_batch_size_count, | |
"gradient accumulation steps": self.args.gradient_accumulation_steps, | |
} | |
logger.info(f"Training configuration: {json.dumps(info, indent=4)}") | |
global_step = 0 | |
first_epoch = 0 | |
initial_global_step = 0 | |
# Potentially load in the weights and states from a previous save | |
( | |
resume_from_checkpoint_path, | |
initial_global_step, | |
global_step, | |
first_epoch, | |
) = get_latest_ckpt_path_to_resume_from( | |
resume_from_checkpoint=self.args.resume_from_checkpoint, | |
num_update_steps_per_epoch=self.state.num_update_steps_per_epoch, | |
) | |
# print(f"Before out_proj weight sum: {self.components.controlnet.out_projectors[0].weight.sum()}") | |
if resume_from_checkpoint_path is not None: | |
self.accelerator.load_state(resume_from_checkpoint_path) | |
# try: | |
# self.accelerator.load_state(resume_from_checkpoint_path) | |
# except: | |
# print("[Error] deepspeed.runtime.zero.utils.ZeRORuntimeException. We sidestep this issue for the case using single gpu.") | |
# self.load_state_single_gpu(resume_from_checkpoint_path) | |
# print(f"After out_proj weight sum: {self.components.controlnet.out_projectors[0].weight.sum()}") | |
progress_bar = tqdm( | |
range(0, self.args.train_steps), | |
initial=initial_global_step, | |
desc="Training steps", | |
disable=not self.accelerator.is_local_main_process, | |
) | |
accelerator = self.accelerator | |
generator = torch.Generator(device=accelerator.device) | |
if self.args.seed is not None: | |
generator = generator.manual_seed(self.args.seed) | |
self.state.generator = generator | |
last_validated_step = -1 | |
if global_step != 0: | |
last_validated_step = global_step | |
free_memory() | |
for epoch in range(first_epoch, self.args.train_epochs): | |
logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})") | |
if self.args.training_type == "sft" or self.args.training_type == "lora": | |
self.components.transformer.train() | |
models_to_accumulate = [self.components.transformer] | |
elif self.args.training_type == "controlnet": | |
self.components.controlnet.train() | |
models_to_accumulate = [self.components.controlnet] | |
for step, batch in enumerate(self.data_loader): | |
logger.debug(f"Starting step {step + 1}") | |
logs = {} | |
with accelerator.accumulate(models_to_accumulate): | |
# These weighting schemes use a uniform timestep sampling and instead post-weight the loss | |
loss = self.compute_loss(batch) | |
accelerator.backward(loss) | |
if accelerator.sync_gradients: | |
if accelerator.distributed_type == DistributedType.DEEPSPEED: | |
if self.args.training_type == "sft" or self.args.training_type == "lora": | |
grad_norm = self.components.transformer.get_global_grad_norm() | |
elif self.args.training_type == "controlnet": | |
grad_norm = self.components.controlnet.get_global_grad_norm() | |
# In some cases the grad norm may not return a float | |
if torch.is_tensor(grad_norm): | |
grad_norm = grad_norm.item() | |
else: | |
if self.args.training_type == "sft" or self.args.training_type == "lora": | |
param_to_clip = self.components.transformer.parameters() | |
elif self.args.training_type == "controlnet": | |
param_to_clip = self.components.controlnet.parameters() | |
grad_norm = accelerator.clip_grad_norm_( | |
param_to_clip, self.args.max_grad_norm | |
) | |
if torch.is_tensor(grad_norm): | |
grad_norm = grad_norm.item() | |
logs["grad_norm"] = grad_norm | |
self.optimizer.step() | |
self.lr_scheduler.step() | |
self.optimizer.zero_grad() | |
# Checks if the accelerator has performed an optimization step behind the scenes | |
if accelerator.sync_gradients: | |
progress_bar.update(1) | |
global_step += 1 | |
self.__maybe_save_checkpoint(global_step) | |
logs["loss"] = loss.detach().item() | |
logs["lr"] = self.lr_scheduler.get_last_lr()[0] | |
progress_bar.set_postfix(logs) | |
# Maybe run validation | |
should_run_validation = ( | |
self.args.do_validation and | |
global_step % self.args.validation_steps == 0 and | |
global_step != 0 and | |
global_step != last_validated_step # prevent duplicate validation | |
) | |
if should_run_validation: | |
del loss | |
free_memory() | |
self.validate(global_step) | |
should_run_validation = False | |
last_validated_step = global_step | |
accelerator.log(logs, step=global_step) | |
if global_step >= self.args.train_steps: | |
break | |
memory_statistics = get_memory_statistics() | |
logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}") | |
accelerator.wait_for_everyone() | |
self.__maybe_save_checkpoint(global_step, must_save=True) | |
if self.args.do_validation: | |
free_memory() | |
self.validate(global_step) | |
del self.components | |
free_memory() | |
memory_statistics = get_memory_statistics() | |
logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}") | |
accelerator.end_training() | |
# except Exception as e: | |
# logger.info(f"Error message: {e}") | |
def validate(self, step: int) -> None: | |
logger.info("Starting validation") | |
accelerator = self.accelerator | |
num_validation_samples = len(self.state.validation_prompts) | |
if num_validation_samples == 0: | |
logger.warning("No validation samples found. Skipping validation.") | |
return | |
self.components.transformer.eval() | |
torch.set_grad_enabled(False) | |
memory_statistics = get_memory_statistics() | |
logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}") | |
##### Initialize pipeline ##### | |
pipe = self.initialize_pipeline() | |
if self.state.using_deepspeed: | |
# Can't using model_cpu_offload in deepspeed, | |
# so we need to move all components in pipe to device | |
# pipe.to(self.accelerator.device, dtype=self.state.weight_dtype) | |
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["transformer"]) | |
else: | |
# if not using deepspeed, use model_cpu_offload to further reduce memory usage | |
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage | |
pipe.enable_model_cpu_offload(device=self.accelerator.device) | |
# Convert all model weights to training dtype | |
# Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32 | |
pipe = pipe.to(dtype=self.state.weight_dtype) | |
################################# | |
all_processes_artifacts = [] | |
for i in range(num_validation_samples): | |
if self.state.using_deepspeed and self.accelerator.deepspeed_plugin.zero_stage != 3: | |
# Skip current validation on all processes but one | |
if i % accelerator.num_processes != accelerator.process_index: | |
continue | |
prompt = self.state.validation_prompts[i] | |
image = self.state.validation_images[i] | |
video = self.state.validation_videos[i] | |
if image is not None: | |
image = preprocess_image_with_resize(image, self.state.train_height, self.state.train_width) | |
# Convert image tensor (C, H, W) to PIL images | |
image = image.to(torch.uint8) | |
image = image.permute(1, 2, 0).cpu().numpy() | |
image = Image.fromarray(image) | |
if video is not None: | |
video = preprocess_video_with_resize( | |
video, self.state.train_frames, self.state.train_height, self.state.train_width | |
) | |
# Convert video tensor (F, C, H, W) to list of PIL images | |
video = video.round().clamp(0, 255).to(torch.uint8) | |
video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video] | |
logger.debug( | |
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}", | |
main_process_only=False, | |
) | |
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe) | |
if ( | |
self.state.using_deepspeed | |
and self.accelerator.deepspeed_plugin.zero_stage == 3 | |
and not accelerator.is_main_process | |
): | |
continue | |
prompt_filename = string_to_filename(prompt)[:25] | |
# Calculate hash of reversed prompt as a unique identifier | |
reversed_prompt = prompt[::-1] | |
hash_suffix = hashlib.md5(reversed_prompt.encode()).hexdigest()[:5] | |
artifacts = { | |
"image": {"type": "image", "value": image}, | |
"video": {"type": "video", "value": video}, | |
} | |
for i, (artifact_type, artifact_value) in enumerate(validation_artifacts): | |
artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}}) | |
logger.debug( | |
f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}", | |
main_process_only=False, | |
) | |
for key, value in list(artifacts.items()): | |
artifact_type = value["type"] | |
artifact_value = value["value"] | |
if artifact_type not in ["image", "video"] or artifact_value is None: | |
continue | |
extension = "png" if artifact_type == "image" else "mp4" | |
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}.{extension}" | |
validation_path = self.args.output_dir / "validation_res" | |
validation_path.mkdir(parents=True, exist_ok=True) | |
filename = str(validation_path / filename) | |
if artifact_type == "image": | |
logger.debug(f"Saving image to {filename}") | |
artifact_value.save(filename) | |
artifact_value = wandb.Image(filename) | |
elif artifact_type == "video": | |
logger.debug(f"Saving video to {filename}") | |
export_to_video(artifact_value, filename, fps=self.args.gen_fps) | |
artifact_value = wandb.Video(filename, caption=prompt) | |
all_processes_artifacts.append(artifact_value) | |
all_artifacts = gather_object(all_processes_artifacts) | |
if accelerator.is_main_process: | |
tracker_key = "validation" | |
for tracker in accelerator.trackers: | |
if tracker.name == "wandb": | |
image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)] | |
video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)] | |
tracker.log( | |
{ | |
tracker_key: {"images": image_artifacts, "videos": video_artifacts}, | |
}, | |
step=step, | |
) | |
########## Clean up ########## | |
if self.state.using_deepspeed: | |
del pipe | |
# Unload models except those needed for training | |
self.__move_components_to_cpu(unload_list=self.UNLOAD_LIST) | |
else: | |
pipe.remove_all_hooks() | |
del pipe | |
# Load models except those not needed for training | |
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST) | |
self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype) | |
# Change trainable weights back to fp32 to keep with dtype after prepare the model | |
cast_training_params([self.components.transformer], dtype=torch.float32) | |
free_memory() | |
accelerator.wait_for_everyone() | |
################################ | |
memory_statistics = get_memory_statistics() | |
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") | |
torch.cuda.reset_peak_memory_stats(accelerator.device) | |
torch.set_grad_enabled(True) | |
self.components.transformer.train() | |
def fit(self): | |
self.check_setting() | |
self.prepare_models() | |
self.prepare_dataset() | |
self.prepare_trainable_parameters() | |
self.prepare_optimizer() | |
self.prepare_for_training() | |
self.prepare_trackers() # prepare for the first validation before training. | |
if self.args.do_validation: | |
self.prepare_for_validation() | |
self.train() | |
def collate_fn(self, examples: List[Dict[str, Any]]): | |
raise NotImplementedError | |
def load_components(self) -> Components: | |
raise NotImplementedError | |
def initialize_pipeline(self) -> DiffusionPipeline: | |
raise NotImplementedError | |
def encode_video(self, video: torch.Tensor) -> torch.Tensor: | |
# shape of input video: [B, C, F, H, W], where B = 1 | |
# shape of output video: [B, C', F', H', W'], where B = 1 | |
raise NotImplementedError | |
def encode_text(self, text: str) -> torch.Tensor: | |
# shape of output text: [batch size, sequence length, embedding dimension] | |
raise NotImplementedError | |
def compute_loss(self, batch) -> torch.Tensor: | |
raise NotImplementedError | |
def validation_step(self) -> List[Tuple[str, Image.Image | List[Image.Image]]]: | |
raise NotImplementedError | |
def __get_training_dtype(self) -> torch.dtype: | |
if self.args.mixed_precision == "no": | |
return _DTYPE_MAP["fp32"] | |
elif self.args.mixed_precision == "fp16": | |
return _DTYPE_MAP["fp16"] | |
elif self.args.mixed_precision == "bf16": | |
return _DTYPE_MAP["bf16"] | |
else: | |
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}") | |
def __move_components_to_device(self, dtype, ignore_list: List[str] = []): | |
ignore_list = set(ignore_list) | |
components = self.components.model_dump() | |
for name, component in components.items(): | |
if not isinstance(component, type) and hasattr(component, "to"): | |
if name not in ignore_list: | |
setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype)) | |
def __move_components_to_cpu(self, unload_list: List[str] = []): | |
unload_list = set(unload_list) | |
components = self.components.model_dump() | |
for name, component in components.items(): | |
if not isinstance(component, type) and hasattr(component, "to"): | |
if name in unload_list: | |
setattr(self.components, name, component.to("cpu")) | |
def __prepare_saving_loading_hooks(self, transformer_lora_config, block_names=[]): | |
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format | |
def save_model_hook(models, weights, output_dir): | |
if self.accelerator.is_main_process: | |
transformer_lora_layers_to_save = None | |
for model in models: | |
if isinstance( | |
unwrap_model(self.accelerator, model), | |
type(unwrap_model(self.accelerator, self.components.transformer)), | |
): | |
model = unwrap_model(self.accelerator, model) | |
# 1) Set Lora weight | |
transformer_lora_layers_to_save = get_peft_model_state_dict(model) | |
# 2) Set Other weight designated by block_names | |
if len(block_names) != 0: | |
tensor_dict = {} | |
for block_name in block_names: | |
if hasattr(model, block_name): | |
block = getattr(model, block_name) | |
for k, v in block.state_dict().items(): | |
tensor_dict[f"{block_name}.{k}"] = v | |
else: | |
raise ValueError(f"Model has no attribute '{block_name}'") | |
else: | |
raise ValueError(f"Unexpected save model: {model.__class__}") | |
# make sure to pop weight so that corresponding model is not saved again | |
if weights: | |
weights.pop() | |
# 1) Save Lora weight | |
self.components.pipeline_cls.save_lora_weights( | |
output_dir, | |
transformer_lora_layers=transformer_lora_layers_to_save, | |
) | |
# 2) Save Other weight | |
if len(block_names) != 0: | |
save_path = os.path.join(output_dir, "selected_blocks.safetensors") | |
save_file(tensor_dict, save_path) | |
def load_model_hook(models, input_dir): | |
if not self.accelerator.distributed_type == DistributedType.DEEPSPEED: | |
while len(models) > 0: | |
model = models.pop() | |
if isinstance( | |
unwrap_model(self.accelerator, model), | |
type(unwrap_model(self.accelerator, self.components.transformer)), | |
): | |
transformer_ = unwrap_model(self.accelerator, model) | |
else: | |
raise ValueError(f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}") | |
else: | |
transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained( | |
self.args.model_path, subfolder="transformer" | |
) | |
transformer_.add_adapter(transformer_lora_config) | |
# 1) Load Lora weight | |
lora_state_dict = self.components.pipeline_cls.lora_state_dict(input_dir) | |
transformer_state_dict = { | |
f'{k.replace("transformer.", "")}': v | |
for k, v in lora_state_dict.items() | |
if k.startswith("transformer.") | |
} | |
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") | |
if incompatible_keys is not None: | |
# check only for unexpected keys | |
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) | |
if unexpected_keys: | |
logger.warning( | |
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " | |
f" {unexpected_keys}. " | |
) | |
# 2) Load Other weight | |
load_path = os.path.join(input_dir, "selected_blocks.safetensors") | |
if os.path.exists(load_path): | |
tensor_dict = load_file(load_path) | |
block_state_dicts = {} | |
for k, v in tensor_dict.items(): | |
block_name, param_name = k.split(".", 1) | |
if block_name not in block_state_dicts: | |
block_state_dicts[block_name] = {} | |
block_state_dicts[block_name][param_name] = v | |
for block_name, state_dict in block_state_dicts.items(): | |
if hasattr(transformer_, block_name): | |
getattr(transformer_, block_name).load_state_dict(state_dict) | |
else: | |
raise ValueError(f"Transformer has no attribute '{block_name}'") | |
# 3) Set optimizer state for desired device/dtype | |
for state in self.optimizer.state.values(): | |
for k, v in state.items(): | |
if isinstance(v, torch.Tensor): | |
state[k] = v.to(device=self.accelerator.device, dtype=torch.float32) | |
self.accelerator.register_save_state_pre_hook(save_model_hook) | |
self.accelerator.register_load_state_pre_hook(load_model_hook) | |
def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False): | |
if self.accelerator.distributed_type == DistributedType.DEEPSPEED or self.accelerator.is_main_process: | |
if must_save or global_step % self.args.checkpointing_steps == 0: | |
# for training | |
save_path = get_intermediate_ckpt_path( | |
checkpointing_limit=self.args.checkpointing_limit, | |
step=global_step, | |
output_dir=self.args.output_dir, | |
) | |
self.accelerator.save_state(save_path, safe_serialization=True) | |