import importlib import argparse import math import os import sys import random import time import json from multiprocessing import Value from typing import Any, List import toml from tqdm import tqdm # from comfy.utils import ProgressBar import torch from library.device_utils import init_ipex, clean_memory_on_device init_ipex() from accelerate.utils import set_seed from accelerate import Accelerator from diffusers import DDPMScheduler from library import deepspeed_utils, model_util, strategy_base, strategy_sd from library import train_util as train_util from library.train_util import DreamBoothDataset from library import config_util as config_util from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) from library import huggingface_util as huggingface_util from library import custom_train_functions as custom_train_functions from library.custom_train_functions import ( apply_snr_weight, get_weighted_text_embeddings, prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, apply_debiased_estimation, apply_masked_loss, ) from library.utils import setup_logging, add_logging_arguments setup_logging() import logging logger = logging.getLogger(__name__) class NetworkTrainer: def __init__(self): self.vae_scale_factor = 0.18215 self.is_sdxl = False # TODO 他のスクリプトと共通化する def generate_step_logs( self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer=None, keys_scaled=None, mean_norm=None, maximum_norm=None, ): logs = {"loss/current": current_loss, "loss/average": avr_loss} if keys_scaled is not None: logs["max_norm/keys_scaled"] = keys_scaled logs["max_norm/average_key_norm"] = mean_norm logs["max_norm/max_key_norm"] = maximum_norm lrs = lr_scheduler.get_last_lr() for i, lr in enumerate(lrs): if lr_descriptions is not None: lr_desc = lr_descriptions[i] else: idx = i - (0 if args.network_train_unet_only else -1) if idx == -1: lr_desc = "textencoder" else: if len(lrs) > 2: lr_desc = f"group{idx}" else: lr_desc = "unet" logs[f"lr/{lr_desc}"] = lr if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value logs[f"lr/d*lr/{lr_desc}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) if ( args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None ): # tracking d*lr value of unet. logs["lr/d*lr"] = ( optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] ) else: idx = 0 if not args.network_train_unet_only: logs["lr/textencoder"] = float(lrs[0]) idx = 1 for i in range(idx, len(lrs)): logs[f"lr/group{i}"] = float(lrs[i]) if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): logs[f"lr/d*lr/group{i}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) if ( args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None ): logs[f"lr/d*lr/group{i}"] = ( optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] ) return logs def assert_extra_args(self, args, train_dataset_group): train_dataset_group.verify_bucket_reso_steps(64) def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) # Incorporate xformers or memory efficient attention into the model train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) if torch.__version__ >= "2.0.0": # If you have xformers compatible with PyTorch 2.0.0 or higher, you can use the following vae.set_use_memory_efficient_attention_xformers(args.xformers) return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet def get_tokenize_strategy(self, args): return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]: return [tokenize_strategy.tokenizer] def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy def get_text_encoding_strategy(self, args): return strategy_sd.SdTextEncodingStrategy(args.clip_skip) def get_text_encoder_outputs_caching_strategy(self, args): return None def get_models_for_text_encoding(self, args, accelerator, text_encoders): """ Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models. FLUX.1 and SD3 may cache some outputs of the text encoder, so return the models that will be used for encoding (not cached). """ return text_encoders # returns a list of bool values indicating whether each text encoder should be trained def get_text_encoders_train_flags(self, args, text_encoders): return [True] * len(text_encoders) if self.is_train_text_encoder(args) else [False] * len(text_encoders) def is_train_text_encoder(self, args): return not args.network_train_unet_only def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype): for t_enc in text_encoders: t_enc.to(accelerator.device, dtype=weight_dtype) def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs): noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred def all_reduce_network(self, accelerator, network): for param in network.parameters(): if param.grad is not None: param.grad = accelerator.reduce(param.grad, reduction="mean") def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet) # region SD/SDXL def post_process_network(self, args, accelerator, network, text_encoders, unet): pass def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) prepare_scheduler_for_custom_training(noise_scheduler, device) if args.zero_terminal_snr: custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) return noise_scheduler def encode_images_to_latents(self, args, accelerator, vae, images): return vae.encode(images).latent_dist.sample() def shift_scale_latents(self, args, latents): return latents * self.vae_scale_factor def get_noise_pred_and_target( self, args, accelerator, noise_scheduler, latents, batch, text_encoder_conds, unet, network, weight_dtype, train_unet, ): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # ensure the hidden state will require grad if args.gradient_checkpointing: for x in noisy_latents: x.requires_grad_(True) for t in text_encoder_conds: t.requires_grad_(True) # Predict the noise residual with accelerator.autocast(): noise_pred = self.call_unet( args, accelerator, unet, noisy_latents.requires_grad_(train_unet), timesteps, text_encoder_conds, batch, weight_dtype, ) if args.v_parameterization: # v-parameterization training target = noise_scheduler.get_velocity(latents, noise, timesteps) else: target = noise # differential output preservation if "custom_attributes" in batch: diff_output_pr_indices = [] for i, custom_attributes in enumerate(batch["custom_attributes"]): if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: diff_output_pr_indices.append(i) if len(diff_output_pr_indices) > 0: network.set_multiplier(0.0) with torch.no_grad(), accelerator.autocast(): noise_pred_prior = self.call_unet( args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype, indices=diff_output_pr_indices, ) network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) return noise_pred, target, timesteps, None def post_process_loss(self, loss, args, timesteps, noise_scheduler): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) return loss def get_sai_model_spec(self, args): return train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False) def update_metadata(self, metadata, args): pass def is_text_encoder_not_needed_for_training(self, args): return False # use for sample images def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): # set top parameter requires_grad = True for gradient checkpointing works text_encoder.text_model.embeddings.requires_grad_(True) def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): text_encoder.text_model.embeddings.to(dtype=weight_dtype) def prepare_unet_with_accelerator( self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module ) -> torch.nn.Module: return accelerator.prepare(unet) def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): pass # endregion def init_train(self, args): session_id = random.randint(0, 2**32) training_started_at = time.time() train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None use_user_config = args.dataset_config is not None if args.seed is None: args.seed = random.randint(0, 2**32) set_seed(args.seed) tokenize_strategy = self.get_tokenize_strategy(args) strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. latents_caching_strategy = self.get_latents_caching_strategy(args) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # pbar = ProgressBar(5) # Prepare the dataset if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) if use_user_config: logger.info(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( "ignoring the following options because config file is found: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: logger.info("Using DreamBooth method.") user_config = { "datasets": [ { "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( args.train_data_dir, args.reg_data_dir ) } ] } else: logger.info("Training with captions.") user_config = { "datasets": [ { "subsets": [ { "image_dir": args.train_data_dir, "metadata_file": args.in_json, } ] } ] } blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) if args.debug_dataset: train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: logger.error( "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" ) return if cache_latents: assert ( train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" self.assert_extra_args(args, train_dataset_group) # may change some args # prepare accelerator logger.info("preparing accelerator") accelerator = train_util.prepare_accelerator(args) # Prepare a type that supports mixed precision and cast it as appropriate. weight_dtype, save_dtype = train_util.prepare_dtype(args) vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # Load the model model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator) # text_encoder is List[CLIPTextModel] or CLIPTextModel text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] # pbar.update(1) # Load the model for incremental learning #sys.path.append(os.path.dirname(__file__)) accelerator.print("import network module:", args.network_module) package = __name__.split('.')[0] network_module = importlib.import_module(args.network_module, package=package) if args.base_weights is not None: # base_weights が指定されている場合は、指定された重みを読み込みマージする for i, weight_path in enumerate(args.base_weights): if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i: multiplier = 1.0 else: multiplier = args.base_weights_multiplier[i] accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}") module, weights_sd = network_module.create_network_from_weights( multiplier, weight_path, vae, text_encoder, unet, for_inference=True ) module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") accelerator.print(f"all weights merged: {', '.join(args.base_weights)}") # cache latents if cache_latents: vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu text_encoding_strategy = self.get_text_encoding_strategy(args) strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) text_encoder_outputs_caching_strategy = self.get_text_encoder_outputs_caching_strategy(args) if text_encoder_outputs_caching_strategy is not None: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy) self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype) # pbar.update(1) # prepare network net_kwargs = {} if args.network_args is not None: for net_arg in args.network_args: key, value = net_arg.split("=") net_kwargs[key] = value # if a new network is added in future, add if ~ then blocks for each network (;'∀') if args.dim_from_weights: network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs) else: if "dropout" not in net_kwargs: # workaround for LyCORIS (;^ω^) net_kwargs["dropout"] = args.network_dropout network = network_module.create_network( 1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, neuron_dropout=args.network_dropout, **net_kwargs, ) if network is None: return network_has_multiplier = hasattr(network, "set_multiplier") if hasattr(network, "prepare_network"): network.prepare_network(args) if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"): logger.warning( "warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません" ) args.scale_weight_norms = False self.post_process_network(args, accelerator, network, text_encoders, unet) # apply network to unet and text_encoder train_unet = not args.network_train_text_encoder_only train_text_encoder = self.is_train_text_encoder(args) network.apply_to(text_encoder, unet, train_text_encoder, train_unet) if args.network_weights is not None: # FIXME consider alpha of weights: this assumes that the alpha is not changed info = network.load_weights(args.network_weights) accelerator.print(f"load network weights from {args.network_weights}: {info}") if args.gradient_checkpointing: if args.cpu_offload_checkpointing: unet.enable_gradient_checkpointing(cpu_offload=True) else: unet.enable_gradient_checkpointing() for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): if flag: if t_enc.supports_gradient_checkpointing: t_enc.gradient_checkpointing_enable() del t_enc network.enable_gradient_checkpointing() # may have no effect # Prepare classes necessary for learning accelerator.print("prepare optimizer, data loader etc.") # make backward compatibility for text_encoder_lr support_multiple_lrs = hasattr(network, "prepare_optimizer_params_with_multiple_te_lrs") if support_multiple_lrs: text_encoder_lr = args.text_encoder_lr else: # toml backward compatibility if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float) or isinstance(args.text_encoder_lr, int): text_encoder_lr = args.text_encoder_lr else: text_encoder_lr = None if len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] try: if support_multiple_lrs: results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate) else: results = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr, args.learning_rate) if type(results) is tuple: trainable_params = results[0] lr_descriptions = results[1] else: trainable_params = results lr_descriptions = None except TypeError as e: trainable_params = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr) lr_descriptions = None # if len(trainable_params) == 0: # accelerator.print("no trainable parameters found / 学習可能なパラメータが見つかりませんでした") # for params in trainable_params: # for k, v in params.items(): # if type(v) == float: # pass # else: # v = len(v) # accelerator.print(f"trainable_params: {k} = {v}") optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) self.optimizer_train_fn, self.optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset # some strategies can be None train_dataset_group.set_current_strategies() # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, shuffle=True, collate_fn=collator, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) accelerator.print( f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" ) # Send learning steps to the dataset side as well train_dataset_group.set_max_train_steps(args.max_train_steps) # lr scheduler init lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # Experimental function: performs fp16/bf16 learning including gradients, sets the entire model to fp16/bf16 if args.full_fp16: assert ( args.mixed_precision == "fp16" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" accelerator.print("enable full fp16 training.") network.to(weight_dtype) elif args.full_bf16: assert ( args.mixed_precision == "bf16" ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" accelerator.print("enable full bf16 training.") network.to(weight_dtype) unet_weight_dtype = te_weight_dtype = weight_dtype # Experimental Feature: Put base model into fp8 to save vram if args.fp8_base or args.fp8_base_unet: assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0" assert ( args.mixed_precision != "no" ), "fp8_base requires mixed precision='fp16' or 'bf16'" accelerator.print("enable fp8 training for U-Net.") unet_weight_dtype = torch.float8_e4m3fn if args.fp8_dtype == "e4m3" else torch.float8_e5m2 accelerator.print(f"unet_weight_dtype: {unet_weight_dtype}") if not args.fp8_base_unet and not args.network_train_unet_only: accelerator.print("enable fp8 training for Text Encoder.") te_weight_dtype = torch.float8_e4m3fn if args.fp8_dtype == "e4m3" else torch.float8_e5m2 # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory # logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}") # unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above logger.info(f"set U-Net weight dtype to {unet_weight_dtype}") unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator unet.requires_grad_(False) unet.to(dtype=unet_weight_dtype) for i, t_enc in enumerate(text_encoders): t_enc.requires_grad_(False) # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 if t_enc.device.type != "cpu": t_enc.to(dtype=te_weight_dtype) # nn.Embedding not support FP8 if te_weight_dtype != weight_dtype: self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: flags = self.get_text_encoders_train_flags(args, text_encoders) ds_model = deepspeed_utils.prepare_deepspeed_model( args, unet=unet if train_unet else None, text_encoder1=text_encoders[0] if flags[0] else None, text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None, network=network, ) ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( ds_model, optimizer, train_dataloader, lr_scheduler ) training_model = ds_model else: if train_unet: # default implementation is: unet = accelerator.prepare(unet) unet = self.prepare_unet_with_accelerator(args, accelerator, unet) # accelerator does some magic here else: unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator if train_text_encoder: text_encoders = [ (accelerator.prepare(t_enc) if flag else t_enc) for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)) ] if len(text_encoders) > 1: text_encoder = text_encoders else: text_encoder = text_encoders[0] else: pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( network, optimizer, train_dataloader, lr_scheduler ) training_model = network if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() for i, (t_enc, frag) in enumerate(zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))): t_enc.train() # set top parameter requires_grad = True for gradient checkpointing works if frag: self.prepare_text_encoder_grad_ckpt_workaround(i, t_enc) else: unet.eval() for t_enc in text_encoders: t_enc.eval() del t_enc accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet) if not cache_latents: # If you do not cache, VAE will be used, so enable VAE preparation. vae.requires_grad_(False) vae.eval() vae.to(accelerator.device, dtype=vae_dtype) # Experimental feature: Perform fp16 learning including gradients Apply a patch to PyTorch to enable grad scale in fp16 if args.full_fp16: train_util.patch_accelerator_for_fp16_training(accelerator) # pbar.update(1) # before resuming make hook for saving/loading to save/load the network weights only def save_model_hook(models, weights, output_dir): # pop weights of other models than network to save only network weights # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606 #if args.deepspeed: remove_indices = [] for i, model in enumerate(models): if not isinstance(model, type(accelerator.unwrap_model(network))): remove_indices.append(i) for i in reversed(remove_indices): if len(weights) > i: weights.pop(i) # print(f"save model hook: {len(weights)} weights will be saved") # save current ecpoch and step train_state_file = os.path.join(output_dir, "train_state.json") # +1 is needed because the state is saved before current_step is set from global_step logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}") with open(train_state_file, "w", encoding="utf-8") as f: json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f) steps_from_state = None def load_model_hook(models, input_dir): # remove models except network remove_indices = [] for i, model in enumerate(models): if not isinstance(model, type(accelerator.unwrap_model(network))): remove_indices.append(i) for i in reversed(remove_indices): models.pop(i) # print(f"load model hook: {len(models)} models will be loaded") # load current epoch and step to nonlocal steps_from_state train_state_file = os.path.join(input_dir, "train_state.json") if os.path.exists(train_state_file): with open(train_state_file, "r", encoding="utf-8") as f: data = json.load(f) steps_from_state = data["current_step"] logger.info(f"load train state from {train_state_file}: {data}") accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) # resume from local or huggingface train_util.resume_from_local_or_hf_if_specified(accelerator, args) # pbar.update(1) # Calculate the number of epochs num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 # 学習する # TODO: find a way to handle total batch size when there are multiple datasets total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps accelerator.print("running training") accelerator.print(f" num train images * repeats: {train_dataset_group.num_train_images}") accelerator.print(f" num reg images: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch: {len(train_dataloader)}") accelerator.print(f" num epochs: {num_train_epochs}") accelerator.print( f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" ) accelerator.print(f" gradient accumulation steps: {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps: {args.max_train_steps}") # TODO refactor metadata creation and move to util metadata = { "ss_session_id": session_id, # random integer indicating which group of epochs the model came from "ss_training_started_at": training_started_at, # unix timestamp "ss_output_name": args.output_name, "ss_learning_rate": args.learning_rate, "ss_text_encoder_lr": text_encoder_lr, "ss_unet_lr": args.unet_lr, "ss_num_train_images": train_dataset_group.num_train_images, "ss_num_reg_images": train_dataset_group.num_reg_images, "ss_num_batches_per_epoch": len(train_dataloader), "ss_num_epochs": num_train_epochs, "ss_gradient_checkpointing": args.gradient_checkpointing, "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, "ss_max_train_steps": args.max_train_steps, "ss_lr_warmup_steps": args.lr_warmup_steps, "ss_lr_scheduler": args.lr_scheduler, "ss_network_module": args.network_module, "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim "ss_network_alpha": args.network_alpha, # some networks may not have alpha "ss_network_dropout": args.network_dropout, # some networks may not have dropout "ss_mixed_precision": args.mixed_precision, "ss_full_fp16": bool(args.full_fp16), "ss_v2": bool(args.v2), "ss_base_model_version": model_version, "ss_clip_skip": args.clip_skip, "ss_max_token_length": args.max_token_length, "ss_cache_latents": bool(args.cache_latents), "ss_seed": args.seed, "ss_lowram": args.lowram, "ss_noise_offset": args.noise_offset, "ss_multires_noise_iterations": args.multires_noise_iterations, "ss_multires_noise_discount": args.multires_noise_discount, "ss_adaptive_noise_scale": args.adaptive_noise_scale, "ss_zero_terminal_snr": args.zero_terminal_snr, "ss_training_comment": args.training_comment, # will not be updated after training "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), "ss_max_grad_norm": args.max_grad_norm, "ss_caption_dropout_rate": args.caption_dropout_rate, "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs, "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate, "ss_face_crop_aug_range": args.face_crop_aug_range, "ss_prior_loss_weight": args.prior_loss_weight, "ss_min_snr_gamma": args.min_snr_gamma, "ss_scale_weight_norms": args.scale_weight_norms, "ss_ip_noise_gamma": args.ip_noise_gamma, "ss_debiased_estimation": bool(args.debiased_estimation_loss), "ss_noise_offset_random_strength": args.noise_offset_random_strength, "ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength, "ss_loss_type": args.loss_type, "ss_huber_schedule": args.huber_schedule, "ss_huber_scale": args.huber_scale, "ss_huber_c": args.huber_c, "ss_fp8_base": bool(args.fp8_base), "ss_fp8_base_unet": bool(args.fp8_base_unet), } self.update_metadata(metadata, args) # architecture specific metadata if use_user_config: # save metadata of multiple datasets # NOTE: pack "ss_datasets" value as json one time # or should also pack nested collections as json? datasets_metadata = [] tag_frequency = {} # merge tag frequency for metadata editor dataset_dirs_info = {} # merge subset dirs for metadata editor for dataset in train_dataset_group.datasets: is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset) dataset_metadata = { "is_dreambooth": is_dreambooth_dataset, "batch_size_per_device": dataset.batch_size, "num_train_images": dataset.num_train_images, # includes repeating "num_reg_images": dataset.num_reg_images, "resolution": (dataset.width, dataset.height), "enable_bucket": bool(dataset.enable_bucket), "min_bucket_reso": dataset.min_bucket_reso, "max_bucket_reso": dataset.max_bucket_reso, "tag_frequency": dataset.tag_frequency, "bucket_info": dataset.bucket_info, } subsets_metadata = [] for subset in dataset.subsets: subset_metadata = { "img_count": subset.img_count, "num_repeats": subset.num_repeats, "color_aug": bool(subset.color_aug), "flip_aug": bool(subset.flip_aug), "random_crop": bool(subset.random_crop), "shuffle_caption": bool(subset.shuffle_caption), "keep_tokens": subset.keep_tokens, "keep_tokens_separator": subset.keep_tokens_separator, "secondary_separator": subset.secondary_separator, "enable_wildcard": bool(subset.enable_wildcard), "caption_prefix": subset.caption_prefix, "caption_suffix": subset.caption_suffix, } image_dir_or_metadata_file = None if subset.image_dir: image_dir = os.path.basename(subset.image_dir) subset_metadata["image_dir"] = image_dir image_dir_or_metadata_file = image_dir if is_dreambooth_dataset: subset_metadata["class_tokens"] = subset.class_tokens subset_metadata["is_reg"] = subset.is_reg if subset.is_reg: image_dir_or_metadata_file = None # not merging reg dataset else: metadata_file = os.path.basename(subset.metadata_file) subset_metadata["metadata_file"] = metadata_file image_dir_or_metadata_file = metadata_file # may overwrite subsets_metadata.append(subset_metadata) # merge dataset dir: not reg subset only # TODO update additional-network extension to show detailed dataset config from metadata if image_dir_or_metadata_file is not None: # datasets may have a certain dir multiple times v = image_dir_or_metadata_file i = 2 while v in dataset_dirs_info: v = image_dir_or_metadata_file + f" ({i})" i += 1 image_dir_or_metadata_file = v dataset_dirs_info[image_dir_or_metadata_file] = { "n_repeats": subset.num_repeats, "img_count": subset.img_count, } dataset_metadata["subsets"] = subsets_metadata datasets_metadata.append(dataset_metadata) # merge tag frequency: for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items(): # If a directory is used by multiple datasets, count only once # Since the number of repetitions is originally specified, the number of times a tag appears in the caption does not match the number of times it is used in training. # Therefore, it is not very meaningful to add up the number of times for multiple datasets here. if ds_dir_name in tag_frequency: continue tag_frequency[ds_dir_name] = ds_freq_for_dir metadata["ss_datasets"] = json.dumps(datasets_metadata) metadata["ss_tag_frequency"] = json.dumps(tag_frequency) metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info) else: # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir assert ( len(train_dataset_group.datasets) == 1 ), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug." dataset = train_dataset_group.datasets[0] dataset_dirs_info = {} reg_dataset_dirs_info = {} if use_dreambooth_method: for subset in dataset.subsets: info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count} else: for subset in dataset.subsets: dataset_dirs_info[os.path.basename(subset.metadata_file)] = { "n_repeats": subset.num_repeats, "img_count": subset.img_count, } metadata.update( { "ss_batch_size_per_device": args.train_batch_size, "ss_total_batch_size": total_batch_size, "ss_resolution": args.resolution, "ss_color_aug": bool(args.color_aug), "ss_flip_aug": bool(args.flip_aug), "ss_random_crop": bool(args.random_crop), "ss_shuffle_caption": bool(args.shuffle_caption), "ss_enable_bucket": bool(dataset.enable_bucket), "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), "ss_min_bucket_reso": dataset.min_bucket_reso, "ss_max_bucket_reso": dataset.max_bucket_reso, "ss_keep_tokens": args.keep_tokens, "ss_dataset_dirs": json.dumps(dataset_dirs_info), "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), "ss_tag_frequency": json.dumps(dataset.tag_frequency), "ss_bucket_info": json.dumps(dataset.bucket_info), } ) # add extra args if args.network_args: metadata["ss_network_args"] = json.dumps(net_kwargs) # model name and hash if args.pretrained_model_name_or_path is not None: sd_model_name = args.pretrained_model_name_or_path if os.path.exists(sd_model_name): metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name) sd_model_name = os.path.basename(sd_model_name) metadata["ss_sd_model_name"] = sd_model_name if args.vae is not None: vae_name = args.vae if os.path.exists(vae_name): metadata["ss_vae_hash"] = train_util.model_hash(vae_name) metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name) vae_name = os.path.basename(vae_name) metadata["ss_vae_name"] = vae_name metadata = {k: str(v) for k, v in metadata.items()} # make minimum metadata for filtering minimum_metadata = {} for key in train_util.SS_METADATA_MINIMUM_KEYS: if key in metadata: minimum_metadata[key] = metadata[key] # calculate steps to skip when resuming or starting from a specific step initial_step = 0 if args.initial_epoch is not None or args.initial_step is not None: # if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming if steps_from_state is not None: logger.warning( "steps from the state is ignored because initial_step is specified" ) if args.initial_step is not None: initial_step = args.initial_step else: # num steps per epoch is calculated by num_processes and gradient_accumulation_steps initial_step = (args.initial_epoch - 1) * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) else: # if initial_epoch and initial_step are not specified, steps_from_state is used when resuming if steps_from_state is not None: initial_step = steps_from_state steps_from_state = None if initial_step > 0: assert ( args.max_train_steps > initial_step ), f"max_train_steps should be greater than initial step: {args.max_train_steps} vs {initial_step}" epoch_to_start = 0 if initial_step > 0: if args.skip_until_initial_step: # if skip_until_initial_step is specified, load data and discard it to ensure the same data is used if not args.resume: logger.info( f"initial_step is specified but not resuming. lr scheduler will be started from the beginning" ) logger.info(f"skipping {initial_step} steps") initial_step *= args.gradient_accumulation_steps # set epoch to start to make initial_step less than len(train_dataloader) epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) else: # if not, only epoch no is skipped for informative purpose epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) initial_step = 0 # do not skip noise_scheduler = self.get_noise_scheduler(args, accelerator.device) init_kwargs = {} if args.wandb_run_name: init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( "network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs, ) self.loss_recorder = train_util.LossRecorder() del train_dataset_group # pbar.update(1) # callback for step start if hasattr(accelerator.unwrap_model(network), "on_step_start"): on_step_start_for_network = accelerator.unwrap_model(network).on_step_start else: on_step_start_for_network = lambda *args, **kwargs: None # function for saving/removing def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): os.makedirs(args.output_dir, exist_ok=True) ckpt_file = os.path.join(args.output_dir, ckpt_name) accelerator.print(f"\nsaving checkpoint: {ckpt_file}") metadata["ss_training_finished_at"] = str(time.time()) metadata["ss_steps"] = str(steps) metadata["ss_epoch"] = str(epoch_no) metadata_to_save = minimum_metadata if args.no_metadata else metadata sai_metadata = self.get_sai_model_spec(args) metadata_to_save.update(sai_metadata) unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save) if args.huggingface_repo_id is not None: huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) def remove_model(old_ckpt_name): old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) if os.path.exists(old_ckpt_file): accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) if self.is_text_encoder_not_needed_for_training(args): logger.info("text_encoder is not needed for training. deleting to save memory.") for t_enc in text_encoders: del t_enc text_encoders = [] text_encoder = None # For --sample_at_first #self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) self.global_step = 0 # training loop if initial_step > 0: # only if skip_until_initial_step is specified for skip_epoch in range(epoch_to_start): # skip epochs logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") initial_step -= len(train_dataloader) # log device and dtype for each model logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}") for i, t_enc in enumerate(text_encoders): params_itr = t_enc.parameters() params_itr.__next__() # skip the first parameter params_itr.__next__() # skip the second parameter. because CLIP first two parameters are embeddings param_3rd = params_itr.__next__() logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}") clean_memory_on_device(accelerator.device) self.epoch_to_start = epoch_to_start self.num_train_epochs = num_train_epochs self.accelerator = accelerator self.network = network self.text_encoder = text_encoder self.unet = unet self.vae = vae self.tokenizers = tokenizers self.args = args self.train_dataloader = train_dataloader self.initial_step = initial_step self.current_epoch = current_epoch self.metadata = metadata self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.save_model = save_model self.remove_model = remove_model # self.comfy_pbar = None progress_bar = tqdm(range(args.max_train_steps - initial_step), smoothing=0, disable=False, desc="steps") def training_loop(break_at_steps, epoch): steps_done = 0 #accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") progress_bar.set_description(f"Epoch {epoch + 1}/{num_train_epochs} - steps") current_epoch.value = epoch + 1 metadata["ss_epoch"] = str(epoch + 1) accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) skipped_dataloader = None if self.initial_step > 0: skipped_dataloader = accelerator.skip_first_batches(train_dataloader, self.initial_step - 1) self.initial_step = 1 for step, batch in enumerate(skipped_dataloader or train_dataloader): current_step.value = self.global_step if self.initial_step > 0: self.initial_step -= 1 continue with accelerator.accumulate(training_model): on_step_start_for_network(text_encoder, unet) # temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: with torch.no_grad(): # encode latents latents = self.encode_images_to_latents(args, accelerator, vae, batch["images"].to(vae_dtype)) latents = latents.to(dtype=weight_dtype) # NaN check if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") latents = torch.nan_to_num(latents, 0, out=latents) latents = self.shift_scale_latents(args, latents) # get multiplier for each sample if network_has_multiplier: multipliers = batch["network_multipliers"] # if all multipliers are same, use single multiplier if torch.all(multipliers == multipliers[0]): multipliers = multipliers[0].item() else: raise NotImplementedError("multipliers for each sample is not supported yet") # print(f"set multiplier: {multipliers}") accelerator.unwrap_model(network).set_multiplier(multipliers) text_encoder_conds = [] text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids_list, weights_list, ) else: input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids, ) if args.full_fp16: encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] # if text_encoder_conds is not cached, use encoded_text_encoder_conds if len(text_encoder_conds) == 0: text_encoder_conds = encoded_text_encoder_conds else: # if encoded_text_encoder_conds is not None, update cached text_encoder_conds for i in range(len(encoded_text_encoder_conds)): if encoded_text_encoder_conds[i] is not None: text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( args, accelerator, noise_scheduler, latents, batch, text_encoder_conds, unet, network, weight_dtype, train_unet, ) huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if weighting is not None: loss = loss * weighting if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # weight for each sample loss = loss * loss_weights # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) loss = loss.mean() # No need to divide by batch_size since it's an average accelerator.backward(loss) if accelerator.sync_gradients: self.all_reduce_network(accelerator, network) # sync DDP grad manually if args.max_grad_norm != 0.0: params_to_clip = accelerator.unwrap_model(network).get_trainable_params() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) if args.scale_weight_norms: keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization( args.scale_weight_norms, accelerator.device ) max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} else: keys_scaled, mean_norm, maximum_norm = None, None, None # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) self.global_step += 1 current_loss = loss.detach().item() self.loss_recorder.add(epoch=epoch, step=step, global_step=self.global_step, loss=current_loss) avr_loss: float = self.loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if args.scale_weight_norms: progress_bar.set_postfix(**{**max_mean_logs, **logs}) if len(accelerator.trackers) > 0: logs = self.generate_step_logs( args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm ) accelerator.log(logs, step=self.global_step) if self.global_step >= break_at_steps: break steps_done += 1 # self.comfy_pbar.update(1) if len(accelerator.trackers) > 0: logs = {"loss/epoch": self.loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) return steps_done return training_loop # metadata["ss_epoch"] = str(num_train_epochs) # metadata["ss_training_finished_at"] = str(time.time()) # network = accelerator.unwrap_model(network) # accelerator.end_training() # if (args.save_state or args.save_state_on_train_end): # train_util.save_state_on_train_end(args, accelerator) # ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) # save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) # logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) train_util.add_masked_loss_arguments(parser) deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) train_util.add_dit_training_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) parser.add_argument( "--cpu_offload_checkpointing", action="store_true", help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing for U-Net or DiT, if supported" " / 勾配チェックポイント時にテンソルをCPUにオフロードする(U-NetまたはDiTのみ、サポートされている場合)", ) parser.add_argument( "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない" ) parser.add_argument( "--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"], help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", ) parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument( "--text_encoder_lr", type=float, default=None, nargs="*", help="learning rate for Text Encoder, can be multiple / Text Encoderの学習率、複数指定可能", ) parser.add_argument( "--fp8_base_unet", action="store_true", help="use fp8 for U-Net (or DiT), Text Encoder is fp16 or bf16" " / U-Net(またはDiT)にfp8を使用する。Text Encoderはfp16またはbf16", ) parser.add_argument( "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" ) parser.add_argument( "--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール" ) parser.add_argument( "--network_dim", type=int, default=None, help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)", ) parser.add_argument( "--network_alpha", type=float, default=1, help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)", ) parser.add_argument( "--network_dropout", type=float, default=None, help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)", ) parser.add_argument( "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数", ) parser.add_argument( "--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する" ) parser.add_argument( "--network_train_text_encoder_only", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する", ) parser.add_argument( "--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列", ) parser.add_argument( "--dim_from_weights", action="store_true", help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する", ) parser.add_argument( "--scale_weight_norms", type=float, default=None, help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)", ) parser.add_argument( "--base_weights", type=str, default=None, nargs="*", help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル", ) parser.add_argument( "--base_weights_multiplier", type=float, default=None, nargs="*", help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率", ) parser.add_argument( "--no_half_vae", action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) parser.add_argument( "--skip_until_initial_step", action="store_true", help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする", ) parser.add_argument( "--initial_epoch", type=int, default=None, help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`." + " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる", ) parser.add_argument( "--initial_step", type=int, default=None, help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch." + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする", ) # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") return parser if __name__ == "__main__": parser = setup_parser() args = parser.parse_args() train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) trainer = NetworkTrainer() trainer.train(args)