import argparse from typing import List, Optional import torch from accelerate import Accelerator from .library.device_utils import init_ipex, clean_memory_on_device init_ipex() from .library import sdxl_model_util, sdxl_train_util, strategy_base, strategy_sd, strategy_sdxl, train_util from . import train_network from .library.utils import setup_logging setup_logging() import logging logger = logging.getLogger(__name__) class SdxlNetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR self.is_sdxl = True def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) sdxl_train_util.verify_sdxl_training_args(args) if args.cache_text_encoder_outputs: assert ( train_dataset_group.is_text_encoder_output_cacheable() ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" assert ( args.network_train_unet_only or not args.cache_text_encoder_outputs ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" train_dataset_group.verify_bucket_reso_steps(32) def load_target_model(self, args, weight_dtype, accelerator): ( load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info, ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) self.load_stable_diffusion_format = load_stable_diffusion_format self.logit_scale = logit_scale self.ckpt_info = ckpt_info # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える vae.set_use_memory_efficient_attention_xformers(args.xformers) return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet def get_tokenize_strategy(self, args): return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy): return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] def get_latents_caching_strategy(self, args): latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) return latents_caching_strategy def get_text_encoding_strategy(self, args): return strategy_sdxl.SdxlTextEncodingStrategy() def get_models_for_text_encoding(self, args, accelerator, text_encoders): return text_encoders + [accelerator.unwrap_model(text_encoders[-1])] def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions ) else: return None def cache_text_encoder_outputs_if_needed( self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype ): if args.cache_text_encoder_outputs: if not args.lowram: # メモリ消費を減らす logger.info("move vae and unet to cpu to save memory") org_vae_device = vae.device org_unet_device = unet.device vae.to("cpu") unet.to("cpu") clean_memory_on_device(accelerator.device) # When TE is not be trained, it will not be prepared so we need to use explicit autocast text_encoders[0].to(accelerator.device, dtype=weight_dtype) text_encoders[1].to(accelerator.device, dtype=weight_dtype) with accelerator.autocast(): dataset.new_cache_text_encoder_outputs(text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator) accelerator.wait_for_everyone() text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[1].to("cpu", dtype=torch.float32) clean_memory_on_device(accelerator.device) if not args.lowram: logger.info("move vae and unet back to original device") vae.to(org_vae_device) unet.to(org_unet_device) else: # Text Encoderから毎回出力を取得するので、GPUに乗せておく text_encoders[0].to(accelerator.device, dtype=weight_dtype) text_encoders[1].to(accelerator.device, dtype=weight_dtype) def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: input_ids1 = batch["input_ids"] input_ids2 = batch["input_ids2"] with torch.enable_grad(): # Get the text embedding for conditioning # TODO support weighted captions # if args.weighted_captions: # encoder_hidden_states = get_weighted_text_embeddings( # tokenizer, # text_encoder, # batch["captions"], # accelerator.device, # args.max_token_length // 75 if args.max_token_length else 1, # clip_skip=args.clip_skip, # ) # else: input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( args.max_token_length, input_ids1, input_ids2, tokenizers[0], tokenizers[1], text_encoders[0], text_encoders[1], None if not args.full_fp16 else weight_dtype, accelerator=accelerator, ) else: encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) # # verify that the text encoder outputs are correct # ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl( # args.max_token_length, # batch["input_ids"].to(text_encoders[0].device), # batch["input_ids2"].to(text_encoders[0].device), # tokenizers[0], # tokenizers[1], # text_encoders[0], # text_encoders[1], # None if not args.full_fp16 else weight_dtype, # ) # b_size = encoder_hidden_states1.shape[0] # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # logger.info("text encoder outputs verified") return encoder_hidden_states1, encoder_hidden_states2, pool2 def call_unet( self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, indices: Optional[List[int]] = None, ): noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype # get size embeddings orig_size = batch["original_sizes_hw"] crop_size = batch["crop_top_lefts"] target_size = batch["target_sizes_hw"] embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) # concat embeddings encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) if indices is not None and len(indices) > 0: noisy_latents = noisy_latents[indices] timesteps = timesteps[indices] text_embedding = text_embedding[indices] vector_embedding = vector_embedding[indices] noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) return noise_pred def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, validation_settings=None): image_tensors = sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, validation_settings) return image_tensors def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() sdxl_train_util.add_sdxl_training_arguments(parser) return parser if __name__ == "__main__": parser = setup_parser() args = parser.parse_args() train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) trainer = SdxlNetworkTrainer() trainer.train(args)