Configure your LoRA train settings.
+ """, elem_classes="group_padding") + lora_name = gr.Textbox( + label="The name of your LoRA", + info="This has to be a unique name", + placeholder="e.g.: Persian Miniature Painting style, Cat Toy", + ) + concept_sentence = gr.Textbox( + elem_id="--concept_sentence", + label="Trigger word/sentence", + info="Trigger word or sentence to be used", + placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'", + interactive=True, + ) + model_names = list(models.keys()) + print(f"model_names={model_names}") + base_model = gr.Dropdown(label="Base model (edit the models.yaml file to add more to this list)", choices=model_names, value=model_names[0]) + vram = gr.Radio(["20G", "16G", "12G" ], value="20G", label="VRAM", interactive=True) + num_repeats = gr.Number(value=10, precision=0, label="Repeat trains per image", interactive=True) + max_train_epochs = gr.Number(label="Max Train Epochs", value=16, interactive=True) + total_steps = gr.Number(0, interactive=False, label="Expected training steps") + sample_prompts = gr.Textbox("", lines=5, label="Sample Image Prompts (Separate with new lines)", interactive=True) + sample_every_n_steps = gr.Number(0, precision=0, label="Sample Image Every N Steps", interactive=True) + resolution = gr.Number(value=512, precision=0, label="Resize dataset images") + with gr.Column(): + gr.Markdown( + """# Step 2. Dataset +Make sure the captions include the trigger word.
+ """, elem_classes="group_padding") + with gr.Group(): + images = gr.File( + file_types=["image", ".txt"], + label="Upload your images", + #info="If you want, you can also manually upload caption files that match the image names (example: img0.png => img0.txt)", + file_count="multiple", + interactive=True, + visible=True, + scale=1, + ) + with gr.Group(visible=False) as captioning_area: + do_captioning = gr.Button("Add AI captions with Florence-2") + output_components.append(captioning_area) + #output_components = [captioning_area] + caption_list = [] + for i in range(1, MAX_IMAGES + 1): + locals()[f"captioning_row_{i}"] = gr.Row(visible=False) + with locals()[f"captioning_row_{i}"]: + locals()[f"image_{i}"] = gr.Image( + type="filepath", + width=111, + height=111, + min_width=111, + interactive=False, + scale=2, + show_label=False, + show_share_button=False, + show_download_button=False, + ) + locals()[f"caption_{i}"] = gr.Textbox( + label=f"Caption {i}", scale=15, interactive=True + ) + + output_components.append(locals()[f"captioning_row_{i}"]) + output_components.append(locals()[f"image_{i}"]) + output_components.append(locals()[f"caption_{i}"]) + caption_list.append(locals()[f"caption_{i}"]) + with gr.Column(): + gr.Markdown( + """# Step 3. Train +Press start to start training.
+ """, elem_classes="group_padding") + refresh = gr.Button("Refresh", elem_id="refresh", visible=False) + start = gr.Button("Start training", visible=False, elem_id="start_training") + output_components.append(start) + train_script = gr.Textbox(label="Train script", max_lines=100, interactive=True) + train_config = gr.Textbox(label="Train config", max_lines=100, interactive=True) + with gr.Accordion("Advanced options", elem_id='advanced_options', open=False): + with gr.Row(): + with gr.Column(min_width=300): + seed = gr.Number(label="--seed", info="Seed", value=42, interactive=True) + with gr.Column(min_width=300): + workers = gr.Number(label="--max_data_loader_n_workers", info="Number of Workers", value=2, interactive=True) + with gr.Column(min_width=300): + learning_rate = gr.Textbox(label="--learning_rate", info="Learning Rate", value="8e-4", interactive=True) + with gr.Column(min_width=300): + save_every_n_epochs = gr.Number(label="--save_every_n_epochs", info="Save every N epochs", value=4, interactive=True) + with gr.Column(min_width=300): + guidance_scale = gr.Number(label="--guidance_scale", info="Guidance Scale", value=1.0, interactive=True) + with gr.Column(min_width=300): + timestep_sampling = gr.Textbox(label="--timestep_sampling", info="Timestep Sampling", value="shift", interactive=True) + with gr.Column(min_width=300): + network_dim = gr.Number(label="--network_dim", info="LoRA Rank", value=4, minimum=4, maximum=128, step=4, interactive=True) + advanced_components, advanced_component_ids = init_advanced() + with gr.Row(): + terminal = LogsView(label="Train log", elem_id="terminal") + with gr.Row(): + gallery = gr.Gallery(get_samples, inputs=[lora_name], label="Samples", every=10, columns=6) + + with gr.TabItem("Publish") as publish_tab: + hf_token = gr.Textbox(label="Huggingface Token") + hf_login = gr.Button("Login") + hf_logout = gr.Button("Logout") + with gr.Row() as row: + gr.Markdown("**LoRA**") + gr.Markdown("**Upload**") + loras = get_loras() + with gr.Row(): + lora_rows = refresh_publish_tab() + with gr.Column(): + with gr.Row(): + repo_owner = gr.Textbox(label="Account", interactive=False) + repo_name = gr.Textbox(label="Repository Name") + repo_visibility = gr.Textbox(label="Repository Visibility ('public' or 'private')", value="public") + upload_button = gr.Button("Upload to HuggingFace") + upload_button.click( + fn=upload_hf, + inputs=[ + base_model, + lora_rows, + repo_owner, + repo_name, + repo_visibility, + hf_token, + ] + ) + hf_login.click(fn=login_hf, inputs=[hf_token], outputs=[hf_token, hf_login, hf_logout, repo_owner]) + hf_logout.click(fn=logout_hf, outputs=[hf_token, hf_login, hf_logout, repo_owner]) + + + publish_tab.select(refresh_publish_tab, outputs=lora_rows) + lora_rows.select(fn=set_repo, inputs=[lora_rows], outputs=[repo_name]) + + dataset_folder = gr.State() + + listeners = [ + base_model, + lora_name, + resolution, + seed, + workers, + concept_sentence, + learning_rate, + network_dim, + max_train_epochs, + save_every_n_epochs, + timestep_sampling, + guidance_scale, + vram, + num_repeats, + sample_prompts, + sample_every_n_steps, + *advanced_components + ] + advanced_component_ids = [x.elem_id for x in advanced_components] + original_advanced_component_values = [comp.value for comp in advanced_components] + images.upload( + load_captioning, + inputs=[images, concept_sentence], + outputs=output_components + ) + images.delete( + load_captioning, + inputs=[images, concept_sentence], + outputs=output_components + ) + images.clear( + hide_captioning, + outputs=[captioning_area, start] + ) + max_train_epochs.change( + fn=update_total_steps, + inputs=[max_train_epochs, num_repeats, images], + outputs=[total_steps] + ) + num_repeats.change( + fn=update_total_steps, + inputs=[max_train_epochs, num_repeats, images], + outputs=[total_steps] + ) + images.upload( + fn=update_total_steps, + inputs=[max_train_epochs, num_repeats, images], + outputs=[total_steps] + ) + images.delete( + fn=update_total_steps, + inputs=[max_train_epochs, num_repeats, images], + outputs=[total_steps] + ) + images.clear( + fn=update_total_steps, + inputs=[max_train_epochs, num_repeats, images], + outputs=[total_steps] + ) + concept_sentence.change(fn=update_sample, inputs=[concept_sentence], outputs=sample_prompts) + start.click(fn=create_dataset, inputs=[dataset_folder, resolution, images] + caption_list, outputs=dataset_folder).then( + fn=start_training, + inputs=[ + base_model, + lora_name, + train_script, + train_config, + sample_prompts, + ], + outputs=terminal, + ) + do_captioning.click(fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list) + demo.load(fn=loaded, js=js, outputs=[hf_token, hf_login, hf_logout, repo_owner]) + refresh.click(update, inputs=listeners, outputs=[train_script, train_config, dataset_folder]) +if __name__ == "__main__": + cwd = os.path.dirname(os.path.abspath(__file__)) + demo.launch(debug=True, show_error=True, allowed_paths=[cwd]) diff --git a/datasets/1 b/datasets/1 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..998b18251a7854d2d03de682bb39dd2ede796b51 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,28 @@ +services: + + fluxgym: + build: + context: . + # change the dockerfile to Dockerfile.cuda12.4 if you are running CUDA 12.4 drivers otherwise leave as is + dockerfile: Dockerfile + image: fluxgym + container_name: fluxgym + ports: + - 7860:7860 + environment: + - PUID=${PUID:-1000} + - PGID=${PGID:-1000} + volumes: + - /etc/localtime:/etc/localtime:ro + - /etc/timezone:/etc/timezone:ro + - ./:/app/fluxgym + stop_signal: SIGKILL + tty: true + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + restart: unless-stopped \ No newline at end of file diff --git a/fine_tune.py b/fine_tune.py new file mode 100644 index 0000000000000000000000000000000000000000..ae7d9c8aa1523da692f3dbe70ab2181694908fdf --- /dev/null +++ b/fine_tune.py @@ -0,0 +1,560 @@ +# training with captions +# XXX dropped option: hypernetwork training + +import argparse +import math +import os +from multiprocessing import Value +import toml + +from tqdm import tqdm + +import torch +from library import deepspeed_utils, strategy_base +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate.utils import set_seed +from diffusers import DDPMScheduler + +from .utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + apply_snr_weight, + get_weighted_text_embeddings, + prepare_scheduler_for_custom_training, + scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, +) +import library.strategy_sd as strategy_sd + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + cache_latents = args.cache_latents + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if cache_latents: + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator) + + # verify load/save model formats + if load_stable_diffusion_format: + src_stable_diffusion_ckpt = args.pretrained_model_name_or_path + src_diffusers_model_path = None + else: + src_stable_diffusion_ckpt = None + src_diffusers_model_path = args.pretrained_model_name_or_path + + if args.save_model_as is None: + save_stable_diffusion_format = load_stable_diffusion_format + use_safetensors = args.use_safetensors + else: + save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors" + use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) + + # Diffusers版のxformers使用フラグを設定する関数 + def set_diffusers_xformers_flag(model, valid): + # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう + # pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`) + # U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか + # 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^) + + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + fn_recursive_set_mem_eff(model) + + # モデルに xformers とか memory efficient attention を組み込む + if args.diffusers_xformers: + accelerator.print("Use xformers by Diffusers") + set_diffusers_xformers_flag(unet, True) + else: + # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある + accelerator.print("Disable Diffusers' xformers") + set_diffusers_xformers_flag(unet, False) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + + train_dataset_group.new_cache_latents(vae, accelerator) + + vae.to("cpu") + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # 学習を準備する:モデルを適切な状態にする + training_models = [] + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + training_models.append(unet) + + if args.train_text_encoder: + accelerator.print("enable text encoder training") + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + training_models.append(text_encoder) + else: + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.requires_grad_(False) # text encoderは学習しない + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + text_encoder.train() # required for gradient_checkpointing + else: + text_encoder.eval() + + text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + for m in training_models: + m.requires_grad_(True) + + trainable_params = [] + if args.learning_rate_te is None or not args.train_text_encoder: + for m in training_models: + trainable_params.extend(m.parameters()) + else: + trainable_params = [ + {"params": list(unet.parameters()), "lr": args.learning_rate}, + {"params": list(text_encoder.parameters()), "lr": args.learning_rate_te}, + ] + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + _, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + unet.to(weight_dtype) + text_encoder.to(weight_dtype) + + if args.deepspeed: + if args.train_text_encoder: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder) + else: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet) + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + else: + # acceleratorがなんかよろしくやってくれるらしい + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + accelerator.print( + f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + # For --sample_at_first + train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) + + loss_recorder = train_util.LossRecorder() + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(*training_models): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(weight_dtype) + latents = latents * 0.18215 + b_size = latents.shape[0] + + with torch.set_grad_enabled(args.train_text_encoder): + # Get the text embedding for conditioning + if args.weighted_captions: + # TODO move to strategy_sd.py + encoder_hidden_states = get_weighted_text_embeddings( + tokenize_strategy.tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + input_ids = batch["input_ids_list"][0].to(accelerator.device) + encoder_hidden_states = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder], [input_ids] + )[0] + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: + # do not mean over batch dimension for snr weight or scale v-pred loss + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) + loss = loss.mean([1, 2, 3]) + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + + loss = loss.mean() # mean over batch dimension + else: + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + ) + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(text_encoder), + accelerator.unwrap_model(unet), + vae, + ) + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if args.logging_dir is not None: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(text_encoder), + accelerator.unwrap_model(unet), + vae, + ) + + train_util.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) + + is_main_process = accelerator.is_main_process + if is_main_process: + unet = accelerator.unwrap_model(unet) + text_encoder = accelerator.unwrap_model(text_encoder) + + accelerator.end_training() + + if is_main_process and (args.save_state or args.save_state_on_train_end): + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_train_end( + args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae + ) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + + parser.add_argument( + "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" + ) + parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + parser.add_argument( + "--learning_rate_te", + type=float, + default=None, + help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ", + ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/flags.png b/flags.png new file mode 100644 index 0000000000000000000000000000000000000000..049b127bbdffd1318ec0b70ffc309d81d76a08c8 Binary files /dev/null and b/flags.png differ diff --git a/flow.gif b/flow.gif new file mode 100644 index 0000000000000000000000000000000000000000..6f54af79b26e1b01f1ce82b0252525947c5ad4b4 --- /dev/null +++ b/flow.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e502e5bcbfd25f5d7bad10e0b57a88c8f3b24006792d3a273d7bd964634a8fd9 +size 11349766 diff --git a/flux_extract_lora.py b/flux_extract_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..661cbba5b052d26fe4a09b72715f108171992110 --- /dev/null +++ b/flux_extract_lora.py @@ -0,0 +1,221 @@ +# extract approximating LoRA by svd from two FLUX models +# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py +# Thanks to cloneofsimo! + +import argparse +import json +import os +import time +import torch +from safetensors.torch import load_file, save_file +from safetensors import safe_open +from tqdm import tqdm +from .library import flux_utils, sai_model_spec +from .library.utils import MemoryEfficientSafeOpen +from .library.utils import setup_logging +from .networks import lora_flux + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from comfy.utils import ProgressBar +# CLAMP_QUANTILE = 0.99 +# MIN_DIFF = 1e-1 + + +def save_to_file(file_name, state_dict, metadata, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + save_file(state_dict, file_name, metadata=metadata) + + +def svd( + model_org=None, + model_tuned=None, + save_to=None, + dim=4, + device=None, + store_device='cpu', + save_precision=None, + clamp_quantile=0.99, + min_diff=0.01, + no_metadata=False, + mem_eff_safe_open=False, +): + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None + + calc_dtype = torch.float + save_dtype = str_to_dtype(save_precision) + + # open models + lora_weights = {} + if not mem_eff_safe_open: + # use original safetensors.safe_open + open_fn = lambda fn: safe_open(fn, framework="pt") + else: + logger.info("Using memory efficient safe_open") + open_fn = lambda fn: MemoryEfficientSafeOpen(fn) + + with open_fn(model_org) as fo: + # filter keys + keys = [] + for key in fo.keys(): + if not ("single_block" in key or "double_block" in key): + continue + if ".bias" in key: + continue + if "norm" in key: + continue + keys.append(key) + comfy_pbar = ProgressBar(len(keys)) + with open_fn(model_tuned) as ft: + for key in tqdm(keys): + # get tensors and calculate difference + value_o = fo.get_tensor(key) + value_t = ft.get_tensor(key) + mat = value_t.to(calc_dtype) - value_o.to(calc_dtype) + del value_o, value_t + + # extract LoRA weights + if device: + mat = mat.to(device) + out_dim, in_dim = mat.size()[0:2] + rank = min(dim, in_dim, out_dim) # LoRA rank cannot exceed the original dim + + mat = mat.squeeze() + + U, S, Vh = torch.linalg.svd(mat) + + U = U[:, :rank] + S = S[:rank] + U = U @ torch.diag(S) + + Vh = Vh[:rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, clamp_quantile) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + + U = U.to(store_device, dtype=save_dtype).contiguous() + Vh = Vh.to(store_device, dtype=save_dtype).contiguous() + + print(f"key: {key}, U: {U.size()}, Vh: {Vh.size()}") + comfy_pbar.update(1) + lora_weights[key] = (U, Vh) + del mat, U, S, Vh + + # make state dict for LoRA + lora_sd = {} + for key, (up_weight, down_weight) in lora_weights.items(): + lora_name = key.replace(".weight", "").replace(".", "_") + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + lora_name + lora_sd[lora_name + ".lora_up.weight"] = up_weight + lora_sd[lora_name + ".lora_down.weight"] = down_weight + lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) # same as rank + + # minimum metadata + net_kwargs = {} + metadata = { + "ss_v2": str(False), + "ss_base_model_version": flux_utils.MODEL_VERSION_FLUX_V1, + "ss_network_module": "networks.lora_flux", + "ss_network_dim": str(dim), + "ss_network_alpha": str(float(dim)), + "ss_network_args": json.dumps(net_kwargs), + } + + if not no_metadata: + title = os.path.splitext(os.path.basename(save_to))[0] + sai_metadata = sai_model_spec.build_metadata(lora_sd, False, False, False, True, False, time.time(), title, flux="dev") + metadata.update(sai_metadata) + + save_to_file(save_to, lora_sd, metadata, save_dtype) + + logger.info(f"LoRA weights saved to {save_to}") + return save_to + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat", + ) + parser.add_argument( + "--model_org", + type=str, + default=None, + required=True, + help="Original model: safetensors file / 元モデル、safetensors", + ) + parser.add_argument( + "--model_tuned", + type=str, + default=None, + required=True, + help="Tuned model, LoRA is difference of `original to tuned`: safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors", + ) + parser.add_argument( + "--mem_eff_safe_open", + action="store_true", + help="use memory efficient safe_open. This is an experimental feature, use only when memory is not enough." + " / メモリ効率の良いsafe_openを使用する。実装は実験的なものなので、メモリが足りない場合のみ使用してください。", + ) + parser.add_argument( + "--save_to", + type=str, + default=None, + required=True, + help="destination file name: safetensors file / 保存先のファイル名、safetensors", + ) + parser.add_argument( + "--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)" + ) + parser.add_argument( + "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う" + ) + parser.add_argument( + "--clamp_quantile", + type=float, + default=0.99, + help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99", + ) + # parser.add_argument( + # "--min_diff", + # type=float, + # default=0.01, + # help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /" + # + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01", + # ) + parser.add_argument( + "--no_metadata", + action="store_true", + help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + svd(**vars(args)) \ No newline at end of file diff --git a/flux_train_comfy.py b/flux_train_comfy.py new file mode 100644 index 0000000000000000000000000000000000000000..6251ccc473d40f762c657684ffb3a7f17980a352 --- /dev/null +++ b/flux_train_comfy.py @@ -0,0 +1,806 @@ +# training with captions + +# Swap blocks between CPU and GPU: +# This implementation is inspired by and based on the work of 2kpr. +# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading. +# The original idea has been adapted and extended to fit the current project's needs. + +# Key features: +# - CPU offloading during forward and backward passes +# - Use of fused optimizer and grad_hook for efficient gradient processing +# - Per-block fused optimizer instances + +import argparse +import copy +import math +import os +from multiprocessing import Value +import toml + +from tqdm import tqdm + +import torch +from .library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate.utils import set_seed +from .library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux +from .library.sd3_train_utils import FlowMatchEulerDiscreteScheduler + +from .library import train_util as train_util + +from .library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from .library import config_util as config_util + +from .library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from .library.custom_train_functions import apply_masked_loss, add_custom_train_arguments + + +class FluxTrainer: + def __init__(self): + self.sample_prompts_te_outputs = None + + def sample_images(self, epoch, global_step, validation_settings): + image_tensors = flux_train_utils.sample_images( + self.accelerator, self.args, epoch, global_step, self.unet, self.vae, self.text_encoder, self.sample_prompts_te_outputs, validation_settings) + return image_tensors + + def init_train(self, args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cpu_offload_checkpointing and not args.gradient_checkpointing: + logger.warning( + "cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります" + ) + args.gradient_checkpointing = True + + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, ( + "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + ) + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # Prepare the dataset + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 + + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + if args.debug_dataset: + if args.cache_text_encoder_outputs: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False + ) + ) + t5xxl_max_token_length = ( + args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512) + ) + strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)) + + train_dataset_group.set_current_strategies() + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # load VAE for caching latents + ae = None + if cache_latents: + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae.to(accelerator.device, dtype=weight_dtype) + ae.requires_grad_(False) + ae.eval() + + train_dataset_group.new_cache_latents(ae, accelerator) + + ae.to("cpu") # if no sampling, vae can be deleted + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # prepare tokenize strategy + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy) + + # load clip_l, t5xxl for caching text encoder outputs + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + clip_l.eval() + t5xxl.eval() + clip_l.requires_grad_(False) + t5xxl.requires_grad_(False) + + text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # cache text encoder outputs + sample_prompts_te_outputs = None + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad here + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = [] + for line in args.sample_prompts: + line = line.strip() + if len(line) > 0 and line[0] != "#": + prompts.append(line) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + from .library.train_util import line_to_prompt_dict + + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = flux_tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + accelerator.wait_for_everyone() + + # now we can delete Text Encoders to free memory + clip_l = None + t5xxl = None + clean_memory_on_device(accelerator.device) + + # load FLUX + _, flux = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors + ) + + if args.gradient_checkpointing: + flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) + + flux.requires_grad_(True) + + # block swap + + # backward compatibility + if args.blocks_to_swap is None: + blocks_to_swap = args.double_blocks_to_swap or 0 + if args.single_blocks_to_swap is not None: + blocks_to_swap += args.single_blocks_to_swap // 2 + if blocks_to_swap > 0: + logger.warning( + "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead." + " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。" + ) + logger.info( + f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}." + ) + args.blocks_to_swap = blocks_to_swap + del blocks_to_swap + + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # This idea is based on 2kpr's great work. Thank you! + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + flux.enable_block_swap(args.blocks_to_swap, accelerator.device) + + if not cache_latents: + # load VAE here if not cached + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu") + ae.requires_grad_(False) + ae.eval() + ae.to(accelerator.device, dtype=weight_dtype) + + training_models = [] + params_to_optimize = [] + training_models.append(flux) + name_and_params = list(flux.named_parameters()) + # single param group for now + params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate}) + param_names = [[n for n, _ in name_and_params]] + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.blockwise_fused_optimizers: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. + # This balances memory usage and management complexity. + + # split params into groups. currently different learning rates are not supported + grouped_params = [] + param_group = {} + for group in params_to_optimize: + named_parameters = list(flux.named_parameters()) + assert len(named_parameters) == len(group["params"]), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # double, single or other + if np[0].startswith("double_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "double" + elif np[0].startswith("single_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "single" + else: + block_index = -1 + + param_group_key = (block_type, block_index) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) + + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) + + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") + + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") + self.optimizer_train_fn = lambda: None # dummy function + self.optimizer_eval_fn = lambda: None # dummy function + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + self.optimizer_train_fn, self.optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + if args.blockwise_fused_optimizers: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + flux.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) # TODO check works with fp16 or not + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + flux.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) + + # if we don't cache text encoder outputs, move them to device + if not args.cache_text_encoder_outputs: + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=flux) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # accelerator does some magic + # if we doesn't swap blocks, we can move the model to device + flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks]) + if self.is_swapping_blocks: + accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + from .library import adafactor_fused + + adafactor_fused.patch_adafactor_fused(optimizer) + + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): + if parameter.requires_grad: + + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook + + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) + + elif args.blockwise_fused_optimizers: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(grad_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training") + accelerator.print(f" num examples: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch: {len(train_dataloader)}") + accelerator.print(f" num epochs: {num_train_epochs}") + accelerator.print( + f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + self.global_step = 0 + + noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + if self.is_swapping_blocks: + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() + + # For --sample_at_first + #flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + + self.loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 + + self.tokens_and_masks = tokens_and_masks + self.num_train_epochs = num_train_epochs + self.current_epoch = current_epoch + self.args = args + self.accelerator = accelerator + self.unet = flux + self.vae = ae + self.text_encoder = [clip_l, t5xxl] + self.save_dtype = save_dtype + + def training_loop(break_at_steps, epoch): + global optimizer_hooked_count + steps_done = 0 + #accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + progress_bar.set_description(f"Epoch {epoch + 1}/{num_train_epochs} - steps") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = self.global_step + + if args.blockwise_fused_optimizers: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list + else: + # not cached or training, so get from text encoders + tokens_and_masks = batch["input_ids_list"] + with torch.no_grad(): + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance: ensure args.guidance_scale is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) + + # call model + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + if args.bypass_flux_guidance: + flux_utils.bypass_flux_guidance(flux) + + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = flux( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + if args.bypass_flux_guidance: + flux_utils.restore_flux_guidance(flux) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # calculate loss + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + # backward + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.blockwise_fused_optimizers: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + self.global_step += 1 + + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if len(accelerator.trackers) > 0: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) + + accelerator.log(logs, step=self.global_step) + + self.loss_recorder.add(epoch=epoch, step=step, loss=current_loss, global_step=self.global_step) + avr_loss: float = self.loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if self.global_step >= break_at_steps: + break + steps_done += 1 + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": self.loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + return steps_done + + return training_loop + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) # TODO split this + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + add_custom_train_arguments(parser) # TODO remove this from here + train_util.add_dit_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) + + parser.add_argument( + "--mem_eff_save", + action="store_true", + help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う", + ) + + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます", + ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", + ) + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="skip latents validity check / latentsの正当性チェックをスキップする", + ) + + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", + ) + return parser diff --git a/flux_train_network_comfy.py b/flux_train_network_comfy.py new file mode 100644 index 0000000000000000000000000000000000000000..8cbbd9edccd3dcec4bda5bbae0fb4dd82e90816e --- /dev/null +++ b/flux_train_network_comfy.py @@ -0,0 +1,500 @@ +import torch +import copy +import math +from typing import Any, Dict, List, Optional, Tuple, Union +import argparse +from .library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util +from .train_network import NetworkTrainer, clean_memory_on_device, setup_parser + +from accelerate import Accelerator + + +import logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +class FluxNetworkTrainer(NetworkTrainer): + def __init__(self): + super().__init__() + self.sample_prompts_te_outputs = None + self.is_schnell: Optional[bool] = None + self.is_swapping_blocks: bool = False + + def assert_extra_args(self, args, train_dataset_group): + super().assert_extra_args(args, train_dataset_group) + # sdxl_train_util.verify_sdxl_training_args(args) + + if args.fp8_base_unet: + args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 + + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # prepare CLIP-L/T5XXL training flags + self.train_clip_l = not args.network_train_unet_only + self.train_t5xxl = False # default is False even if args.network_train_unet_only is False + + if args.max_token_length is not None: + logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + + def load_target_model(self, args, weight_dtype, accelerator): + # currently offload to cpu for some models + + # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) + loading_dtype = None if args.fp8_base else weight_dtype + + # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future + self.is_schnell, model = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + ) + if args.fp8_base: + # check dtype of model + if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") + elif model.dtype == torch.float8_e4m3fn or model.dtype == torch.float8_e5m2: + logger.info(f"Loaded {model.dtype} FLUX model") + + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + model.enable_block_swap(args.blocks_to_swap, accelerator.device) + + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + clip_l.eval() + + # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) + if args.fp8_base and not args.fp8_base_unet: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + + # loading t5xxl to cpu takes a long time, so we should load to gpu in future + t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + t5xxl.eval() + if args.fp8_base and not args.fp8_base_unet: + # check dtype of model + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") + + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + + return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + + def get_tokenize_strategy(self, args): + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") + return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): + return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) + + def post_process_network(self, args, accelerator, network, text_encoders, unet): + # check t5xxl is trained or not + self.train_t5xxl = network.train_t5xxl + + if self.train_t5xxl and args.cache_text_encoder_outputs: + raise ValueError( + "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" + ) + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + if args.cache_text_encoder_outputs: + if self.train_clip_l and not self.train_t5xxl: + return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached + else: + return None # no text encoders are needed for encoding because both are cached + else: + return text_encoders # both CLIP-L and T5XXL are needed for encoding + + def get_text_encoders_train_flags(self, args, text_encoders): + return [self.train_clip_l, self.train_t5xxl] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True + return strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=self.train_clip_l or self.train_t5xxl, + apply_t5_attn_mask=args.apply_t5_attn_mask, + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # reduce memory consumption + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[1].to(accelerator.device) + + if text_encoders[1].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[1].to(weight_dtype) + + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) + + # cache sample prompts + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = [] + for line in args.sample_prompts: + line = line.strip() + if len(line) > 0 and line[0] != "#": + prompts.append(line) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + from .library.train_util import line_to_prompt_dict + + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + + accelerator.wait_for_everyone() + + # move back to cpu + if not self.is_train_text_encoder(args): + logger.info("move CLIP-L back to cpu") + text_encoders[0].to("cpu") + logger.info("move t5XXL back to cpu") + text_encoders[1].to("cpu") + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoder + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device) + + def sample_images(self, epoch, global_step, validation_settings): + text_encoders = self.get_models_for_text_encoding(self.args, self.accelerator, self.text_encoder) + + image_tensors = flux_train_utils.sample_images( + self.accelerator, self.args, epoch, global_step, self.unet, self.vae, text_encoders, self.sample_prompts_te_outputs, validation_settings) + clean_memory_on_device(self.accelerator.device) + return image_tensors + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images) + + def shift_scale_latents(self, args, latents): + return latents + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet: flux_models.Flux, + network, + weight_dtype, + train_unet, + ): + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance + # ensure guidance_scale in args is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + if t is not None and t.dtype.is_floating_point: + t.requires_grad_(True) + img_ids.requires_grad_(True) + guidance_vec.requires_grad_(True) + + # Predict the noise residual + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): + # normal forward + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = unet( + img=img, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + """ + else: + # split forward to reduce memory usage + assert network.train_blocks == "single", "train_blocks must be single for split mode" + with accelerator.autocast(): + # move flux lower to cpu, and then move flux upper to gpu + unet.to("cpu") + clean_memory_on_device(accelerator.device) + self.flux_upper.to(accelerator.device) + + # upper model does not require grad + with torch.no_grad(): + intermediate_img, intermediate_txt, vec, pe = self.flux_upper( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + # move flux upper back to cpu, and then move flux lower to gpu + self.flux_upper.to("cpu") + clean_memory_on_device(accelerator.device) + unet.to(accelerator.device) + + # lower model requires grad + intermediate_img.requires_grad_(True) + intermediate_txt.requires_grad_(True) + vec.requires_grad_(True) + pe.requires_grad_(True) + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + """ + + return model_pred + + if args.bypass_flux_guidance: + flux_utils.bypass_flux_guidance(unet) + + model_pred = call_dit( + img=packed_noisy_model_input, + img_ids=img_ids, + t5_out=t5_out, + txt_ids=txt_ids, + l_pooled=l_pooled, + timesteps=timesteps, + guidance_vec=guidance_vec, + t5_attn_mask=t5_attn_mask, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + if args.bypass_flux_guidance: #for flex + flux_utils.restore_flux_guidance(unet) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + unet.prepare_block_swap_before_forward() + with torch.no_grad(): + model_pred_prior = call_dit( + img=packed_noisy_model_input[diff_output_pr_indices], + img_ids=img_ids[diff_output_pr_indices], + t5_out=t5_out[diff_output_pr_indices], + txt_ids=txt_ids[diff_output_pr_indices], + l_pooled=l_pooled[diff_output_pr_indices], + timesteps=timesteps[diff_output_pr_indices], + guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, + t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + + model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width) + model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( + args, + model_pred_prior, + noisy_model_input[diff_output_pr_indices], + sigmas[diff_output_pr_indices] if sigmas is not None else None, + ) + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + + return model_pred, target, timesteps, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") + + def update_metadata(self, metadata, args): + metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + metadata["ss_guidance_scale"] = args.guidance_scale + metadata["ss_timestep_sampling"] = args.timestep_sampling + metadata["ss_sigmoid_scale"] = args.sigmoid_scale + metadata["ss_model_prediction_type"] = args.model_prediction_type + metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + if index == 0: # CLIP-L + return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) + else: # T5XXL + text_encoder.encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + if index == 0: # CLIP-L + logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + else: # T5XXL + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: + logger.info(f"T5XXL already prepared for fp8") + else: + logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") + text_encoder.to(te_weight_dtype) # fp8 + prepare_fp8(text_encoder, weight_dtype) + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + flux: flux_models.Flux = unet + flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() + + return flux + + +def setup_parser() -> argparse.ArgumentParser: + parser = setup_parser() + train_util.add_dit_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) diff --git a/hf_token.json b/hf_token.json new file mode 100644 index 0000000000000000000000000000000000000000..ab7d5c76a4ead5351ee7e95329a48677a3501605 --- /dev/null +++ b/hf_token.json @@ -0,0 +1,3 @@ +{ + "hf_token": "your_token_here" +} \ No newline at end of file diff --git a/icon.png b/icon.png new file mode 100644 index 0000000000000000000000000000000000000000..c763bed52203d1c9f56607c01dc28b3a7f1eee2d Binary files /dev/null and b/icon.png differ diff --git a/install.js b/install.js new file mode 100644 index 0000000000000000000000000000000000000000..d3d911125808d634cf5a8794e93f3f6d7791c46b --- /dev/null +++ b/install.js @@ -0,0 +1,96 @@ +module.exports = { + run: [ + { + method: "shell.run", + params: { + venv: "env", + message: [ + "git config --global --add safe.directory '*'", + "git clone -b sd3 https://github.com/kohya-ss/sd-scripts" + ] + } + }, + { + method: "shell.run", + params: { + path: "sd-scripts", + venv: "../env", + message: [ + "uv pip install -r requirements.txt", + ] + } + }, + { + method: "shell.run", + params: { + venv: "env", + message: [ + "pip uninstall -y diffusers[torch] torch torchaudio torchvision", + "uv pip install -r requirements.txt", + ] + } + }, + { + method: "script.start", + params: { + uri: "torch.js", + params: { + venv: "env", + // xformers: true // uncomment this line if your project requires xformers + } + } + }, + { + method: "fs.link", + params: { + drive: { + vae: "models/vae", + clip: "models/clip", + unet: "models/unet", + loras: "outputs", + }, + peers: [ + "https://github.com/pinokiofactory/stable-diffusion-webui-forge.git", + "https://github.com/pinokiofactory/comfy.git", + "https://github.com/cocktailpeanutlabs/comfyui.git", + "https://github.com/cocktailpeanutlabs/fooocus.git", + "https://github.com/cocktailpeanutlabs/automatic1111.git", + ] + } + }, +// { +// method: "fs.download", +// params: { +// uri: [ +// "https://huggingface.co/comfyanonymous/flux_text_encoders/resolve/main/clip_l.safetensors?download=true", +// "https://huggingface.co/comfyanonymous/flux_text_encoders/resolve/main/t5xxl_fp16.safetensors?download=true", +// ], +// dir: "models/clip" +// } +// }, +// { +// method: "fs.download", +// params: { +// uri: [ +// "https://huggingface.co/cocktailpeanut/xulf-dev/resolve/main/ae.sft?download=true", +// ], +// dir: "models/vae" +// } +// }, +// { +// method: "fs.download", +// params: { +// uri: [ +// "https://huggingface.co/cocktailpeanut/xulf-dev/resolve/main/flux1-dev.sft?download=true", +// ], +// dir: "models/unet" +// } +// }, + { + method: "fs.link", + params: { + venv: "env" + } + } + ] +} diff --git a/library/__init__.py b/library/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/library/__pycache__/__init__.cpython-310.pyc b/library/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4434c91f0781a8c54a5ec928892548bef8dd3555 Binary files /dev/null and b/library/__pycache__/__init__.cpython-310.pyc differ diff --git a/library/__pycache__/config_util.cpython-310.pyc b/library/__pycache__/config_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7af25c53b8bddd90ce54163c1e3d3b56354a5f4d Binary files /dev/null and b/library/__pycache__/config_util.cpython-310.pyc differ diff --git a/library/__pycache__/custom_offloading_utils.cpython-310.pyc b/library/__pycache__/custom_offloading_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d5415a1e497b01444439652630b5f4b7a49c8cd Binary files /dev/null and b/library/__pycache__/custom_offloading_utils.cpython-310.pyc differ diff --git a/library/__pycache__/custom_train_functions.cpython-310.pyc b/library/__pycache__/custom_train_functions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ec9879a42194bd300ddca5f5b5b6360272b587b Binary files /dev/null and b/library/__pycache__/custom_train_functions.cpython-310.pyc differ diff --git a/library/__pycache__/deepspeed_utils.cpython-310.pyc b/library/__pycache__/deepspeed_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ba3ef9a24c69ab1ba0731314ea8da98d8fa75d0 Binary files /dev/null and b/library/__pycache__/deepspeed_utils.cpython-310.pyc differ diff --git a/library/__pycache__/device_utils.cpython-310.pyc b/library/__pycache__/device_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..972b8250302570e149242ea63616cdd3efb558bc Binary files /dev/null and b/library/__pycache__/device_utils.cpython-310.pyc differ diff --git a/library/__pycache__/flux_models.cpython-310.pyc b/library/__pycache__/flux_models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56ce1d4ea96448680e9c47a306efe1b420a01e7e Binary files /dev/null and b/library/__pycache__/flux_models.cpython-310.pyc differ diff --git a/library/__pycache__/flux_train_utils.cpython-310.pyc b/library/__pycache__/flux_train_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8232f6bb95aef764011a49a6b3c37850d32ea8da Binary files /dev/null and b/library/__pycache__/flux_train_utils.cpython-310.pyc differ diff --git a/library/__pycache__/flux_utils.cpython-310.pyc b/library/__pycache__/flux_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf8997d908fcc5aa7dd2c97cf15daf72f2a8dcbd Binary files /dev/null and b/library/__pycache__/flux_utils.cpython-310.pyc differ diff --git a/library/__pycache__/huggingface_util.cpython-310.pyc b/library/__pycache__/huggingface_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf0c380c4a08fd6a19cc64d05794479e4e93be51 Binary files /dev/null and b/library/__pycache__/huggingface_util.cpython-310.pyc differ diff --git a/library/__pycache__/model_util.cpython-310.pyc b/library/__pycache__/model_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..614429b5176628baee564a8fe07a6c11487b0f04 Binary files /dev/null and b/library/__pycache__/model_util.cpython-310.pyc differ diff --git a/library/__pycache__/original_unet.cpython-310.pyc b/library/__pycache__/original_unet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40324b150ca486e128b08b1006fd117b15a302e1 Binary files /dev/null and b/library/__pycache__/original_unet.cpython-310.pyc differ diff --git a/library/__pycache__/sai_model_spec.cpython-310.pyc b/library/__pycache__/sai_model_spec.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea3d3d1c582b76697080f6314c2d9ff5850fe4b8 Binary files /dev/null and b/library/__pycache__/sai_model_spec.cpython-310.pyc differ diff --git a/library/__pycache__/sd3_models.cpython-310.pyc b/library/__pycache__/sd3_models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80484c7229bd39c4f072bc6fa04b3288a34943ca Binary files /dev/null and b/library/__pycache__/sd3_models.cpython-310.pyc differ diff --git a/library/__pycache__/sd3_utils.cpython-310.pyc b/library/__pycache__/sd3_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2007e8adf9309e66d0b16b84c684099618b99612 Binary files /dev/null and b/library/__pycache__/sd3_utils.cpython-310.pyc differ diff --git a/library/__pycache__/strategy_base.cpython-310.pyc b/library/__pycache__/strategy_base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b000e3252e6ae0f8a06a8a86604be57399f799fe Binary files /dev/null and b/library/__pycache__/strategy_base.cpython-310.pyc differ diff --git a/library/__pycache__/strategy_sd.cpython-310.pyc b/library/__pycache__/strategy_sd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67f8c7430846716a0c47828f8d1a155c6772f68f Binary files /dev/null and b/library/__pycache__/strategy_sd.cpython-310.pyc differ diff --git a/library/__pycache__/train_util.cpython-310.pyc b/library/__pycache__/train_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..156c9a9a8e2cdac3e52c8ad85df43259f0a285f4 --- /dev/null +++ b/library/__pycache__/train_util.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa71a44895d0a006e41ba9fadbd0177a9ad5499cc89aeb2266aa1c7a9597e82e +size 164434 diff --git a/library/__pycache__/utils.cpython-310.pyc b/library/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62837bec9f40aacaacc0075a2f70745ce7c84049 Binary files /dev/null and b/library/__pycache__/utils.cpython-310.pyc differ diff --git a/library/adafactor_fused.py b/library/adafactor_fused.py new file mode 100644 index 0000000000000000000000000000000000000000..b5afa236bc8d76779d5f90e6db9d52cf2261bdb3 --- /dev/null +++ b/library/adafactor_fused.py @@ -0,0 +1,138 @@ +import math +import torch +from transformers import Adafactor + +# stochastic rounding for bfloat16 +# The implementation was provided by 2kpr. Thank you very much! + +def copy_stochastic_(target: torch.Tensor, source: torch.Tensor): + """ + copies source into target using stochastic rounding + + Args: + target: the target tensor with dtype=bfloat16 + source: the target tensor with dtype=float32 + """ + # create a random 16 bit integer + result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16)) + + # add the random number to the lower 16 bit of the mantissa + result.add_(source.view(dtype=torch.int32)) + + # mask off the lower 16 bit of the mantissa + result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 + + # copy the higher 16 bit into the target tensor + target.copy_(result.view(dtype=torch.float32)) + + del result + + +@torch.no_grad() +def adafactor_step_param(self, p, group): + if p.grad is None: + return + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = Adafactor._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + state["RMS"] = Adafactor._rms(p_data_fp32) + lr = Adafactor._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + + # Approximation of exponential moving average of square of gradient + update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) + + p_data_fp32.add_(-update) + + # if p.dtype in {torch.float16, torch.bfloat16}: + # p.copy_(p_data_fp32) + + if p.dtype == torch.bfloat16: + copy_stochastic_(p, p_data_fp32) + elif p.dtype == torch.float16: + p.copy_(p_data_fp32) + + +@torch.no_grad() +def adafactor_step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + adafactor_step_param(self, p, group) + + return loss + + +def patch_adafactor_fused(optimizer: Adafactor): + optimizer.step_param = adafactor_step_param.__get__(optimizer) + optimizer.step = adafactor_step.__get__(optimizer) diff --git a/library/attention_processors.py b/library/attention_processors.py new file mode 100644 index 0000000000000000000000000000000000000000..310c2cb1c63955f8f03296c54fd47c21f1a981c9 --- /dev/null +++ b/library/attention_processors.py @@ -0,0 +1,227 @@ +import math +from typing import Any +from einops import rearrange +import torch +from diffusers.models.attention_processor import Attention + + +# flash attention forwards and backwards + +# https://arxiv.org/abs/2205.14135 + +EPSILON = 1e-6 + + +class FlashAttentionFunction(torch.autograd.function.Function): + @staticmethod + @torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """Algorithm 2 in the paper""" + + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full( + (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device + ) + + scale = q.shape[-1] ** -0.5 + + if mask is None: + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, "b n -> b 1 1 n") + mask = mask.split(q_bucket_size, dim=-1) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = ( + torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + ) + + if row_mask is not None: + attn_weights.masked_fill_(~row_mask, max_neg_value) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones( + (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device + ).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) + + if row_mask is not None: + exp_weights.masked_fill_(~row_mask, 0.0) + + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp( + min=EPSILON + ) + + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + + exp_values = torch.einsum( + "... i j, ... j d -> ... i d", exp_weights, vc + ) + + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + + new_row_sums = ( + exp_row_max_diff * row_sums + + exp_block_row_max_diff * block_row_sums + ) + + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_( + (exp_block_row_max_diff / new_row_sums) * exp_values + ) + + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) + + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + + return o + + @staticmethod + @torch.no_grad() + def backward(ctx, do): + """Algorithm 4 in the paper""" + + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors + + device = q.device + + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = ( + torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + ) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones( + (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device + ).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + exp_attn_weights = torch.exp(attn_weights - mc) + + if row_mask is not None: + exp_attn_weights.masked_fill_(~row_mask, 0.0) + + p = exp_attn_weights / lc + + dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) + dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) + + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) + + dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) + + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) + + return dq, dk, dv, None, None, None, None + + +class FlashAttnProcessor: + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ) -> Any: + q_bucket_size = 512 + k_bucket_size = 1024 + + h = attn.heads + q = attn.to_q(hidden_states) + + encoder_hidden_states = ( + encoder_hidden_states + if encoder_hidden_states is not None + else hidden_states + ) + encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype) + + if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None: + context_k, context_v = attn.hypernetwork.forward( + hidden_states, encoder_hidden_states + ) + context_k = context_k.to(hidden_states.dtype) + context_v = context_v.to(hidden_states.dtype) + else: + context_k = encoder_hidden_states + context_v = encoder_hidden_states + + k = attn.to_k(context_k) + v = attn.to_v(context_v) + del encoder_hidden_states, hidden_states + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + out = FlashAttentionFunction.apply( + q, k, v, attention_mask, False, q_bucket_size, k_bucket_size + ) + + out = rearrange(out, "b h n d -> b n (h d)") + + out = attn.to_out[0](out) + out = attn.to_out[1](out) + return out diff --git a/library/config_util.py b/library/config_util.py new file mode 100644 index 0000000000000000000000000000000000000000..458e2fe81ab78b3dcc02bee09a74ededa9895de2 --- /dev/null +++ b/library/config_util.py @@ -0,0 +1,717 @@ +import argparse +from dataclasses import ( + asdict, + dataclass, +) +import functools +import random +from textwrap import dedent, indent +import json +from pathlib import Path + +# from toolz import curry +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import toml +import voluptuous +from voluptuous import ( + Any, + ExactSequence, + MultipleInvalid, + Object, + Required, + Schema, +) + + +from . import train_util +from .train_util import ( + DreamBoothSubset, + FineTuningSubset, + ControlNetSubset, + DreamBoothDataset, + FineTuningDataset, + ControlNetDataset, + DatasetGroup, +) +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def add_config_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル" + ) + + +# TODO: inherit Params class in Subset, Dataset + + +@dataclass +class BaseSubsetParams: + image_dir: Optional[str] = None + num_repeats: int = 1 + shuffle_caption: bool = False + caption_separator: str = (",",) + keep_tokens: int = 0 + keep_tokens_separator: str = (None,) + secondary_separator: Optional[str] = None + enable_wildcard: bool = False + color_aug: bool = False + flip_aug: bool = False + face_crop_aug_range: Optional[Tuple[float, float]] = None + random_crop: bool = False + caption_prefix: Optional[str] = None + caption_suffix: Optional[str] = None + caption_dropout_rate: float = 0.0 + caption_dropout_every_n_epochs: int = 0 + caption_tag_dropout_rate: float = 0.0 + token_warmup_min: int = 1 + token_warmup_step: float = 0 + custom_attributes: Optional[Dict[str, Any]] = None + + +@dataclass +class DreamBoothSubsetParams(BaseSubsetParams): + is_reg: bool = False + class_tokens: Optional[str] = None + caption_extension: str = ".caption" + cache_info: bool = False + alpha_mask: bool = False + + +@dataclass +class FineTuningSubsetParams(BaseSubsetParams): + metadata_file: Optional[str] = None + alpha_mask: bool = False + + +@dataclass +class ControlNetSubsetParams(BaseSubsetParams): + conditioning_data_dir: str = None + caption_extension: str = ".caption" + cache_info: bool = False + + +@dataclass +class BaseDatasetParams: + resolution: Optional[Tuple[int, int]] = None + network_multiplier: float = 1.0 + debug_dataset: bool = False + + +@dataclass +class DreamBoothDatasetParams(BaseDatasetParams): + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + prior_loss_weight: float = 1.0 + + +@dataclass +class FineTuningDatasetParams(BaseDatasetParams): + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + + +@dataclass +class ControlNetDatasetParams(BaseDatasetParams): + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + + +@dataclass +class SubsetBlueprint: + params: Union[DreamBoothSubsetParams, FineTuningSubsetParams] + + +@dataclass +class DatasetBlueprint: + is_dreambooth: bool + is_controlnet: bool + params: Union[DreamBoothDatasetParams, FineTuningDatasetParams] + subsets: Sequence[SubsetBlueprint] + + +@dataclass +class DatasetGroupBlueprint: + datasets: Sequence[DatasetBlueprint] + + +@dataclass +class Blueprint: + dataset_group: DatasetGroupBlueprint + + +class ConfigSanitizer: + # @curry + @staticmethod + def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple: + Schema(ExactSequence([klass, klass]))(value) + return tuple(value) + + # @curry + @staticmethod + def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple: + Schema(Any(klass, ExactSequence([klass, klass])))(value) + try: + Schema(klass)(value) + return (value, value) + except: + return ConfigSanitizer.__validate_and_convert_twodim(klass, value) + + # subset schema + SUBSET_ASCENDABLE_SCHEMA = { + "color_aug": bool, + "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float), + "flip_aug": bool, + "num_repeats": int, + "random_crop": bool, + "shuffle_caption": bool, + "keep_tokens": int, + "keep_tokens_separator": str, + "secondary_separator": str, + "caption_separator": str, + "enable_wildcard": bool, + "token_warmup_min": int, + "token_warmup_step": Any(float, int), + "caption_prefix": str, + "caption_suffix": str, + "custom_attributes": dict, + } + # DO means DropOut + DO_SUBSET_ASCENDABLE_SCHEMA = { + "caption_dropout_every_n_epochs": int, + "caption_dropout_rate": Any(float, int), + "caption_tag_dropout_rate": Any(float, int), + } + # DB means DreamBooth + DB_SUBSET_ASCENDABLE_SCHEMA = { + "caption_extension": str, + "class_tokens": str, + "cache_info": bool, + } + DB_SUBSET_DISTINCT_SCHEMA = { + Required("image_dir"): str, + "is_reg": bool, + "alpha_mask": bool, + } + # FT means FineTuning + FT_SUBSET_DISTINCT_SCHEMA = { + Required("metadata_file"): str, + "image_dir": str, + "alpha_mask": bool, + } + CN_SUBSET_ASCENDABLE_SCHEMA = { + "caption_extension": str, + "cache_info": bool, + } + CN_SUBSET_DISTINCT_SCHEMA = { + Required("image_dir"): str, + Required("conditioning_data_dir"): str, + } + + # datasets schema + DATASET_ASCENDABLE_SCHEMA = { + "batch_size": int, + "bucket_no_upscale": bool, + "bucket_reso_steps": int, + "enable_bucket": bool, + "max_bucket_reso": int, + "min_bucket_reso": int, + "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), + "network_multiplier": float, + } + + # options handled by argparse but not handled by user config + ARGPARSE_SPECIFIC_SCHEMA = { + "debug_dataset": bool, + "max_token_length": Any(None, int), + "prior_loss_weight": Any(float, int), + } + # for handling default None value of argparse + ARGPARSE_NULLABLE_OPTNAMES = [ + "face_crop_aug_range", + "resolution", + ] + # prepare map because option name may differ among argparse and user config + ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { + "train_batch_size": "batch_size", + "dataset_repeats": "num_repeats", + } + + def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: + assert support_dreambooth or support_finetuning or support_controlnet, ( + "Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more." + + " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。" + ) + + self.db_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_DISTINCT_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.ft_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.FT_SUBSET_DISTINCT_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.cn_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.CN_SUBSET_DISTINCT_SCHEMA, + self.CN_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.db_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.db_subset_schema]}, + ) + + self.ft_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.ft_subset_schema]}, + ) + + self.cn_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.CN_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.cn_subset_schema]}, + ) + + if support_dreambooth and support_finetuning: + + def validate_flex_dataset(dataset_config: dict): + subsets_config = dataset_config.get("subsets", []) + + if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]): + return Schema(self.cn_dataset_schema)(dataset_config) + # check dataset meets FT style + # NOTE: all FT subsets should have "metadata_file" + elif all(["metadata_file" in subset for subset in subsets_config]): + return Schema(self.ft_dataset_schema)(dataset_config) + # check dataset meets DB style + # NOTE: all DB subsets should have no "metadata_file" + elif all(["metadata_file" not in subset for subset in subsets_config]): + return Schema(self.db_dataset_schema)(dataset_config) + else: + raise voluptuous.Invalid( + "DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。" + ) + + self.dataset_schema = validate_flex_dataset + elif support_dreambooth: + if support_controlnet: + self.dataset_schema = self.cn_dataset_schema + else: + self.dataset_schema = self.db_dataset_schema + elif support_finetuning: + self.dataset_schema = self.ft_dataset_schema + elif support_controlnet: + self.dataset_schema = self.cn_dataset_schema + + self.general_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {}, + self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {}, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.user_config_validator = Schema( + { + "general": self.general_schema, + "datasets": [self.dataset_schema], + } + ) + + self.argparse_schema = self.__merge_dict( + self.general_schema, + self.ARGPARSE_SPECIFIC_SCHEMA, + {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES}, + {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()}, + ) + + self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA) + + def sanitize_user_config(self, user_config: dict) -> dict: + try: + return self.user_config_validator(user_config) + except MultipleInvalid: + # TODO: エラー発生時のメッセージをわかりやすくする + logger.error("Invalid user config / ユーザ設定の形式が正しくないようです") + raise + + # NOTE: In nature, argument parser result is not needed to be sanitize + # However this will help us to detect program bug + def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace: + try: + return self.argparse_config_validator(argparse_namespace) + except MultipleInvalid: + # XXX: this should be a bug + logger.error( + "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。" + ) + raise + + # NOTE: value would be overwritten by latter dict if there is already the same key + @staticmethod + def __merge_dict(*dict_list: dict) -> dict: + merged = {} + for schema in dict_list: + # merged |= schema + for k, v in schema.items(): + merged[k] = v + return merged + + +class BlueprintGenerator: + BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {} + + def __init__(self, sanitizer: ConfigSanitizer): + self.sanitizer = sanitizer + + # runtime_params is for parameters which is only configurable on runtime, such as tokenizer + def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint: + sanitized_user_config = self.sanitizer.sanitize_user_config(user_config) + sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace) + + # convert argparse namespace to dict like config + # NOTE: it is ok to have extra entries in dict + optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME + argparse_config = { + optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items() + } + + general_config = sanitized_user_config.get("general", {}) + + dataset_blueprints = [] + for dataset_config in sanitized_user_config.get("datasets", []): + # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets + subsets = dataset_config.get("subsets", []) + is_dreambooth = all(["metadata_file" not in subset for subset in subsets]) + is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets]) + if is_controlnet: + subset_params_klass = ControlNetSubsetParams + dataset_params_klass = ControlNetDatasetParams + elif is_dreambooth: + subset_params_klass = DreamBoothSubsetParams + dataset_params_klass = DreamBoothDatasetParams + else: + subset_params_klass = FineTuningSubsetParams + dataset_params_klass = FineTuningDatasetParams + + subset_blueprints = [] + for subset_config in subsets: + params = self.generate_params_by_fallbacks( + subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params] + ) + subset_blueprints.append(SubsetBlueprint(params)) + + params = self.generate_params_by_fallbacks( + dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params] + ) + dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints)) + + dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) + + return Blueprint(dataset_group_blueprint) + + @staticmethod + def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]): + name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME + search_value = BlueprintGenerator.search_value + default_params = asdict(param_klass()) + param_names = default_params.keys() + + params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names} + + return param_klass(**params) + + @staticmethod + def search_value(key: str, fallbacks: Sequence[dict], default_value=None): + for cand in fallbacks: + value = cand.get(key) + if value is not None: + return value + + return default_value + + +def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): + datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset + else: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + datasets.append(dataset) + + # print info + info = "" + for i, dataset in enumerate(datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent( + f"""\ + [Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + network_multiplier: {dataset.network_multiplier} + """ + ) + + if dataset.enable_bucket: + info += indent( + dedent( + f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n""" + ), + " ", + ) + else: + info += "\n" + + for j, subset in enumerate(dataset.subsets): + info += indent( + dedent( + f"""\ + [Subset {j} of Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + shuffle_caption: {subset.shuffle_caption} + keep_tokens: {subset.keep_tokens} + keep_tokens_separator: {subset.keep_tokens_separator} + caption_separator: {subset.caption_separator} + secondary_separator: {subset.secondary_separator} + enable_wildcard: {subset.enable_wildcard} + caption_dropout_rate: {subset.caption_dropout_rate} + caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs} + caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + caption_prefix: {subset.caption_prefix} + caption_suffix: {subset.caption_suffix} + color_aug: {subset.color_aug} + flip_aug: {subset.flip_aug} + face_crop_aug_range: {subset.face_crop_aug_range} + random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min} + token_warmup_step: {subset.token_warmup_step} + alpha_mask: {subset.alpha_mask} + custom_attributes: {subset.custom_attributes} + """ + ), + " ", + ) + + if is_dreambooth: + info += indent( + dedent( + f"""\ + is_reg: {subset.is_reg} + class_tokens: {subset.class_tokens} + caption_extension: {subset.caption_extension} + \n""" + ), + " ", + ) + elif not is_controlnet: + info += indent( + dedent( + f"""\ + metadata_file: {subset.metadata_file} + \n""" + ), + " ", + ) + + logger.info(f"{info}") + + # make buckets first because it determines the length of dataset + # and set the same seed for all datasets + seed = random.randint(0, 2**31) # actual seed is seed + epoch_no + for i, dataset in enumerate(datasets): + logger.info(f"[Dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) + + return DatasetGroup(datasets) + + +def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): + def extract_dreambooth_params(name: str) -> Tuple[int, str]: + tokens = name.split("_") + try: + n_repeats = int(tokens[0]) + except ValueError as e: + logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") + return 0, "" + caption_by_folder = "_".join(tokens[1:]) + return n_repeats, caption_by_folder + + def generate(base_dir: Optional[str], is_reg: bool): + if base_dir is None: + return [] + + base_dir: Path = Path(base_dir) + if not base_dir.is_dir(): + return [] + + subsets_config = [] + for subdir in base_dir.iterdir(): + if not subdir.is_dir(): + continue + + num_repeats, class_tokens = extract_dreambooth_params(subdir.name) + if num_repeats < 1: + continue + + subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens} + subsets_config.append(subset_config) + + return subsets_config + + subsets_config = [] + subsets_config += generate(train_data_dir, False) + subsets_config += generate(reg_data_dir, True) + + return subsets_config + + +def generate_controlnet_subsets_config_by_subdirs( + train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt" +): + def generate(base_dir: Optional[str]): + if base_dir is None: + return [] + + base_dir: Path = Path(base_dir) + if not base_dir.is_dir(): + return [] + + subsets_config = [] + subset_config = { + "image_dir": train_data_dir, + "conditioning_data_dir": conditioning_data_dir, + "caption_extension": caption_extension, + "num_repeats": 1, + } + subsets_config.append(subset_config) + + return subsets_config + + subsets_config = [] + subsets_config += generate(train_data_dir) + + return subsets_config + + +def load_user_config(file: str) -> dict: + file_path: Path = Path(file) + if not file_path.is_file(): + #raise ValueError(f"file not found / ファイルが見つかりません: {file}") + return toml.loads(file) + + if file_path.name.lower().endswith(".json"): + try: + with open(file, "r") as f: + config = json.load(f) + except Exception: + logger.error( + f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) + raise + elif file_path.name.lower().endswith(".toml"): + try: + config = toml.load(file_path) + except Exception: + logger.error( + f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) + raise + else: + raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file_path}") + + return config + + +# for config test +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--support_dreambooth", action="store_true") + parser.add_argument("--support_finetuning", action="store_true") + parser.add_argument("--support_controlnet", action="store_true") + parser.add_argument("--support_dropout", action="store_true") + parser.add_argument("dataset_config") + config_args, remain = parser.parse_known_args() + + parser = argparse.ArgumentParser() + train_util.add_dataset_arguments( + parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout + ) + train_util.add_training_arguments(parser, config_args.support_dreambooth) + argparse_namespace = parser.parse_args(remain) + train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) + + logger.info("[argparse_namespace]") + logger.info(f"{vars(argparse_namespace)}") + + user_config = load_user_config(config_args.dataset_config) + + logger.info("") + logger.info("[user_config]") + logger.info(f"{user_config}") + + sanitizer = ConfigSanitizer( + config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout + ) + sanitized_user_config = sanitizer.sanitize_user_config(user_config) + + logger.info("") + logger.info("[sanitized_user_config]") + logger.info(f"{sanitized_user_config}") + + blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) + + logger.info("") + logger.info("[blueprint]") + logger.info(f"{blueprint}") \ No newline at end of file diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0565ea4cc1c2b114414b5cc09089caf1c4a0bbcc --- /dev/null +++ b/library/custom_offloading_utils.py @@ -0,0 +1,227 @@ +from concurrent.futures import ThreadPoolExecutor +import time +from typing import Optional +import torch +import torch.nn as nn + +from .device_utils import clean_memory_on_device + + +def synchronize_device(device: torch.device): + if device.type == "cuda": + torch.cuda.synchronize() + elif device.type == "xpu": + torch.xpu.synchronize() + elif device.type == "mps": + torch.mps.synchronize() + + +def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + + # This is not working for all cases (e.g. SD3), so we need to find the corresponding modules + # for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + # print(module_to_cpu.__class__, module_to_cuda.__class__) + # if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + # weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()} + for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules(): + if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None: + module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None) + if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + else: + if module_to_cuda.weight.data.device.type != device.type: + # print( + # f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device" + # ) + module_to_cuda.weight.data = module_to_cuda.weight.data.to(device) + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # cuda to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.record_stream(stream) + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + stream.synchronize() + + # cpu to cuda + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + stream.synchronize() + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + """ + not tested + """ + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + # device to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + synchronize_device() + + # cpu to device + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + synchronize_device() + + +def weighs_to_device(layer: nn.Module, device: torch.device): + for module in layer.modules(): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data = module.weight.data.to(device, non_blocking=True) + + +class Offloader: + """ + common offloading class + """ + + def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + self.num_blocks = num_blocks + self.blocks_to_swap = blocks_to_swap + self.device = device + self.debug = debug + + self.thread_pool = ThreadPoolExecutor(max_workers=1) + self.futures = {} + self.cuda_available = device.type == "cuda" + + def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module): + if self.cuda_available: + swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda) + else: + swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda) + + def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): + if self.debug: + start_time = time.perf_counter() + print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}") + + self.swap_weight_devices(block_to_cpu, block_to_cuda) + + if self.debug: + print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s") + return bidx_to_cpu, bidx_to_cuda # , event + + block_to_cpu = blocks[block_idx_to_cpu] + block_to_cuda = blocks[block_idx_to_cuda] + + self.futures[block_idx_to_cuda] = self.thread_pool.submit( + move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda + ) + + def _wait_blocks_move(self, block_idx): + if block_idx not in self.futures: + return + + if self.debug: + print(f"Wait for block {block_idx}") + start_time = time.perf_counter() + + future = self.futures.pop(block_idx) + _, bidx_to_cuda = future.result() + + assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}" + + if self.debug: + print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") + + +class ModelOffloader(Offloader): + """ + supports forward offloading + """ + + def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + super().__init__(num_blocks, blocks_to_swap, device, debug) + + # register backward hooks + self.remove_handles = [] + for i, block in enumerate(blocks): + hook = self.create_backward_hook(blocks, i) + if hook is not None: + handle = block.register_full_backward_hook(hook) + self.remove_handles.append(handle) + + def __del__(self): + for handle in self.remove_handles: + handle.remove() + + def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]: + # -1 for 0-based index + num_blocks_propagated = self.num_blocks - block_index - 1 + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap + waiting = block_index > 0 and block_index <= self.blocks_to_swap + + if not swapping and not waiting: + return None + + # create hook + block_idx_to_cpu = self.num_blocks - num_blocks_propagated + block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated + block_idx_to_wait = block_index - 1 + + def backward_hook(module, grad_input, grad_output): + if self.debug: + print(f"Backward hook for block {block_index}") + + if swapping: + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) + if waiting: + self._wait_blocks_move(block_idx_to_wait) + return None + + return backward_hook + + def prepare_block_devices_before_forward(self, blocks: list[nn.Module]): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + + if self.debug: + print("Prepare block devices before forward") + + for b in blocks[0 : self.num_blocks - self.blocks_to_swap]: + b.to(self.device) + weighs_to_device(b, self.device) # make sure weights are on device + + for b in blocks[self.num_blocks - self.blocks_to_swap :]: + b.to(self.device) # move block to device first + weighs_to_device(b, "cpu") # make sure weights are on cpu + + synchronize_device(self.device) + clean_memory_on_device(self.device) + + def wait_for_block(self, block_idx: int): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self._wait_blocks_move(block_idx) + + def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + if block_idx >= self.blocks_to_swap: + return + block_idx_to_cpu = block_idx + block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..41036d7083f4b053c5299bf44ce64cd999dd7123 --- /dev/null +++ b/library/custom_train_functions.py @@ -0,0 +1,556 @@ +import torch +import argparse +import random +import re +from typing import List, Optional, Union +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def prepare_scheduler_for_custom_training(noise_scheduler, device): + if hasattr(noise_scheduler, "all_snr"): + return + + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + alpha = sqrt_alphas_cumprod + sigma = sqrt_one_minus_alphas_cumprod + all_snr = (alpha / sigma) ** 2 + + noise_scheduler.all_snr = all_snr.to(device) + + +def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): + # fix beta: zero terminal SNR + logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891") + + def enforce_zero_terminal_snr(betas): + # Convert betas to alphas_bar_sqrt + alphas = 1 - betas + alphas_bar = alphas.cumprod(0) + alphas_bar_sqrt = alphas_bar.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + # Shift so last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + # Scale so first timestep is back to old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 + alphas = alphas_bar[1:] / alphas_bar[:-1] + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + return betas + + betas = noise_scheduler.betas + betas = enforce_zero_terminal_snr(betas) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + + # logger.info(f"original: {noise_scheduler.betas}") + # logger.info(f"fixed: {betas}") + + noise_scheduler.betas = betas + noise_scheduler.alphas = alphas + noise_scheduler.alphas_cumprod = alphas_cumprod + + +def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False): + snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) + min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) + if v_prediction: + snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device) + else: + snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device) + loss = loss * snr_weight + return loss + + +def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler): + scale = get_snr_scale(timesteps, noise_scheduler) + loss = loss * scale + return loss + + +def get_snr_scale(timesteps, noise_scheduler): + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size + snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 + scale = snr_t / (snr_t + 1) + # # show debug info + # logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}") + return scale + + +def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss): + scale = get_snr_scale(timesteps, noise_scheduler) + # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") + loss = loss + loss / scale * v_pred_like_loss + return loss + + +def apply_debiased_estimation(loss, timesteps, noise_scheduler): + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size + snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 + weight = 1 / torch.sqrt(snr_t) + loss = weight * loss + return loss + + +# TODO train_utilと分散しているのでどちらかに寄せる + + +def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True): + parser.add_argument( + "--min_snr_gamma", + type=float, + default=None, + help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨", + ) + parser.add_argument( + "--scale_v_pred_loss_like_noise_pred", + action="store_true", + help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする", + ) + parser.add_argument( + "--v_pred_like_loss", + type=float, + default=None, + help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する", + ) + parser.add_argument( + "--debiased_estimation_loss", + action="store_true", + help="debiased estimation loss / debiased estimation loss", + ) + if support_weighted_captions: + parser.add_argument( + "--weighted_captions", + action="store_true", + default=False, + help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意", + ) + + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = tokenizer(word).input_ids[1:-1] + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + tokenizer, + text_encoder, + text_input: torch.Tensor, + chunk_length: int, + clip_skip: int, + eos: int, + pad: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos + + if clip_skip is None or clip_skip == 1: + text_embedding = text_encoder(text_input_chunk)[0] + else: + enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) + text_embedding = enc_out["hidden_states"][-clip_skip] + text_embedding = text_encoder.text_model.final_layer_norm(text_embedding) + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + if clip_skip is None or clip_skip == 1: + text_embeddings = text_encoder(text_input)[0] + else: + enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) + text_embeddings = enc_out["hidden_states"][-clip_skip] + text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings) + return text_embeddings + + +def get_weighted_text_embeddings( + tokenizer, + text_encoder, + prompt: Union[str, List[str]], + device, + max_embeddings_multiples: Optional[int] = 3, + no_boseos_middle: Optional[bool] = False, + clip_skip=None, +): + r""" + Prompts can be assigned with local weights using brackets. For example, + prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', + and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. + + Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + no_boseos_middle (`bool`, *optional*, defaults to `False`): + If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and + ending token in each of the chunk in the middle. + skip_parsing (`bool`, *optional*, defaults to `False`): + Skip the parsing of brackets. + skip_weighting (`bool`, *optional*, defaults to `False`): + Skip the weighting. When the parsing is skipped, it is forced True. + """ + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2) + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = tokenizer.bos_token_id + eos = tokenizer.eos_token_id + pad = tokenizer.pad_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) + + # get the embeddings + text_embeddings = get_unweighted_text_embeddings( + tokenizer, + text_encoder, + prompt_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) + + # assign weights to the prompts and normalize in the sense of mean + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + return text_embeddings + + +# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2 +def pyramid_noise_like(noise, device, iterations=6, discount=0.4): + b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant! + u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device) + for i in range(iterations): + r = random.random() * 2 + 2 # Rather than always going 2x, + wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i))) + noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i + if wn == 1 or hn == 1: + break # Lowest resolution is 1x1 + return noise / noise.std() # Scaled back to roughly unit variance + + +# https://www.crosslabs.org//blog/diffusion-with-offset-noise +def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): + if noise_offset is None: + return noise + if adaptive_noise_scale is not None: + # latent shape: (batch_size, channels, height, width) + # abs mean value for each channel + latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True)) + + # multiply adaptive noise scale to the mean value and add it to the noise offset + noise_offset = noise_offset + adaptive_noise_scale * latent_mean + noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative + + noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + return noise + + +def apply_masked_loss(loss, batch): + if "conditioning_images" in batch: + # conditioning image is -1 to 1. we need to convert it to 0 to 1 + mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel + mask_image = mask_image / 2 + 0.5 + # print(f"conditioning_image: {mask_image.shape}") + elif "alpha_masks" in batch and batch["alpha_masks"] is not None: + # alpha mask is 0 to 1 + mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension + # print(f"mask_image: {mask_image.shape}, {mask_image.mean()}") + else: + return loss + + # resize to the same size as the loss + mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area") + loss = loss * mask_image + return loss + + +""" +########################################## +# Perlin Noise +def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3): + delta = (res[0] / shape[0], res[1] / shape[1]) + d = (shape[0] // res[0], shape[1] // res[1]) + + grid = ( + torch.stack( + torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)), + dim=-1, + ) + % 1 + ) + angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device) + gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) + + tile_grads = ( + lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]] + .repeat_interleave(d[0], 0) + .repeat_interleave(d[1], 1) + ) + dot = lambda grad, shift: ( + torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1) + * grad[: shape[0], : shape[1]] + ).sum(dim=-1) + + n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) + n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) + n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) + n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) + t = fade(grid[: shape[0], : shape[1]]) + return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) + + +def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5): + noise = torch.zeros(shape, device=device) + frequency = 1 + amplitude = 1 + for _ in range(octaves): + noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1])) + frequency *= 2 + amplitude *= persistence + return noise + + +def perlin_noise(noise, device, octaves): + _, c, w, h = noise.shape + perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves) + noise_perlin = [] + for _ in range(c): + noise_perlin.append(perlin()) + noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h) + noise += noise_perlin # broadcast for each batch + return noise / noise.std() # Scaled back to roughly unit variance +""" diff --git a/library/deepspeed_utils.py b/library/deepspeed_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..39583c7ac4838ee7cdad6be9de1f3dcf3ae12836 --- /dev/null +++ b/library/deepspeed_utils.py @@ -0,0 +1,139 @@ +import os +import argparse +import torch +from accelerate import DeepSpeedPlugin, Accelerator + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def add_deepspeed_arguments(parser: argparse.ArgumentParser): + # DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed + parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training") + parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.") + parser.add_argument( + "--offload_optimizer_device", + type=str, + default=None, + choices=[None, "cpu", "nvme"], + help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.", + ) + parser.add_argument( + "--offload_optimizer_nvme_path", + type=str, + default=None, + help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.", + ) + parser.add_argument( + "--offload_param_device", + type=str, + default=None, + choices=[None, "cpu", "nvme"], + help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.", + ) + parser.add_argument( + "--offload_param_nvme_path", + type=str, + default=None, + help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.", + ) + parser.add_argument( + "--zero3_init_flag", + action="store_true", + help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models." + "Only applicable with ZeRO Stage-3.", + ) + parser.add_argument( + "--zero3_save_16bit_model", + action="store_true", + help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.", + ) + parser.add_argument( + "--fp16_master_weights_and_gradients", + action="store_true", + help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.", + ) + + +def prepare_deepspeed_args(args: argparse.Namespace): + if not args.deepspeed: + return + + # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1. + args.max_data_loader_n_workers = 1 + + +def prepare_deepspeed_plugin(args: argparse.Namespace): + if not args.deepspeed: + return None + + try: + import deepspeed + except ImportError as e: + logger.error( + "deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed" + ) + exit(1) + + deepspeed_plugin = DeepSpeedPlugin( + zero_stage=args.zero_stage, + gradient_accumulation_steps=args.gradient_accumulation_steps, + gradient_clipping=args.max_grad_norm, + offload_optimizer_device=args.offload_optimizer_device, + offload_optimizer_nvme_path=args.offload_optimizer_nvme_path, + offload_param_device=args.offload_param_device, + offload_param_nvme_path=args.offload_param_nvme_path, + zero3_init_flag=args.zero3_init_flag, + zero3_save_16bit_model=args.zero3_save_16bit_model, + ) + deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size + deepspeed_plugin.deepspeed_config["train_batch_size"] = 1#( + # args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"]) + #) + deepspeed_plugin.set_mixed_precision(args.mixed_precision) + if args.mixed_precision.lower() == "fp16": + deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow. + if args.full_fp16 or args.fp16_master_weights_and_gradients: + if args.offload_optimizer_device == "cpu" and args.zero_stage == 2: + deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True + logger.info("[DeepSpeed] full fp16 enable.") + else: + logger.info( + "[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage." + ) + + if args.offload_optimizer_device is not None: + logger.info("[DeepSpeed] start to manually build cpu_adam.") + deepspeed.ops.op_builder.CPUAdamBuilder().load() + logger.info("[DeepSpeed] building cpu_adam done.") + + return deepspeed_plugin + + +# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model. +def prepare_deepspeed_model(args: argparse.Namespace, **models): + # remove None from models + models = {k: v for k, v in models.items() if v is not None} + + class DeepSpeedWrapper(torch.nn.Module): + def __init__(self, **kw_models) -> None: + super().__init__() + self.models = torch.nn.ModuleDict() + + for key, model in kw_models.items(): + if isinstance(model, list): + model = torch.nn.ModuleList(model) + assert isinstance( + model, torch.nn.Module + ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" + self.models.update(torch.nn.ModuleDict({key: model})) + + def get_models(self): + return self.models + + ds_model = DeepSpeedWrapper(**models) + return ds_model diff --git a/library/device_utils.py b/library/device_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e45ec7ada4082ab1c1273c4d5b419325e5132c50 --- /dev/null +++ b/library/device_utils.py @@ -0,0 +1,84 @@ +import functools +import gc + +import torch + +try: + HAS_CUDA = torch.cuda.is_available() +except Exception: + HAS_CUDA = False + +try: + HAS_MPS = torch.backends.mps.is_available() +except Exception: + HAS_MPS = False + +try: + import intel_extension_for_pytorch as ipex # noqa + + HAS_XPU = torch.xpu.is_available() +except Exception: + HAS_XPU = False + + +def clean_memory(): + gc.collect() + if HAS_CUDA: + torch.cuda.empty_cache() + if HAS_XPU: + torch.xpu.empty_cache() + if HAS_MPS: + torch.mps.empty_cache() + + +def clean_memory_on_device(device: torch.device): + r""" + Clean memory on the specified device, will be called from training scripts. + """ + gc.collect() + + # device may "cuda" or "cuda:0", so we need to check the type of device + if device.type == "cuda": + torch.cuda.empty_cache() + if device.type == "xpu": + torch.xpu.empty_cache() + if device.type == "mps": + torch.mps.empty_cache() + + +@functools.lru_cache(maxsize=None) +def get_preferred_device() -> torch.device: + r""" + Do not call this function from training scripts. Use accelerator.device instead. + """ + if HAS_CUDA: + device = torch.device("cuda") + elif HAS_XPU: + device = torch.device("xpu") + elif HAS_MPS: + device = torch.device("mps") + else: + device = torch.device("cpu") + print(f"get_preferred_device() -> {device}") + return device + + +def init_ipex(): + """ + Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`. + + This function should run right after importing torch and before doing anything else. + + If IPEX is not available, this function does nothing. + """ + try: + if HAS_XPU: + from .ipex import ipex_init + + is_initialized, error_message = ipex_init() + if not is_initialized: + print("failed to initialize ipex:", error_message) + else: + return + except Exception as e: + print("failed to initialize ipex:", e) diff --git a/library/flux_models.py b/library/flux_models.py new file mode 100644 index 0000000000000000000000000000000000000000..b2235d5888aa79183a5241da8bc03614d6eff826 --- /dev/null +++ b/library/flux_models.py @@ -0,0 +1,1060 @@ +# copy from FLUX repo: https://github.com/black-forest-labs/flux +# license: Apache-2.0 License + +from dataclasses import dataclass +import math +from typing import Dict, List, Optional, Union + +from .device_utils import init_ipex +from .custom_offloading_utils import ModelOffloader +init_ipex() + +import torch +from einops import rearrange +from torch import Tensor, nn +from torch.utils.checkpoint import checkpoint + +# USE_REENTRANT = True + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +# region autoencoder + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + ch: int + out_ch: int + ch_mult: list[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) + + +# endregion +# region config + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: str | None + ae_path: str | None + # repo_id: str | None + # repo_flow: str | None + # repo_ae: str | None + + +configs = { + "dev": ModelSpec( + # repo_id="black-forest-labs/FLUX.1-dev", + # repo_flow="flux1-dev.sft", + # repo_ae="ae.sft", + ckpt_path=None, # os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=None, # os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "schnell": ModelSpec( + # repo_id="black-forest-labs/FLUX.1-schnell", + # repo_flow="flux1-schnell.sft", + # repo_ae="ae.sft", + ckpt_path=None, # os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_path=None, # os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +# endregion + +# region math + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor: + q, k = apply_rope(q, k, pe) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + +# endregion + + +# region layers + + +# for cpu_offload_checkpointing + + +def to_cuda(x): + if isinstance(x, torch.Tensor): + return x.cuda() + elif isinstance(x, (list, tuple)): + return [to_cuda(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cuda(v) for k, v in x.items()} + else: + return x + + +def to_cpu(x): + if isinstance(x, torch.Tensor): + return x.cpu() + elif isinstance(x, (list, tuple)): + return [to_cpu(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cpu(v) for k, v in x.items()} + else: + return x + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def _forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + # def forward(self, x): + # if self.training and self.gradient_checkpointing: + # def create_custom_forward(func): + # def custom_forward(*inputs): + # return func(*inputs) + # return custom_forward + # return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, use_reentrant=USE_REENTRANT) + # else: + # return self._forward(x) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + # return (x * rrms).to(dtype=x_dtype) * self.scale + return ((x * rrms) * self.scale.float()).to(dtype=x_dtype) + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + # this is not called from DoubleStreamBlock/SingleStreamBlock because they uses attention function directly + def forward(self, x: Tensor, pe: Tensor) -> Tensor: + qkv = self.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = self.proj(x) + return x + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + +class DoubleStreamBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def _forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None + ) -> tuple[Tensor, Tensor]: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + # make attention mask if not None + attn_mask = None + if txt_attention_mask is not None: + # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask + attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len + attn_mask = torch.cat( + (attn_mask, torch.ones(attn_mask.shape[0], img.shape[1], device=attn_mask.device, dtype=torch.bool)), dim=1 + ) # b, seq_len + img_len + + # broadcast attn_mask to all heads + attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1) + + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img blocks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + + # calculate the txt blocks + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + return img, txt + + def forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None + ) -> tuple[Tensor, Tensor]: + if self.training and self.gradient_checkpointing: + if not self.cpu_offload_checkpointing: + return checkpoint(self._forward, img, txt, vec, pe, txt_attention_mask, use_reentrant=False) + # cpu offload checkpointing + + def create_custom_forward(func): + def custom_forward(*inputs): + cuda_inputs = to_cuda(inputs) + outputs = func(*cuda_inputs) + return to_cpu(outputs) + + return custom_forward + + return torch.utils.checkpoint.checkpoint( + create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False + ) + + else: + return self._forward(img, txt, vec, pe, txt_attention_mask) + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + # make attention mask if not None + attn_mask = None + if txt_attention_mask is not None: + # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask + attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len + attn_mask = torch.cat( + ( + attn_mask, + torch.ones( + attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool + ), + ), + dim=1, + ) # b, seq_len + img_len = x_len + + # broadcast attn_mask to all heads + attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1) + + # compute attention + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) + + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + return x + mod.gate * output + + def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: + if self.training and self.gradient_checkpointing: + if not self.cpu_offload_checkpointing: + return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False) + + # cpu offload checkpointing + + def create_custom_forward(func): + def custom_forward(*inputs): + cuda_inputs = to_cuda(inputs) + outputs = func(*cuda_inputs) + return to_cpu(outputs) + + return custom_forward + + return torch.utils.checkpoint.checkpoint( + create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False + ) + else: + return self._forward(x, vec, pe, txt_attention_mask) + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x + + +# endregion + + +class Flux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + self.blocks_to_swap = None + + self.offloader_double = None + self.offloader_single = None + self.num_double_blocks = len(self.double_blocks) + self.num_single_blocks = len(self.single_blocks) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + self.time_in.enable_gradient_checkpointing() + self.vector_in.enable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.enable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.enable_gradient_checkpointing(cpu_offload=cpu_offload) + + print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + self.time_in.disable_gradient_checkpointing() + self.vector_in.disable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.disable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def enable_block_swap(self, num_blocks: int, device: torch.device): + self.blocks_to_swap = num_blocks + double_blocks_to_swap = num_blocks // 2 + single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + + assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, ( + f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. " + f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks." + ) + + self.offloader_double = ModelOffloader( + self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True + ) + self.offloader_single = ModelOffloader( + self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True + ) + print( + f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." + ) + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage + if self.blocks_to_swap: + save_double_blocks = self.double_blocks + save_single_blocks = self.single_blocks + self.double_blocks = None + self.single_blocks = None + + self.to(device) + + if self.blocks_to_swap: + self.double_blocks = save_double_blocks + self.single_blocks = save_single_blocks + + def prepare_block_swap_before_forward(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + if not self.blocks_to_swap: + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + else: + for block_idx, block in enumerate(self.double_blocks): + self.offloader_double.wait_for_block(block_idx) + + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + + self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) + + img = torch.cat((txt, img), 1) + + for block_idx, block in enumerate(self.single_blocks): + self.offloader_single.wait_for_block(block_idx) + + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + + self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) + + img = img[:, txt.shape[1] :, ...] + + if self.training and self.cpu_offload_checkpointing: + img = img.to(self.device) + vec = vec.to(self.device) + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + + return img diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f69c61d6fa8478e635415734d5f247ec130d476a --- /dev/null +++ b/library/flux_train_utils.py @@ -0,0 +1,585 @@ +import argparse +import math +import os +import numpy as np +import toml +import json +import time +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from accelerate import Accelerator, PartialState +from transformers import CLIPTextModel +from tqdm import tqdm +from PIL import Image + +from safetensors.torch import save_file +from . import flux_models, flux_utils, strategy_base, train_util +from .device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from .utils import setup_logging, mem_eff_save_file + +setup_logging() +import logging + +logger = logging.getLogger(__name__) +# from comfy.utils import ProgressBar + +def sample_images( + accelerator: Accelerator, + args: argparse.Namespace, + epoch, + steps, + flux, + ae, + text_encoders, + sample_prompts_te_outputs, + validation_settings=None, + prompt_replacement=None, +): + + logger.info("") + logger.info(f"generating sample images at step: {steps}") + + #distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + + # unwrap unet and text_encoder(s) + flux = accelerator.unwrap_model(flux) + if text_encoders is not None: + text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) + + prompts = [] + for line in args.sample_prompts: + line = line.strip() + if len(line) > 0 and line[0] != "#": + prompts.append(line) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + from .train_util import line_to_prompt_dict + + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass + + with torch.no_grad(), accelerator.autocast(): + image_tensor_list = [] + for prompt_dict in prompts: + image_tensor = sample_image_inference( + accelerator, + args, + flux, + text_encoders, + ae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + validation_settings + ) + image_tensor_list.append(image_tensor) + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + clean_memory_on_device(accelerator.device) + return torch.cat(image_tensor_list, dim=0) + + +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + flux: flux_models.Flux, + text_encoders: Optional[List[CLIPTextModel]], + ae: flux_models.AutoEncoder, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + validation_settings=None +): + assert isinstance(prompt_dict, dict) + # negative_prompt = prompt_dict.get("negative_prompt") + if validation_settings is not None: + sample_steps = validation_settings["steps"] + width = validation_settings["width"] + height = validation_settings["height"] + scale = validation_settings["guidance_scale"] + seed = validation_settings["seed"] + base_shift = validation_settings["base_shift"] + max_shift = validation_settings["max_shift"] + shift = validation_settings["shift"] + else: + sample_steps = prompt_dict.get("sample_steps", 20) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 3.5) + seed = prompt_dict.get("seed") + base_shift = 0.5 + max_shift = 1.15 + shift = True + # controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + # if negative_prompt is not None: + # negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + # if negative_prompt is None: + # negative_prompt = "" + + height = max(64, height - height % 16) # round to divisible by 16 + width = max(64, width - width % 16) # round to divisible by 16 + logger.info(f"prompt: {prompt}") + # logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") + # logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + + # encode prompts + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + text_encoder_conds = [] + if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: + text_encoder_conds = sample_prompts_te_outputs[prompt] + print(f"Using cached text encoder outputs for prompt: {prompt}") + if text_encoders is not None: + print(f"Encoding prompt: {prompt}") + tokens_and_masks = tokenize_strategy.tokenize(prompt) + # strategy has apply_t5_attn_mask option + encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + + # sample image + weight_dtype = ae.dtype # TOFO give dtype as argument + packed_latent_height = height // 16 + packed_latent_width = width // 16 + noise = torch.randn( + 1, + packed_latent_height * packed_latent_width, + 16 * 2 * 2, + device=accelerator.device, + dtype=weight_dtype, + generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None, + ) + timesteps = get_schedule(sample_steps, noise.shape[1], base_shift=base_shift, max_shift=max_shift, shift=shift) # FLUX.1 dev -> shift=True + #print("TIMESTEPS: ", timesteps) + img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) + t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None + + with accelerator.autocast(), torch.no_grad(): + x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask) + + x = x.float() + x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) + + # latent to image + clean_memory_on_device(accelerator.device) + org_vae_device = ae.device # will be on cpu + ae.to(accelerator.device) # distributed_state.device is same as accelerator.device + with accelerator.autocast(), torch.no_grad(): + x = ae.decode(x) + ae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + + x = x.clamp(-1, 1) + x = x.permute(0, 2, 3, 1) + image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) + + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + return x + + # wandb有効時のみログを送信 + # try: + # wandb_tracker = accelerator.get_tracker("wandb") + # try: + # import wandb + # except ImportError: # 事前に一度確認するのでここはエラー出ないはず + # raise ImportError("No wandb / wandb がインストールされていないようです") + + # wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + # except: # wandb 無効時 + # pass + + +def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: flux_models.Flux, + img: torch.Tensor, + img_ids: torch.Tensor, + txt: torch.Tensor, + txt_ids: torch.Tensor, + vec: torch.Tensor, + timesteps: list[float], + guidance: float = 4.0, + t5_attn_mask: Optional[torch.Tensor] = None, +): + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + # comfy_pbar = ProgressBar(total=len(timesteps)) + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + model.prepare_block_swap_before_forward() + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + img = img + (t_prev - t_curr) * pred + # comfy_pbar.update(1) + model.prepare_block_swap_before_forward() + return img + +# endregion + + +# region train +def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + +def get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, device, dtype +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, _, H, W = latents.shape + sigmas = None + + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": + # Simple random t-based noise sampling + if args.timestep_sampling == "sigmoid": + # https://github.com/XLabs-AI/x-flux/tree/main + t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) + else: + t = torch.rand((bsz,), device=device) + + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1) + noisy_model_input = (1 - t) * latents + t * noise + elif args.timestep_sampling == "shift": + shift = args.discrete_flow_shift + logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + timesteps = logits_norm.sigmoid() + timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) + + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise + elif args.timestep_sampling == "flux_shift": + logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + timesteps = logits_norm.sigmoid() + mu=get_lin_function(y1=0.5, y2=1.15)((H//2) * (W//2)) + timesteps = time_shift(mu, 1.0, timesteps) + + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise + else: + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + timesteps = noise_scheduler.timesteps[indices].to(device=device) + + # Add noise according to flow matching. + sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + return noisy_model_input, timesteps, sigmas + + +def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): + weighting = None + if args.model_prediction_type == "raw": + pass + elif args.model_prediction_type == "additive": + # add the model_pred to the noisy_model_input + model_pred = model_pred + noisy_model_input + elif args.model_prediction_type == "sigma_scaled": + # apply sigma scaling + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + return model_pred, weighting + + +def save_models( + ckpt_path: str, + flux: flux_models.Flux, + sai_metadata: Optional[dict], + save_dtype: Optional[torch.dtype] = None, + use_mem_eff_save: bool = False, +): + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None and v.dtype != save_dtype: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + update_sd("", flux.state_dict()) + + if not use_mem_eff_save: + save_file(state_dict, ckpt_path, metadata=sai_metadata) + else: + mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata) + + +def save_flux_model_on_train_end( + args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") + save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save) + + train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_flux_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + flux: flux_models.Flux, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") + save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save) + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + True, + True, + epoch, + num_train_epochs, + global_step, + sd_saver, + None, + ) + + +# endregion + + +def add_flux_train_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--clip_l", + type=str, + help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)、float16が前提", + ) + parser.add_argument( + "--t5xxl", + type=str, + help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提", + ) + parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") + parser.add_argument( + "--t5xxl_max_token_length", + type=int, + default=None, + help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev" + " / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512", + ) + parser.add_argument( + "--apply_t5_attn_mask", + action="store_true", + help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--timestep_sampling", + choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"], + default="sigma", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。", + ) + parser.add_argument( + "--sigmoid_scale", + type=float, + default=1.0, + help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', + ) + parser.add_argument( + "--model_prediction_type", + choices=["raw", "additive", "sigma_scaled"], + default="sigma_scaled", + help="How to interpret and process the model prediction: " + "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)." + " / モデル予測の解釈と処理方法:" + "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。", + ) + parser.add_argument( + "--discrete_flow_shift", + type=float, + default=3.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", + ) + parser.add_argument( + "--bypass_flux_guidance" + , action="store_true" + , help="bypass flux guidance module for Flex.1-Alpha Training" + ) diff --git a/library/flux_utils.py b/library/flux_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..62c58bd10119b78779097fdd6e4dbce43deba9c5 --- /dev/null +++ b/library/flux_utils.py @@ -0,0 +1,474 @@ +from dataclasses import replace +import json +import os +from typing import List, Optional, Tuple, Union +import einops +import torch + +from safetensors import safe_open +from accelerate import init_empty_weights +from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config + +from .flux_models import Flux, AutoEncoder, configs +from .utils import setup_logging, load_safetensors + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +MODEL_VERSION_FLUX_V1 = "flux1" +MODEL_NAME_DEV = "dev" +MODEL_NAME_SCHNELL = "schnell" + +# bypass guidance +def bypass_flux_guidance(transformer): + transformer.params.guidance_embed = False + +# restore the forward function +def restore_flux_guidance(transformer): + transformer.params.guidance_embed = True + +def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: + """ + チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。 + + Args: + ckpt_path (str): チェックポイントファイルまたはディレクトリのパス。 + + Returns: + Tuple[bool, bool, Tuple[int, int], List[str]]: + - bool: Diffusersかどうかを示すフラグ。 + - bool: Schnellかどうかを示すフラグ。 + - Tuple[int, int]: ダブルブロックとシングルブロックの数。 + - List[str]: チェックポイントに含まれるキーのリスト。 + """ + # check the state dict: Diffusers or BFL, dev or schnell, number of blocks + logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell") + + if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers + ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors") + if "00001-of-00003" in ckpt_path: + ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)] + else: + ckpt_paths = [ckpt_path] + + keys = [] + for ckpt_path in ckpt_paths: + with safe_open(ckpt_path, framework="pt") as f: + keys.extend(f.keys()) + + if keys[0].startswith("model.diffusion_model."): + keys = [key.replace("model.diffusion_model.", "") for key in keys] + + is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys + is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) + + # check number of double and single blocks + if not is_diffusers: + max_double_block_index = max( + [int(key.split(".")[1]) for key in keys if key.startswith("double_blocks.") and key.endswith(".img_attn.proj.bias")] + ) + max_single_block_index = max( + [int(key.split(".")[1]) for key in keys if key.startswith("single_blocks.") and key.endswith(".modulation.lin.bias")] + ) + else: + max_double_block_index = max( + [ + int(key.split(".")[1]) + for key in keys + if key.startswith("transformer_blocks.") and key.endswith(".attn.add_k_proj.bias") + ] + ) + max_single_block_index = max( + [ + int(key.split(".")[1]) + for key in keys + if key.startswith("single_transformer_blocks.") and key.endswith(".attn.to_k.bias") + ] + ) + + num_double_blocks = max_double_block_index + 1 + num_single_blocks = max_single_block_index + 1 + + return is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths + + +def load_flow_model( + ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False +) -> Tuple[bool, Flux]: + is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + + # build model + logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint") + with torch.device("meta"): + params = configs[name].params + + # set the number of blocks + if params.depth != num_double_blocks: + logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}") + params = replace(params, depth=num_double_blocks) + if params.depth_single_blocks != num_single_blocks: + logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}") + params = replace(params, depth_single_blocks=num_single_blocks) + + model = Flux(params) + if dtype is not None: + model = model.to(dtype) + + # load_sft doesn't support torch.device + logger.info(f"Loading state dict from {ckpt_path}") + sd = {} + for ckpt_path in ckpt_paths: + sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)) + + # convert Diffusers to BFL + if is_diffusers: + logger.info("Converting Diffusers to BFL") + sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks) + logger.info("Converted Diffusers to BFL") + + for key in list(sd.keys()): + new_key = key.replace("model.diffusion_model.", "") + if new_key == key: + break + sd[new_key] = sd.pop(key) + + info = model.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded Flux: {info}") + return is_schnell, model + + +def load_ae( + ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False +) -> AutoEncoder: + logger.info("Building AutoEncoder") + with torch.device("meta"): + # dev and schnell have the same AE params + ae = AutoEncoder(configs[MODEL_NAME_DEV].ae_params).to(dtype) + + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = ae.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded AE: {info}") + return ae + + +def load_clip_l( + ckpt_path: Optional[str], + dtype: torch.dtype, + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, +) -> CLIPTextModel: + logger.info("Building CLIP-L") + CLIPL_CONFIG = { + "_name_or_path": "clip-vit-large-patch14/", + "architectures": ["CLIPModel"], + "initializer_factor": 1.0, + "logit_scale_init_value": 2.6592, + "model_type": "clip", + "projection_dim": 768, + # "text_config": { + "_name_or_path": "", + "add_cross_attention": False, + "architectures": None, + "attention_dropout": 0.0, + "bad_words_ids": None, + "bos_token_id": 0, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": None, + "decoder_start_token_id": None, + "diversity_penalty": 0.0, + "do_sample": False, + "dropout": 0.0, + "early_stopping": False, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": 2, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "hidden_act": "quick_gelu", + "hidden_size": 768, + "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 3072, + "is_decoder": False, + "is_encoder_decoder": False, + "label2id": {"LABEL_0": 0, "LABEL_1": 1}, + "layer_norm_eps": 1e-05, + "length_penalty": 1.0, + "max_length": 20, + "max_position_embeddings": 77, + "min_length": 0, + "model_type": "clip_text_model", + "no_repeat_ngram_size": 0, + "num_attention_heads": 12, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 12, + "num_return_sequences": 1, + "output_attentions": False, + "output_hidden_states": False, + "output_scores": False, + "pad_token_id": 1, + "prefix": None, + "problem_type": None, + "projection_dim": 768, + "pruned_heads": {}, + "remove_invalid_values": False, + "repetition_penalty": 1.0, + "return_dict": True, + "return_dict_in_generate": False, + "sep_token_id": None, + "task_specific_params": None, + "temperature": 1.0, + "tie_encoder_decoder": False, + "tie_word_embeddings": True, + "tokenizer_class": None, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": None, + "torchscript": False, + "transformers_version": "4.16.0.dev0", + "use_bfloat16": False, + "vocab_size": 49408, + "hidden_act": "gelu", + "hidden_size": 1280, + "intermediate_size": 5120, + "num_attention_heads": 20, + "num_hidden_layers": 32, + # }, + # "text_config_dict": { + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "projection_dim": 768, + # }, + # "torch_dtype": "float32", + # "transformers_version": None, + } + config = CLIPConfig(**CLIPL_CONFIG) + with init_empty_weights(): + clip = CLIPTextModel._from_config(config) + + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = clip.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded CLIP-L: {info}") + return clip + + +def load_t5xxl( + ckpt_path: str, + dtype: Optional[torch.dtype], + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, +) -> T5EncoderModel: + T5_CONFIG_JSON = """ +{ + "architectures": [ + "T5EncoderModel" + ], + "classifier_dropout": 0.0, + "d_ff": 10240, + "d_kv": 64, + "d_model": 4096, + "decoder_start_token_id": 0, + "dense_act_fn": "gelu_new", + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": true, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 24, + "num_heads": 64, + "num_layers": 24, + "output_past": true, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.41.2", + "use_cache": true, + "vocab_size": 32128 +} +""" + config = json.loads(T5_CONFIG_JSON) + config = T5Config(**config) + with init_empty_weights(): + t5xxl = T5EncoderModel._from_config(config) + + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = t5xxl.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded T5xxl: {info}") + return t5xxl + + +def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype: + # nn.Embedding is the first layer, but it could be casted to bfloat16 or float32 + return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype + + +def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int): + img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :] + img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + return img_ids + + +def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor: + """ + x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) + return x + + +def pack_latents(x: torch.Tensor) -> torch.Tensor: + """ + x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + return x + + +# region Diffusers + +NUM_DOUBLE_BLOCKS = 19 +NUM_SINGLE_BLOCKS = 38 + +BFL_TO_DIFFUSERS_MAP = { + "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], + "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], + "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], + "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], + "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], + "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], + "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], + "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], + "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], + "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], + "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], + "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], + "txt_in.weight": ["context_embedder.weight"], + "txt_in.bias": ["context_embedder.bias"], + "img_in.weight": ["x_embedder.weight"], + "img_in.bias": ["x_embedder.bias"], + "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], + "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], + "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], + "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], + "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], + "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], + "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], + "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], + "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], + "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], + "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], + "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], + "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], + "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], + "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], + "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], + "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], + "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], + "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], + "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], + "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], + "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], + "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], + "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], + "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], + "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], + "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], + "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], + "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().linear2.bias": ["proj_out.bias"], + "final_layer.linear.weight": ["proj_out.weight"], + "final_layer.linear.bias": ["proj_out.bias"], + "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], + "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], +} + + +def make_diffusers_to_bfl_map(num_double_blocks: int, num_single_blocks: int) -> dict[str, tuple[int, str]]: + # make reverse map from diffusers map + diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) + for b in range(num_double_blocks): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("double_blocks."): + block_prefix = f"transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for b in range(num_single_blocks): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("single_blocks."): + block_prefix = f"single_transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")): + for i, weight in enumerate(weights): + diffusers_to_bfl_map[weight] = (i, key) + return diffusers_to_bfl_map + + +def convert_diffusers_sd_to_bfl( + diffusers_sd: dict[str, torch.Tensor], num_double_blocks: int = NUM_DOUBLE_BLOCKS, num_single_blocks: int = NUM_SINGLE_BLOCKS +) -> dict[str, torch.Tensor]: + diffusers_to_bfl_map = make_diffusers_to_bfl_map(num_double_blocks, num_single_blocks) + + # iterate over three safetensors files to reduce memory usage + flux_sd = {} + for diffusers_key, tensor in diffusers_sd.items(): + if diffusers_key in diffusers_to_bfl_map: + index, bfl_key = diffusers_to_bfl_map[diffusers_key] + if bfl_key not in flux_sd: + flux_sd[bfl_key] = [] + flux_sd[bfl_key].append((index, tensor)) + else: + logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}") + raise KeyError(f"Key not found in diffusers_to_bfl_map: {diffusers_key}") + + # concat tensors if multiple tensors are mapped to a single key, sort by index + for key, values in flux_sd.items(): + if len(values) == 1: + flux_sd[key] = values[0][1] + else: + flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])]) + + # special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + if "final_layer.adaLN_modulation.1.weight" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"]) + if "final_layer.adaLN_modulation.1.bias" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"]) + + return flux_sd + + +# endregion \ No newline at end of file diff --git a/library/huggingface_util.py b/library/huggingface_util.py new file mode 100644 index 0000000000000000000000000000000000000000..07a97606ef935a88a77756c287cfb57495afc5b0 --- /dev/null +++ b/library/huggingface_util.py @@ -0,0 +1,84 @@ +from typing import Union, BinaryIO +from huggingface_hub import HfApi +from pathlib import Path +import argparse +import os +from .utils import fire_in_thread +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) + +def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None): + api = HfApi( + token=token, + ) + try: + api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) + return True + except: + return False + + +def upload( + args: argparse.Namespace, + src: Union[str, Path, bytes, BinaryIO], + dest_suffix: str = "", + force_sync_upload: bool = False, +): + repo_id = args.huggingface_repo_id + repo_type = args.huggingface_repo_type + token = args.huggingface_token + path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None + private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public" + api = HfApi(token=token) + if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token): + try: + api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) + except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので + logger.error("===========================================") + logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}") + logger.error("===========================================") + + is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir()) + + def uploader(): + try: + if is_folder: + api.upload_folder( + repo_id=repo_id, + repo_type=repo_type, + folder_path=src, + path_in_repo=path_in_repo, + ) + else: + api.upload_file( + repo_id=repo_id, + repo_type=repo_type, + path_or_fileobj=src, + path_in_repo=path_in_repo, + ) + except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので + logger.error("===========================================") + logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") + logger.error("===========================================") + + if args.async_upload and not force_sync_upload: + fire_in_thread(uploader) + else: + uploader() + + +def list_dir( + repo_id: str, + subfolder: str, + repo_type: str, + revision: str = "main", + token: str = None, +): + api = HfApi( + token=token, + ) + repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) + file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)] + return file_list diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e5aba693c50450393872be2456dff8f1accabb3d --- /dev/null +++ b/library/ipex/__init__.py @@ -0,0 +1,180 @@ +import os +import sys +import contextlib +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +from .hijacks import ipex_hijacks + +# pylint: disable=protected-access, missing-function-docstring, line-too-long + +def ipex_init(): # pylint: disable=too-many-statements + try: + if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked: + return True, "Skipping IPEX hijack" + else: + # Replace cuda with xpu: + torch.cuda.current_device = torch.xpu.current_device + torch.cuda.current_stream = torch.xpu.current_stream + torch.cuda.device = torch.xpu.device + torch.cuda.device_count = torch.xpu.device_count + torch.cuda.device_of = torch.xpu.device_of + torch.cuda.get_device_name = torch.xpu.get_device_name + torch.cuda.get_device_properties = torch.xpu.get_device_properties + torch.cuda.init = torch.xpu.init + torch.cuda.is_available = torch.xpu.is_available + torch.cuda.is_initialized = torch.xpu.is_initialized + torch.cuda.is_current_stream_capturing = lambda: False + torch.cuda.set_device = torch.xpu.set_device + torch.cuda.stream = torch.xpu.stream + torch.cuda.synchronize = torch.xpu.synchronize + torch.cuda.Event = torch.xpu.Event + torch.cuda.Stream = torch.xpu.Stream + torch.cuda.FloatTensor = torch.xpu.FloatTensor + torch.Tensor.cuda = torch.Tensor.xpu + torch.Tensor.is_cuda = torch.Tensor.is_xpu + torch.nn.Module.cuda = torch.nn.Module.xpu + torch.UntypedStorage.cuda = torch.UntypedStorage.xpu + torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock + torch.cuda._initialized = torch.xpu.lazy_init._initialized + torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker + torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls + torch.cuda._tls = torch.xpu.lazy_init._tls + torch.cuda.threading = torch.xpu.lazy_init.threading + torch.cuda.traceback = torch.xpu.lazy_init.traceback + torch.cuda.Optional = torch.xpu.Optional + torch.cuda.__cached__ = torch.xpu.__cached__ + torch.cuda.__loader__ = torch.xpu.__loader__ + torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage + torch.cuda.Tuple = torch.xpu.Tuple + torch.cuda.streams = torch.xpu.streams + torch.cuda._lazy_new = torch.xpu._lazy_new + torch.cuda.FloatStorage = torch.xpu.FloatStorage + torch.cuda.Any = torch.xpu.Any + torch.cuda.__doc__ = torch.xpu.__doc__ + torch.cuda.default_generators = torch.xpu.default_generators + torch.cuda.HalfTensor = torch.xpu.HalfTensor + torch.cuda._get_device_index = torch.xpu._get_device_index + torch.cuda.__path__ = torch.xpu.__path__ + torch.cuda.Device = torch.xpu.Device + torch.cuda.IntTensor = torch.xpu.IntTensor + torch.cuda.ByteStorage = torch.xpu.ByteStorage + torch.cuda.set_stream = torch.xpu.set_stream + torch.cuda.BoolStorage = torch.xpu.BoolStorage + torch.cuda.os = torch.xpu.os + torch.cuda.torch = torch.xpu.torch + torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage + torch.cuda.Union = torch.xpu.Union + torch.cuda.DoubleTensor = torch.xpu.DoubleTensor + torch.cuda.ShortTensor = torch.xpu.ShortTensor + torch.cuda.LongTensor = torch.xpu.LongTensor + torch.cuda.IntStorage = torch.xpu.IntStorage + torch.cuda.LongStorage = torch.xpu.LongStorage + torch.cuda.__annotations__ = torch.xpu.__annotations__ + torch.cuda.__package__ = torch.xpu.__package__ + torch.cuda.__builtins__ = torch.xpu.__builtins__ + torch.cuda.CharTensor = torch.xpu.CharTensor + torch.cuda.List = torch.xpu.List + torch.cuda._lazy_init = torch.xpu._lazy_init + torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor + torch.cuda.DoubleStorage = torch.xpu.DoubleStorage + torch.cuda.ByteTensor = torch.xpu.ByteTensor + torch.cuda.StreamContext = torch.xpu.StreamContext + torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage + torch.cuda.ShortStorage = torch.xpu.ShortStorage + torch.cuda._lazy_call = torch.xpu._lazy_call + torch.cuda.HalfStorage = torch.xpu.HalfStorage + torch.cuda.random = torch.xpu.random + torch.cuda._device = torch.xpu._device + torch.cuda.classproperty = torch.xpu.classproperty + torch.cuda.__name__ = torch.xpu.__name__ + torch.cuda._device_t = torch.xpu._device_t + torch.cuda.warnings = torch.xpu.warnings + torch.cuda.__spec__ = torch.xpu.__spec__ + torch.cuda.BoolTensor = torch.xpu.BoolTensor + torch.cuda.CharStorage = torch.xpu.CharStorage + torch.cuda.__file__ = torch.xpu.__file__ + torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork + # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + + # Memory: + torch.cuda.memory = torch.xpu.memory + if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): + torch.xpu.empty_cache = lambda: None + torch.cuda.empty_cache = torch.xpu.empty_cache + torch.cuda.memory_stats = torch.xpu.memory_stats + torch.cuda.memory_summary = torch.xpu.memory_summary + torch.cuda.memory_snapshot = torch.xpu.memory_snapshot + torch.cuda.memory_allocated = torch.xpu.memory_allocated + torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated + torch.cuda.memory_reserved = torch.xpu.memory_reserved + torch.cuda.memory_cached = torch.xpu.memory_reserved + torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved + torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved + torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats + torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict + torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats + + # RNG: + torch.cuda.get_rng_state = torch.xpu.get_rng_state + torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all + torch.cuda.set_rng_state = torch.xpu.set_rng_state + torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all + torch.cuda.manual_seed = torch.xpu.manual_seed + torch.cuda.manual_seed_all = torch.xpu.manual_seed_all + torch.cuda.seed = torch.xpu.seed + torch.cuda.seed_all = torch.xpu.seed_all + torch.cuda.initial_seed = torch.xpu.initial_seed + + # AMP: + torch.cuda.amp = torch.xpu.amp + torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled + torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype + + if not hasattr(torch.cuda.amp, "common"): + torch.cuda.amp.common = contextlib.nullcontext() + torch.cuda.amp.common.amp_definitely_not_available = lambda: False + + try: + torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler + except Exception: # pylint: disable=broad-exception-caught + try: + from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error + gradscaler_init() + torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler + except Exception: # pylint: disable=broad-exception-caught + torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler + + # C + torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream + ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count + ipex._C._DeviceProperties.major = 2024 + ipex._C._DeviceProperties.minor = 0 + + # Fix functions with ipex: + torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] + torch._utils._get_available_device_type = lambda: "xpu" + torch.has_cuda = True + torch.cuda.has_half = True + torch.cuda.is_bf16_supported = lambda *args, **kwargs: True + torch.cuda.is_fp16_supported = lambda *args, **kwargs: True + torch.backends.cuda.is_built = lambda *args, **kwargs: True + torch.version.cuda = "12.1" + torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1] + torch.cuda.get_device_properties.major = 12 + torch.cuda.get_device_properties.minor = 1 + torch.cuda.ipc_collect = lambda *args, **kwargs: None + torch.cuda.utilization = lambda *args, **kwargs: 0 + + ipex_hijacks() + if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None: + try: + from .diffusers import ipex_diffusers + ipex_diffusers() + except Exception: # pylint: disable=broad-exception-caught + pass + torch.cuda.is_xpu_hijacked = True + except Exception as e: + return False, e + return True, None diff --git a/library/ipex/attention.py b/library/ipex/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..2bc62f65c3b3bde29559814ee0c4a92d71a306f8 --- /dev/null +++ b/library/ipex/attention.py @@ -0,0 +1,177 @@ +import os +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +from functools import cache + +# pylint: disable=protected-access, missing-function-docstring, line-too-long + +# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers + +sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4)) +attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) + +# Find something divisible with the input_tokens +@cache +def find_slice_size(slice_size, slice_block_size): + while (slice_size * slice_block_size) > attention_slice_rate: + slice_size = slice_size // 2 + if slice_size <= 1: + slice_size = 1 + break + return slice_size + +# Find slice sizes for SDPA +@cache +def find_sdpa_slice_sizes(query_shape, query_element_size): + if len(query_shape) == 3: + batch_size_attention, query_tokens, shape_three = query_shape + shape_four = 1 + else: + batch_size_attention, query_tokens, shape_three, shape_four = query_shape + + slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size + block_size = batch_size_attention * slice_block_size + + split_slice_size = batch_size_attention + split_2_slice_size = query_tokens + split_3_slice_size = shape_three + + do_split = False + do_split_2 = False + do_split_3 = False + + if block_size > sdpa_slice_trigger_rate: + do_split = True + split_slice_size = find_slice_size(split_slice_size, slice_block_size) + if split_slice_size * slice_block_size > attention_slice_rate: + slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size + do_split_2 = True + split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) + if split_2_slice_size * slice_2_block_size > attention_slice_rate: + slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size + do_split_3 = True + split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) + + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + +# Find slice sizes for BMM +@cache +def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape): + batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2] + slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size + block_size = batch_size_attention * slice_block_size + + split_slice_size = batch_size_attention + split_2_slice_size = input_tokens + split_3_slice_size = mat2_atten_shape + + do_split = False + do_split_2 = False + do_split_3 = False + + if block_size > attention_slice_rate: + do_split = True + split_slice_size = find_slice_size(split_slice_size, slice_block_size) + if split_slice_size * slice_block_size > attention_slice_rate: + slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size + do_split_2 = True + split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) + if split_2_slice_size * slice_2_block_size > attention_slice_rate: + slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size + do_split_3 = True + split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) + + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + + +original_torch_bmm = torch.bmm +def torch_bmm_32_bit(input, mat2, *, out=None): + if input.device.type != "xpu": + return original_torch_bmm(input, mat2, out=out) + do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape) + + # Slice BMM + if do_split: + batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2] + hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype) + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + if do_split_3: + for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm( + input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + out=out + ) + else: + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( + input[start_idx:end_idx, start_idx_2:end_idx_2], + mat2[start_idx:end_idx, start_idx_2:end_idx_2], + out=out + ) + else: + hidden_states[start_idx:end_idx] = original_torch_bmm( + input[start_idx:end_idx], + mat2[start_idx:end_idx], + out=out + ) + torch.xpu.synchronize(input.device) + else: + return original_torch_bmm(input, mat2, out=out) + return hidden_states + +original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention +def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs): + if query.device.type != "xpu": + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) + do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size()) + + # Slice SDPA + if do_split: + batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] + hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + if do_split_3: + for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention( + query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal, **kwargs + ) + else: + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( + query[start_idx:end_idx, start_idx_2:end_idx_2], + key[start_idx:end_idx, start_idx_2:end_idx_2], + value[start_idx:end_idx, start_idx_2:end_idx_2], + attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal, **kwargs + ) + else: + hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal, **kwargs + ) + torch.xpu.synchronize(query.device) + else: + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) + return hidden_states diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py new file mode 100644 index 0000000000000000000000000000000000000000..732a185689e9882d63be68b5c5d6ee6d82c74f71 --- /dev/null +++ b/library/ipex/diffusers.py @@ -0,0 +1,312 @@ +import os +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import diffusers #0.24.0 # pylint: disable=import-error +from diffusers.models.attention_processor import Attention +from diffusers.utils import USE_PEFT_BACKEND +from functools import cache + +# pylint: disable=protected-access, missing-function-docstring, line-too-long + +attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) + +@cache +def find_slice_size(slice_size, slice_block_size): + while (slice_size * slice_block_size) > attention_slice_rate: + slice_size = slice_size // 2 + if slice_size <= 1: + slice_size = 1 + break + return slice_size + +@cache +def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None): + if len(query_shape) == 3: + batch_size_attention, query_tokens, shape_three = query_shape + shape_four = 1 + else: + batch_size_attention, query_tokens, shape_three, shape_four = query_shape + if slice_size is not None: + batch_size_attention = slice_size + + slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size + block_size = batch_size_attention * slice_block_size + + split_slice_size = batch_size_attention + split_2_slice_size = query_tokens + split_3_slice_size = shape_three + + do_split = False + do_split_2 = False + do_split_3 = False + + if query_device_type != "xpu": + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + + if block_size > attention_slice_rate: + do_split = True + split_slice_size = find_slice_size(split_slice_size, slice_block_size) + if split_slice_size * slice_block_size > attention_slice_rate: + slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size + do_split_2 = True + split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) + if split_2_slice_size * slice_2_block_size > attention_slice_rate: + slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size + do_split_3 = True + split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) + + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + +class SlicedAttnProcessor: # pylint: disable=too-few-public-methods + r""" + Processor for implementing sliced attention. + + Args: + slice_size (`int`, *optional*): + The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and + `attention_head_dim` must be a multiple of the `slice_size`. + """ + + def __init__(self, slice_size): + self.slice_size = slice_size + + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, + encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches + + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + batch_size_attention, query_tokens, shape_three = query.shape + hidden_states = torch.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype + ) + + #################################################################### + # ARC GPUs can't allocate more than 4GB to a single block, Slice it: + _, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size) + + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + if do_split_3: + for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size + + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice + del attn_slice + else: + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice + del attn_slice + torch.xpu.synchronize(query.device) + else: + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + del attn_slice + #################################################################### + + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, + encoder_hidden_states=None, attention_mask=None, + temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches + + residual = hidden_states + + args = () if USE_PEFT_BACKEND else (scale,) + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states, *args) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states, *args) + value = attn.to_v(encoder_hidden_states, *args) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + #################################################################### + # ARC GPUs can't allocate more than 4GB to a single block, Slice it: + batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] + hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) + do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type) + + if do_split: + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + if do_split_3: + for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size + + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice + del attn_slice + else: + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice + del attn_slice + else: + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + del attn_slice + torch.xpu.synchronize(query.device) + else: + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + #################################################################### + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, *args) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +def ipex_diffusers(): + #ARC GPUs can't allocate more than 4GB to a single block: + diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor + diffusers.models.attention_processor.AttnProcessor = AttnProcessor diff --git a/library/ipex/gradscaler.py b/library/ipex/gradscaler.py new file mode 100644 index 0000000000000000000000000000000000000000..6eb56bc2b821e8530557f517ebeaafa141b763a6 --- /dev/null +++ b/library/ipex/gradscaler.py @@ -0,0 +1,183 @@ +from collections import defaultdict +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import + +# pylint: disable=protected-access, missing-function-docstring, line-too-long + +device_supports_fp64 = torch.xpu.has_fp64_dtype() +OptState = ipex.cpu.autocast._grad_scaler.OptState +_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator +_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state + +def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument + per_device_inv_scale = _MultiDeviceReplicator(inv_scale) + per_device_found_inf = _MultiDeviceReplicator(found_inf) + + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be hundreds of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] + # sync grad to master weight + if hasattr(optimizer, "sync_grad"): + optimizer.sync_grad() + with torch.no_grad(): + for group in optimizer.param_groups: + for param in group["params"]: + if param.grad is None: + continue + if (not allow_fp16) and param.grad.dtype == torch.float16: + raise ValueError("Attempting to unscale FP16 gradients.") + if param.grad.is_sparse: + # is_coalesced() == False means the sparse grad has values with duplicate indices. + # coalesce() deduplicates indices and adds all values that have the same index. + # For scaled fp16 values, there's a good chance coalescing will cause overflow, + # so we should check the coalesced _values(). + if param.grad.dtype is torch.float16: + param.grad = param.grad.coalesce() + to_unscale = param.grad._values() + else: + to_unscale = param.grad + + # -: is there a way to split by device and dtype without appending in the inner loop? + to_unscale = to_unscale.to("cpu") + per_device_and_dtype_grads[to_unscale.device][ + to_unscale.dtype + ].append(to_unscale) + + for _, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + core._amp_foreach_non_finite_check_and_unscale_( + grads, + per_device_found_inf.get("cpu"), + per_device_inv_scale.get("cpu"), + ) + + return per_device_found_inf._per_device_tensors + +def unscale_(self, optimizer): + """ + Divides ("unscales") the optimizer's gradient tensors by the scale factor. + :meth:`unscale_` is optional, serving cases where you need to + :ref:`modify or inspect gradientsstr: + return ( + f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, " + + f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, " + + f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, " + + f"unsharp_target_x={self.unsharp_target_x})" + ) + + def apply_unshark_mask(self, x: torch.Tensor): + if self.gaussian_blur_ksize is None: + return x + blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma) + # mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength) + mask = (x - blurred) * self.gaussian_blur_strength + sharpened = x + mask + return sharpened + + def interpolate(self, x: torch.Tensor, resized_size, unsharp=True): + org_dtype = x.dtype + if org_dtype == torch.bfloat16: + x = x.float() + + x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype) + + # apply unsharp mask / アンシャープマスクを適用する + if unsharp and self.gaussian_blur_ksize: + x = self.apply_unshark_mask(x) + + return x + + +class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.resized_size = None + self.gradual_latent = None + + def set_gradual_latent_params(self, size, gradual_latent: GradualLatent): + self.resized_size = size + self.gradual_latent = gradual_latent + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + + """ + + if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + # logger.warning( + print( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + elif self.config.prediction_type == "sample": + raise NotImplementedError("prediction_type not implemented yet: sample") + else: + raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`") + + sigma_from = self.sigmas[self.step_index] + sigma_to = self.sigmas[self.step_index + 1] + sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + + dt = sigma_down - sigma + + device = model_output.device + if self.resized_size is None: + prev_sample = sample + derivative * dt + + noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( + model_output.shape, dtype=model_output.dtype, device=device, generator=generator + ) + s_noise = 1.0 + else: + print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape) + s_noise = self.gradual_latent.s_noise + + if self.gradual_latent.unsharp_target_x: + prev_sample = sample + derivative * dt + prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size) + else: + sample = self.gradual_latent.interpolate(sample, self.resized_size) + derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False) + prev_sample = sample + derivative * dt + + noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( + (model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]), + dtype=model_output.dtype, + device=device, + generator=generator, + ) + + prev_sample = prev_sample + noise * sigma_up * s_noise + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + +# endregion + +def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): + """ + memory efficient save file + """ + + _TYPES = { + torch.float64: "F64", + torch.float32: "F32", + torch.float16: "F16", + torch.bfloat16: "BF16", + torch.int64: "I64", + torch.int32: "I32", + torch.int16: "I16", + torch.int8: "I8", + torch.uint8: "U8", + torch.bool: "BOOL", + getattr(torch, "float8_e5m2", None): "F8_E5M2", + getattr(torch, "float8_e4m3fn", None): "F8_E4M3", + } + _ALIGN = 256 + + def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: + validated = {} + for key, value in metadata.items(): + if not isinstance(key, str): + raise ValueError(f"Metadata key must be a string, got {type(key)}") + if not isinstance(value, str): + print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") + validated[key] = str(value) + else: + validated[key] = value + return validated + + print(f"Using memory efficient save file: {filename}") + + header = {} + offset = 0 + if metadata: + header["__metadata__"] = validate_metadata(metadata) + for k, v in tensors.items(): + if v.numel() == 0: # empty tensor + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} + else: + size = v.numel() * v.element_size() + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} + offset += size + + hjson = json.dumps(header).encode("utf-8") + hjson += b" " * (-(len(hjson) + 8) % _ALIGN) + + with open(filename, "wb") as f: + f.write(struct.pack("b i ...", bi, inp) + inp = inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3) + + if rescale is not None: + inp = inp * rescale + + return inp - org + + +def bypass_forward_diff(org_out, *weights, constraint=None, need_transpose=False): + """### boft_bypass_forward_diff + + Args: + x (torch.Tensor): the input tensor for original model + org_out (torch.Tensor): the output tensor from original model + weights (tuple[torch.Tensor]): (oft_blocks[, rescale_weight]) + constraint (float, optional): constraint for oft + need_transpose (bool, optional): + whether to transpose the input and output, + set to `True` if the original model have "dim" not in the last axis. + For example: Convolution layers + + Returns: + torch.Tensor: output tensor + """ + oft_blocks, rescale = weights + m, num, b, _ = oft_blocks.shape + r_b = b // 2 + I = torch.eye(b, device=oft_blocks.device) + r = get_r(oft_blocks, I, constraint) + inp = org = org_out.to(dtype=r.dtype) + if need_transpose: + inp = org = inp.transpose(1, -1) + + for i in range(m): + bi = r[i] # b_num, b_size, b_size + g = 2 + k = 2**i * r_b + # ... (c g k) ->... (c k g) + # ... (d b) -> ... d b + inp = ( + inp.unflatten(-1, (-1, g, k)) + .transpose(-2, -1) + .flatten(-3) + .unflatten(-1, (-1, b)) + ) + inp = torch.einsum("b i j, ... b j -> ... b i", bi, inp) + # ... d b -> ... (d b) + # ... (c k g) -> ... (c g k) + inp = inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3) + + if rescale is not None: + inp = inp * rescale.transpose(0, -1) + + inp = inp - org + if need_transpose: + inp = inp.transpose(1, -1) + return inp diff --git a/lycoris/functional/diag_oft.py b/lycoris/functional/diag_oft.py new file mode 100644 index 0000000000000000000000000000000000000000..233a6096eaec8d6d87cccd8c774aa84fd62daf8b --- /dev/null +++ b/lycoris/functional/diag_oft.py @@ -0,0 +1,112 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .general import factorization, FUNC_LIST + + +def get_r(oft_blocks, I=None, constraint=0): + if I is None: + I = torch.eye(oft_blocks.shape[-1], device=oft_blocks.device) + if I.ndim < oft_blocks.ndim: + for _ in range(oft_blocks.ndim - I.ndim): + I = I.unsqueeze(0) + # for Q = -Q^T + q = oft_blocks - oft_blocks.transpose(-1, -2) + normed_q = q + if constraint is not None and constraint > 0: + q_norm = torch.norm(q) + 1e-8 + if q_norm > constraint: + normed_q = q * constraint / q_norm + # use float() to prevent unsupported type + r = (I + normed_q) @ (I - normed_q).float().inverse() + return r + + +def weight_gen(org_weight, max_block_size=-1, rescale=False): + """### weight_gen + + Args: + org_weight (torch.Tensor): the weight tensor + max_block_size (int): max block size + rescale (bool, optional): whether to rescale the weight. Defaults to False. + + Returns: + torch.Tensor: oft_blocks[, rescale_weight] + """ + out_dim, *rest = org_weight.shape + block_size, block_num = factorization(out_dim, max_block_size) + oft_blocks = torch.zeros(block_num, block_size, block_size) + if rescale: + return oft_blocks, torch.ones(out_dim, *[1] * len(rest)) + else: + return oft_blocks, None + + +def diff_weight(org_weight, *weights, constraint=None): + """### diff_weight + + Args: + org_weight (torch.Tensor): the weight tensor of original model + weights (tuple[torch.Tensor]): (oft_blocks[, rescale_weight]) + constraint (float, optional): constraint for oft + + Returns: + torch.Tensor: ΔW + """ + oft_blocks, rescale = weights + I = torch.eye(oft_blocks.shape[1], device=oft_blocks.device) + r = get_r(oft_blocks, I, constraint) + + block_num, block_size, _ = oft_blocks.shape + _, *shape = org_weight.shape + org_weight = org_weight.to(dtype=r.dtype) + org_weight = org_weight.view(block_num, block_size, *shape) + # Init R=0, so add I on it to ensure the output of step0 is original model output + weight = torch.einsum( + "k n m, k n ... -> k m ...", + r - I, + org_weight, + ).view(-1, *shape) + if rescale is not None: + weight = rescale * weight + weight = weight + (rescale - 1) * org_weight + return weight + + +def bypass_forward_diff(x, org_out, *weights, constraint=None, need_transpose=False): + """### bypass_forward_diff + + Args: + x (torch.Tensor): the input tensor for original model + org_out (torch.Tensor): the output tensor from original model + weights (tuple[torch.Tensor]): (oft_blocks[, rescale_weight]) + constraint (float, optional): constraint for oft + need_transpose (bool, optional): + whether to transpose the input and output, + set to `True` if the original model have "dim" not in the last axis. + For example: Convolution layers + + Returns: + torch.Tensor: output tensor + """ + oft_blocks, rescale = weights + block_num, block_size, _ = oft_blocks.shape + I = torch.eye(block_size, device=oft_blocks.device) + r = get_r(oft_blocks, I, constraint) + if need_transpose: + org_out = org_out.transpose(1, -1) + org_out = org_out.to(dtype=r.dtype) + *shape, _ = org_out.shape + oft_out = torch.einsum( + "k n m, ... k n -> ... k m", r - I, org_out.view(*shape, block_num, block_size) + ) + out = oft_out.view(*shape, -1) + if rescale is not None: + out = rescale.transpose(-1, 0) * out + out = out + (rescale - 1).transpose(-1, 0) * org_out + if need_transpose: + out = out.transpose(1, -1) + return out diff --git a/lycoris/functional/general.py b/lycoris/functional/general.py new file mode 100644 index 0000000000000000000000000000000000000000..dc0f8a04dfa6d692b518b5f1a210a25bd1098ee0 --- /dev/null +++ b/lycoris/functional/general.py @@ -0,0 +1,108 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + + +def rebuild_tucker(t, wa, wb): + rebuild2 = torch.einsum("i j ..., i p, j r -> p r ...", t, wa, wb) + return rebuild2 + + +def factorization(dimension: int, factor: int = -1) -> tuple[int, int]: + """ + return a tuple of two value of input dimension decomposed by the number closest to factor + second value is higher or equal than first value. + + In LoRA with Kroneckor Product, first value is a value for weight scale. + second value is a value for weight. + + Because of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. + + examples) + factor + -1 2 4 8 16 ... + 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 + 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 + 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 + 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 + 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 + 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 + """ + + if factor > 0 and (dimension % factor) == 0: + m = factor + n = dimension // factor + if m > n: + n, m = m, n + return m, n + if factor < 0: + factor = dimension + m, n = 1, dimension + length = m + n + while m < n: + new_m = m + 1 + while dimension % new_m != 0: + new_m += 1 + new_n = dimension // new_m + if new_m + new_n > length or new_m > factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n + + +def power2factorization(dimension: int, factor: int = -1) -> tuple[int, int]: + """ + m = 2k + n = 2**p + m*n = dim + """ + if factor == -1: + factor = dimension + + # Find the first solution and check if it is even doable + m = n = 0 + while m <= factor: + m += 2 + while dimension % m != 0 and m < dimension: + m += 2 + if m > factor: + break + if sum(int(i) for i in f"{dimension//m:b}") == 1: + n = dimension // m + + if n == 0: + return None, n + return dimension // n, n + + +def tucker_weight_from_conv(up, down, mid): + up = up.reshape(up.size(0), up.size(1)) + down = down.reshape(down.size(0), down.size(1)) + return torch.einsum("m n ..., i m, n j -> i j ...", mid, up, down) + + +def tucker_weight(wa, wb, t): + temp = torch.einsum("i j ..., j r -> i r ...", t, wb) + return torch.einsum("i j ..., i r -> r j ...", temp, wa) + + +def apply_dora_scale(org_weight, rebuild, dora_scale, scale): + dora_norm_dims = org_weight.dim() - 1 + weight = org_weight + rebuild + weight = weight.to(dora_scale.dtype) + weight_norm = ( + weight.transpose(0, 1) + .reshape(weight.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight.shape[1], *[1] * dora_norm_dims) + .transpose(0, 1) + ) + merged_scale1 = weight / weight_norm * dora_scale + diff_weight = merged_scale1 - org_weight + return org_weight + diff_weight * scale diff --git a/lycoris/functional/locon.py b/lycoris/functional/locon.py new file mode 100644 index 0000000000000000000000000000000000000000..756bbd82d7dff35d2ebc150db7a666ec423673df --- /dev/null +++ b/lycoris/functional/locon.py @@ -0,0 +1,85 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .general import rebuild_tucker, FUNC_LIST + + +def weight_gen(org_weight, rank, tucker=True): + """### weight_gen + + Args: + org_weight (torch.Tensor): the weight tensor + rank (int): low rank + + Returns: + torch.Tensor: down, up[, mid] + """ + out_dim, in_dim, *k = org_weight.shape + if k and tucker: + down = torch.empty(rank, in_dim, *(1 for _ in k)) + up = torch.empty(out_dim, rank, *(1 for _ in k)) + mid = torch.empty(rank, rank, *k) + nn.init.kaiming_uniform_(down, a=math.sqrt(5)) + nn.init.constant_(up, 0) + nn.init.kaiming_uniform_(mid, a=math.sqrt(5)) + return down, up, mid + else: + down = torch.empty(rank, in_dim) + up = torch.empty(out_dim, rank) + nn.init.kaiming_uniform_(down, a=math.sqrt(5)) + nn.init.constant_(up, 0) + return down, up, None + + +def diff_weight(*weights: tuple[torch.Tensor], gamma=1.0): + """### diff_weight + + Get ΔW = BA, where BA is low rank decomposition + + Args: + weights (tuple[torch.Tensor]): (down, up[, mid]) + gamma (float, optional): scale factor, normally alpha/rank here + + Returns: + torch.Tensor: ΔW + """ + d, u, m = weights + R, I, *k = d.shape + O, R, *_ = u.shape + u = u * gamma + + if m is None: + result = u.reshape(-1, u.size(1)) @ d.reshape(d.size(0), -1) + else: + R, R, *k = m.shape + u = u.reshape(u.size(0), -1).transpose(0, 1) + d = d.reshape(d.size(0), -1) + result = rebuild_tucker(m, u, d) + return result.reshape(O, I, *k) + + +def bypass_forward_diff(x, org_out, *weights, gamma=1.0, extra_args={}): + """### bypass_forward_diff + + Args: + x (torch.Tensor): input tensor + weights (tuple[torch.Tensor]): (down, up[, mid]) + gamma (float, optional): scale factor, normally alpha/rank here + extra_args (dict, optional): extra args for forward func, \ + e.g. padding, stride for Conv1/2/3d + + Returns: + torch.Tensor: output tensor + """ + d, u, m = weights + if m is not None: + down = FUNC_LIST[d.dim()](x, d) + mid = FUNC_LIST[d.dim()](down, m, **extra_args) + up = FUNC_LIST[d.dim()](mid, u) + else: + down = FUNC_LIST[d.dim()](x, d, **extra_args) + up = FUNC_LIST[d.dim()](down, u) + return up * gamma diff --git a/lycoris/functional/loha.py b/lycoris/functional/loha.py new file mode 100644 index 0000000000000000000000000000000000000000..042e56bdeedb84592949a112c109b777d5f8f27c --- /dev/null +++ b/lycoris/functional/loha.py @@ -0,0 +1,165 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .general import FUNC_LIST + + +class HadaWeight(torch.autograd.Function): + @staticmethod + def forward(ctx, w1d, w1u, w2d, w2u, scale=torch.tensor(1)): + ctx.save_for_backward(w1d, w1u, w2d, w2u, scale) + diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale + return diff_weight + + @staticmethod + def backward(ctx, grad_out): + (w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors + grad_out = grad_out * scale + temp = grad_out * (w2u @ w2d) + grad_w1u = temp @ w1d.T + grad_w1d = w1u.T @ temp + + temp = grad_out * (w1u @ w1d) + grad_w2u = temp @ w2d.T + grad_w2d = w2u.T @ temp + + del temp + return grad_w1d, grad_w1u, grad_w2d, grad_w2u, None + + +class HadaWeightTucker(torch.autograd.Function): + @staticmethod + def forward(ctx, t1, w1d, w1u, t2, w2d, w2u, scale=torch.tensor(1)): + ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale) + + rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u) + rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u) + + return rebuild1 * rebuild2 * scale + + @staticmethod + def backward(ctx, grad_out): + (t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors + grad_out = grad_out * scale + + temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d) + rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u) + + grad_w = rebuild * grad_out + del rebuild + + grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w) + grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T) + del grad_w, temp + + grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp) + grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T) + del grad_temp + + temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d) + rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u) + + grad_w = rebuild * grad_out + del rebuild + + grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w) + grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T) + del grad_w, temp + + grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp) + grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T) + del grad_temp + return grad_t1, grad_w1d, grad_w1u, grad_t2, grad_w2d, grad_w2u, None + + +def make_weight(w1d, w1u, w2d, w2u, scale): + return HadaWeight.apply(w1d, w1u, w2d, w2u, scale) + + +def make_weight_tucker(t1, w1d, w1u, t2, w2d, w2u, scale): + return HadaWeightTucker.apply(t1, w1d, w1u, t2, w2d, w2u, scale) + + +def weight_gen(org_weight, rank, tucker=True): + """### weight_gen + + Args: + org_weight (torch.Tensor): the weight tensor + rank (int): low rank + + Returns: + torch.Tensor: w1d, w2d, w1u, w2u[, t1, t2] + """ + out_dim, in_dim, *k = org_weight.shape + if k and tucker: + w1d = torch.empty(rank, in_dim) + w1u = torch.empty(rank, out_dim) + t1 = torch.empty(rank, rank, *k) + w2d = torch.empty(rank, in_dim) + w2u = torch.empty(rank, out_dim) + t2 = torch.empty(rank, rank, *k) + nn.init.normal_(t1, std=0.1) + nn.init.normal_(t2, std=0.1) + else: + w1d = torch.empty(rank, in_dim) + w1u = torch.empty(out_dim, rank) + w2d = torch.empty(rank, in_dim) + w2u = torch.empty(out_dim, rank) + t1 = t2 = None + nn.init.normal_(w1d, std=1) + nn.init.constant_(w1u, 0) + nn.init.normal_(w2d, std=1) + nn.init.normal_(w2u, std=0.1) + return w1d, w1u, w2d, w2u, t1, t2 + + +def diff_weight(*weights, gamma=1.0): + """### diff_weight + + Get ΔW = BA, where BA is low rank decomposition + + Args: + wegihts (tuple[torch.Tensor]): (w1d, w2d, w1u, w2u[, t1, t2]) + gamma (float, optional): scale factor, normally alpha/rank here + + Returns: + torch.Tensor: ΔW + """ + w1d, w1u, w2d, w2u, t1, t2 = weights + if t1 is not None and t2 is not None: + R, I = w1d.shape + R, O = w1u.shape + R, R, *k = t1.shape + result = make_weight_tucker(t1, w1d, w1u, t2, w2d, w2u, gamma) + else: + R, I, *k = w1d.shape + O, R, *_ = w1u.shape + w1d = w1d.reshape(w1d.size(0), -1) + w1u = w1u.reshape(-1, w1u.size(1)) + w2d = w2d.reshape(w2d.size(0), -1) + w2u = w2u.reshape(-1, w2u.size(1)) + result = make_weight(w1d, w1u, w2d, w2u, gamma) + + result = result.reshape(O, I, *k) + return result + + +def bypass_forward_diff(x, org_out, *weights, gamma=1.0, extra_args={}): + """### bypass_forward_diff + + Args: + x (torch.Tensor): input tensor + weights (tuple[torch.Tensor]): (w1d, w2d, w1u, w2u[, t1, t2]) + gamma (float, optional): scale factor, normally alpha/rank here + extra_args (dict, optional): extra args for forward func, \ + e.g. padding, stride for Conv1/2/3d + + Returns: + torch.Tensor: output tensor + """ + w1d, w1u, w2d, w2u, t1, t2 = weights + diff_w = diff_weight(w1d, w1u, w2d, w2u, t1, t2, gamma) + return FUNC_LIST[w1d.dim() if t1 is None else t1.dim()](x, diff_w, **extra_args) diff --git a/lycoris/functional/lokr.py b/lycoris/functional/lokr.py new file mode 100644 index 0000000000000000000000000000000000000000..75720ed61a9aa5bf0a45008fc9bca71a7373fd2b --- /dev/null +++ b/lycoris/functional/lokr.py @@ -0,0 +1,247 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .general import rebuild_tucker, FUNC_LIST +from .general import factorization + + +def make_kron(w1, w2, scale): + for _ in range(w2.dim() - w1.dim()): + w1 = w1.unsqueeze(-1) + w2 = w2.contiguous() + rebuild = torch.kron(w1, w2) + + if scale != 1: + rebuild = rebuild * scale + + return rebuild + + +def weight_gen( + org_weight, + rank, + tucker=True, + factor=-1, + decompose_both=False, + full_matrix=False, + unbalanced_factorization=False, +): + """### weight_gen + + Args: + org_weight (torch.Tensor): the weight tensor + rank (int): low rank + + Returns: + torch.Tensor | None: w1, w1a, w1b, w2, w2a, w2b, t2 + """ + out_dim, in_dim, *k = org_weight.shape + w1 = w1a = w1b = None + w2 = w2a = w2b = None + t2 = None + use_w1 = use_w2 = False + + if k: + k_size = k + shape = (out_dim, in_dim, *k_size) + + in_m, in_n = factorization(in_dim, factor) + out_l, out_k = factorization(out_dim, factor) + if unbalanced_factorization: + out_l, out_k = out_k, out_l + shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), *k_size) + tucker = tucker and any(i != 1 for i in k_size) + if ( + decompose_both + and rank < max(shape[0][0], shape[1][0]) / 2 + and not full_matrix + ): + w1a = torch.empty(shape[0][0], rank) + w1b = torch.empty(rank, shape[1][0]) + else: + use_w1 = True + w1 = torch.empty(shape[0][0], shape[1][0]) # a*c, 1-mode + + if rank >= max(shape[0][1], shape[1][1]) / 2 or full_matrix: + use_w2 = True + w2 = torch.empty(shape[0][1], shape[1][1], *k_size) + elif tucker: + t2 = torch.empty(rank, rank, *shape[2:]) + w2a = torch.empty(rank, shape[0][1]) # b, 1-mode + w2b = torch.empty(rank, shape[1][1]) # d, 2-mode + else: # Conv2d not tucker + # bigger part. weight and LoRA. [b, dim] x [dim, d*k1*k2] + w2a = torch.empty(shape[0][1], rank) + w2b = torch.empty(rank, shape[1][1], *shape[2:]) + # w1 ⊗ (w2a x w2b) = (a, b)⊗((c, dim)x(dim, d*k1*k2)) = (a, b)⊗(c, d*k1*k2) = (ac, bd*k1*k2) + else: # Linear + shape = (out_dim, in_dim) + + in_m, in_n = factorization(in_dim, factor) + out_l, out_k = factorization(out_dim, factor) + if unbalanced_factorization: + out_l, out_k = out_k, out_l + shape = ( + (out_l, out_k), + (in_m, in_n), + ) # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d + # smaller part. weight scale + if decompose_both and rank < max(shape[0][0], shape[1][0]) / 2: + w1a = torch.empty(shape[0][0], rank) + w1b = torch.empty(rank, shape[1][0]) + else: + use_w1 = True + w1 = torch.empty(shape[0][0], shape[1][0]) # a*c, 1-mode + if rank < max(shape[0][1], shape[1][1]) / 2: + # bigger part. weight and LoRA. [b, dim] x [dim, d] + w2a = torch.empty(shape[0][1], rank) + w2b = torch.empty(rank, shape[1][1]) + # w1 ⊗ (w2a x w2b) = (a, b)⊗((c, dim)x(dim, d)) = (a, b)⊗(c, d) = (ac, bd) + else: + use_w2 = True + w2 = torch.empty(shape[0][1], shape[1][1]) + + if use_w2: + torch.nn.init.constant_(w2, 1) + else: + if tucker: + torch.nn.init.kaiming_uniform_(t2, a=math.sqrt(5)) + torch.nn.init.kaiming_uniform_(w2a, a=math.sqrt(5)) + torch.nn.init.constant_(w2b, 1) + + if use_w1: + torch.nn.init.kaiming_uniform_(w1, a=math.sqrt(5)) + else: + torch.nn.init.kaiming_uniform_(w1a, a=math.sqrt(5)) + torch.nn.init.kaiming_uniform_(w1b, a=math.sqrt(5)) + + return w1, w1a, w1b, w2, w2a, w2b, t2 + + +def diff_weight(*weights, gamma=1.0): + """### diff_weight + + Args: + weights (tuple[torch.Tensor]): (w1, w1a, w1b, w2, w2a, w2b, t) + gamma (float, optional): scale factor, normally alpha/rank here + + Returns: + torch.Tensor: ΔW + """ + w1, w1a, w1b, w2, w2a, w2b, t = weights + if w1a is not None: + rank = w1a.shape[1] + elif w2a is not None: + rank = w2a.shape[1] + else: + rank = gamma + scale = gamma / rank + if w1 is None: + w1 = w1a @ w1b + if w2 is None: + if t is None: + r, o, *k = w2b.shape + w2 = w2a @ w2b.view(r, -1) + w2 = w2.view(-1, o, *k) + else: + w2 = rebuild_tucker(t, w2a, w2b) + return make_kron(w1, w2, scale) + + +def bypass_forward_diff(h, org_out, *weights, gamma=1.0, extra_args={}): + """### bypass_forward_diff + + Args: + weights (tuple[torch.Tensor]): (w1, w1a, w1b, w2, w2a, w2b, t) + gamma (float, optional): scale factor, normally alpha/rank here + extra_args (dict, optional): extra args for forward func, \ + e.g. padding, stride for Conv1/2/3d + + Returns: + torch.Tensor: output tensor + """ + w1, w1a, w1b, w2, w2a, w2b, t = weights + use_w1 = w1 is not None + use_w2 = w2 is not None + tucker = t is not None + dim = t.dim() if tucker else w2.dim() if w2 is not None else w2b.dim() + rank = w1b.size(0) if not use_w1 else w2b.size(0) if not use_w2 else gamma + scale = gamma / rank + is_conv = dim > 2 + op = FUNC_LIST[dim] + + if is_conv: + kw_dict = extra_args + else: + kw_dict = {} + + if use_w2: + ba = w2 + else: + a = w2b + b = w2a + + if t is not None: + a = a.view(*a.shape, *[1] * (dim - 2)) + b = b.view(*b.shape, *[1] * (dim - 2)) + elif is_conv: + b = b.view(*b.shape, *[1] * (dim - 2)) + + if use_w1: + c = w1 + else: + c = w1a @ w1b + uq = c.size(1) + + if is_conv: + # (b, uq), vq, ... + B, _, *rest = h.shape + h_in_group = h.reshape(B * uq, -1, *rest) + else: + # b, ..., uq, vq + h_in_group = h.reshape(*h.shape[:-1], uq, -1) + + if use_w2: + hb = op(h_in_group, ba, **kw_dict) + else: + if is_conv: + if tucker: + ha = op(h_in_group, a) + ht = op(ha, t, **kw_dict) + hb = op(ht, b) + else: + ha = op(h_in_group, a, **kw_dict) + hb = op(ha, b) + else: + ha = op(h_in_group, a, **kw_dict) + hb = op(ha, b) + + if is_conv: + # (b, uq), vp, ..., f + # -> b, uq, vp, ..., f + # -> b, f, vp, ..., uq + hb = hb.view(B, -1, *hb.shape[1:]) + h_cross_group = hb.transpose(1, -1) + else: + # b, ..., uq, vq + # -> b, ..., vq, uq + h_cross_group = hb.transpose(-1, -2) + + hc = F.linear(h_cross_group, c) + if is_conv: + # b, f, vp, ..., up + # -> b, up, vp, ... ,f + # -> b, c, ..., f + hc = hc.transpose(1, -1) + h = hc.reshape(B, -1, *hc.shape[3:]) + else: + # b, ..., vp, up + # -> b, ..., up, vp + # -> b, ..., c + hc = hc.transpose(-1, -2) + h = hc.reshape(*hc.shape[:-2], -1) + + return h * scale diff --git a/lycoris/kohya.py b/lycoris/kohya.py new file mode 100644 index 0000000000000000000000000000000000000000..3f31872b0fd6fe944df7fd30718618a4675e0290 --- /dev/null +++ b/lycoris/kohya.py @@ -0,0 +1,676 @@ +import os +import fnmatch +import re +import logging + +from typing import Any, List + +import torch + +from .utils import precalculate_safetensors_hashes +from .wrapper import LycorisNetwork, network_module_dict, deprecated_arg_dict +from .modules.locon import LoConModule +from .modules.loha import LohaModule +from .modules.ia3 import IA3Module +from .modules.lokr import LokrModule +from .modules.dylora import DyLoraModule +from .modules.glora import GLoRAModule +from .modules.norms import NormModule +from .modules.full import FullModule +from .modules.diag_oft import DiagOFTModule +from .modules.boft import ButterflyOFTModule +from .modules import make_module, get_module + +from .config import PRESET +from .utils.preset import read_preset +from .utils import str_bool +from .logging import logger + + +def create_network( + multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs +): + for key, value in list(kwargs.items()): + if key in deprecated_arg_dict: + logger.warning( + f"{key} is deprecated. Please use {deprecated_arg_dict[key]} instead.", + stacklevel=2, + ) + kwargs[deprecated_arg_dict[key]] = value + if network_dim is None: + network_dim = 4 # default + conv_dim = int(kwargs.get("conv_dim", network_dim) or network_dim) + conv_alpha = float(kwargs.get("conv_alpha", network_alpha) or network_alpha) + dropout = float(kwargs.get("dropout", 0.0) or 0.0) + rank_dropout = float(kwargs.get("rank_dropout", 0.0) or 0.0) + module_dropout = float(kwargs.get("module_dropout", 0.0) or 0.0) + algo = (kwargs.get("algo", "lora") or "lora").lower() + use_tucker = str_bool( + not kwargs.get("disable_conv_cp", True) + or kwargs.get("use_conv_cp", False) + or kwargs.get("use_cp", False) + or kwargs.get("use_tucker", False) + ) + use_scalar = str_bool(kwargs.get("use_scalar", False)) + block_size = int(kwargs.get("block_size", None) or 4) + train_norm = str_bool(kwargs.get("train_norm", False)) + constraint = float(kwargs.get("constraint", None) or 0) + rescaled = str_bool(kwargs.get("rescaled", False)) + weight_decompose = str_bool(kwargs.get("dora_wd", False)) + wd_on_output = str_bool(kwargs.get("wd_on_output", False)) + full_matrix = str_bool(kwargs.get("full_matrix", False)) + bypass_mode = str_bool(kwargs.get("bypass_mode", None)) + rs_lora = str_bool(kwargs.get("rs_lora", False)) + unbalanced_factorization = str_bool(kwargs.get("unbalanced_factorization", False)) + train_t5xxl = str_bool(kwargs.get("train_t5xxl", False)) + + if unbalanced_factorization: + logger.info("Unbalanced factorization for LoKr is enabled") + + if bypass_mode: + logger.info("Bypass mode is enabled") + + if weight_decompose: + logger.info("Weight decomposition is enabled") + + if full_matrix: + logger.info("Full matrix mode for LoKr is enabled") + + preset_str = kwargs.get("preset", "full") + if preset_str not in PRESET: + preset = read_preset(preset_str) + else: + preset = PRESET[preset_str] + assert preset is not None + LycorisNetworkKohya.apply_preset(preset) + + logger.info(f"Using rank adaptation algo: {algo}") + + if algo == "ia3" and preset_str != "ia3": + logger.warning("It is recommended to use preset ia3 for IA^3 algorithm") + + network = LycorisNetworkKohya( + text_encoder, + unet, + multiplier=multiplier, + lora_dim=network_dim, + conv_lora_dim=conv_dim, + alpha=network_alpha, + conv_alpha=conv_alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + use_tucker=use_tucker, + use_scalar=use_scalar, + network_module=algo, + train_norm=train_norm, + decompose_both=kwargs.get("decompose_both", False), + factor=kwargs.get("factor", -1), + block_size=block_size, + constraint=constraint, + rescaled=rescaled, + weight_decompose=weight_decompose, + wd_on_out=wd_on_output, + full_matrix=full_matrix, + bypass_mode=bypass_mode, + rs_lora=rs_lora, + unbalanced_factorization=unbalanced_factorization, + train_t5xxl=train_t5xxl, + ) + + return network + + +def create_network_from_weights( + multiplier, + file, + vae, + text_encoder, + unet, + weights_sd=None, + for_inference=False, + **kwargs, +): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping + unet_loras = {} + te_loras = {} + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if lora_name.startswith(LycorisNetworkKohya.LORA_PREFIX_UNET): + unet_loras[lora_name] = None + elif lora_name.startswith(LycorisNetworkKohya.LORA_PREFIX_TEXT_ENCODER): + te_loras[lora_name] = None + + for name, modules in unet.named_modules(): + lora_name = f"{LycorisNetworkKohya.LORA_PREFIX_UNET}_{name}".replace(".", "_") + if lora_name in unet_loras: + unet_loras[lora_name] = modules + + if isinstance(text_encoder, list): + text_encoders = text_encoder + use_index = True + else: + text_encoders = [text_encoder] + use_index = False + + for idx, te in enumerate(text_encoders): + if use_index: + prefix = f"{LycorisNetworkKohya.LORA_PREFIX_TEXT_ENCODER}{idx+1}" + else: + prefix = LycorisNetworkKohya.LORA_PREFIX_TEXT_ENCODER + for name, modules in te.named_modules(): + lora_name = f"{prefix}_{name}".replace(".", "_") + if lora_name in te_loras: + te_loras[lora_name] = modules + + original_level = logger.level + logger.setLevel(logging.ERROR) + network = LycorisNetworkKohya(text_encoder, unet) + network.unet_loras = [] + network.text_encoder_loras = [] + logger.setLevel(original_level) + + logger.info("Loading UNet Modules from state dict...") + for lora_name, orig_modules in unet_loras.items(): + if orig_modules is None: + continue + lyco_type, params = get_module(weights_sd, lora_name) + module = make_module(lyco_type, params, lora_name, orig_modules) + if module is not None: + network.unet_loras.append(module) + logger.info(f"{len(network.unet_loras)} Modules Loaded") + + logger.info("Loading TE Modules from state dict...") + for lora_name, orig_modules in te_loras.items(): + if orig_modules is None: + continue + lyco_type, params = get_module(weights_sd, lora_name) + module = make_module(lyco_type, params, lora_name, orig_modules) + if module is not None: + network.text_encoder_loras.append(module) + logger.info(f"{len(network.text_encoder_loras)} Modules Loaded") + + for lora in network.unet_loras + network.text_encoder_loras: + lora.multiplier = multiplier + + return network, weights_sd + + +class LycorisNetworkKohya(LycorisNetwork): + """ + LoRA + LoCon + """ + + # Ignore proj_in or proj_out, their channels is only a few. + ENABLE_CONV = True + UNET_TARGET_REPLACE_MODULE = [ + "Transformer2DModel", + "ResnetBlock2D", + "Downsample2D", + "Upsample2D", + "HunYuanDiTBlock", + "DoubleStreamBlock", + "SingleStreamBlock", + "SingleDiTBlock", + "MMDoubleStreamBlock", #HunYuanVideo + "MMSingleStreamBlock", #HunYuanVideo + ] + UNET_TARGET_REPLACE_NAME = [ + "conv_in", + "conv_out", + "time_embedding.linear_1", + "time_embedding.linear_2", + ] + TEXT_ENCODER_TARGET_REPLACE_MODULE = [ + "CLIPAttention", + "CLIPSdpaAttention", + "CLIPMLP", + "MT5Block", + "BertLayer", + ] + TEXT_ENCODER_TARGET_REPLACE_NAME = [] + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + MODULE_ALGO_MAP = {} + NAME_ALGO_MAP = {} + USE_FNMATCH = False + + @classmethod + def apply_preset(cls, preset): + if "enable_conv" in preset: + cls.ENABLE_CONV = preset["enable_conv"] + if "unet_target_module" in preset: + cls.UNET_TARGET_REPLACE_MODULE = preset["unet_target_module"] + if "unet_target_name" in preset: + cls.UNET_TARGET_REPLACE_NAME = preset["unet_target_name"] + if "text_encoder_target_module" in preset: + cls.TEXT_ENCODER_TARGET_REPLACE_MODULE = preset[ + "text_encoder_target_module" + ] + if "text_encoder_target_name" in preset: + cls.TEXT_ENCODER_TARGET_REPLACE_NAME = preset["text_encoder_target_name"] + if "module_algo_map" in preset: + cls.MODULE_ALGO_MAP = preset["module_algo_map"] + if "name_algo_map" in preset: + cls.NAME_ALGO_MAP = preset["name_algo_map"] + if "use_fnmatch" in preset: + cls.USE_FNMATCH = preset["use_fnmatch"] + return cls + + def __init__( + self, + text_encoder, + unet, + multiplier=1.0, + lora_dim=4, + conv_lora_dim=4, + alpha=1, + conv_alpha=1, + use_tucker=False, + dropout=0, + rank_dropout=0, + module_dropout=0, + network_module: str = "locon", + norm_modules=NormModule, + train_norm=False, + train_t5xxl=False, + **kwargs, + ) -> None: + torch.nn.Module.__init__(self) + root_kwargs = kwargs + self.multiplier = multiplier + self.lora_dim = lora_dim + self.train_t5xxl = train_t5xxl + + if not self.ENABLE_CONV: + conv_lora_dim = 0 + + self.conv_lora_dim = int(conv_lora_dim) + if self.conv_lora_dim and self.conv_lora_dim != self.lora_dim: + logger.info("Apply different lora dim for conv layer") + logger.info(f"Conv Dim: {conv_lora_dim}, Linear Dim: {lora_dim}") + elif self.conv_lora_dim == 0: + logger.info("Disable conv layer") + + self.alpha = alpha + self.conv_alpha = float(conv_alpha) + if self.conv_lora_dim and self.alpha != self.conv_alpha: + logger.info("Apply different alpha value for conv layer") + logger.info(f"Conv alpha: {conv_alpha}, Linear alpha: {alpha}") + + if 1 >= dropout >= 0: + logger.info(f"Use Dropout value: {dropout}") + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + self.use_tucker = use_tucker + + def create_single_module( + lora_name: str, + module: torch.nn.Module, + algo_name, + dim=None, + alpha=None, + use_tucker=self.use_tucker, + **kwargs, + ): + for k, v in root_kwargs.items(): + if k in kwargs: + continue + kwargs[k] = v + + if train_norm and "Norm" in module.__class__.__name__: + return norm_modules( + lora_name, + module, + self.multiplier, + self.rank_dropout, + self.module_dropout, + **kwargs, + ) + lora = None + if isinstance(module, torch.nn.Linear) and lora_dim > 0: + dim = dim or lora_dim + alpha = alpha or self.alpha + elif isinstance( + module, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d) + ): + k_size, *_ = module.kernel_size + if k_size == 1 and lora_dim > 0: + dim = dim or lora_dim + alpha = alpha or self.alpha + elif conv_lora_dim > 0 or dim: + dim = dim or conv_lora_dim + alpha = alpha or self.conv_alpha + else: + return None + else: + return None + lora = network_module_dict[algo_name]( + lora_name, + module, + self.multiplier, + dim, + alpha, + self.dropout, + self.rank_dropout, + self.module_dropout, + use_tucker, + **kwargs, + ) + return lora + + def create_modules_( + prefix: str, + root_module: torch.nn.Module, + algo, + configs={}, + ): + loras = {} + lora_names = [] + for name, module in root_module.named_modules(): + module_name = module.__class__.__name__ + if module_name in self.MODULE_ALGO_MAP and module is not root_module: + next_config = self.MODULE_ALGO_MAP[module_name] + next_algo = next_config.get("algo", algo) + new_loras, new_lora_names = create_modules_( + f"{prefix}_{name}", module, next_algo, next_config + ) + for lora_name, lora in zip(new_lora_names, new_loras): + if lora_name not in loras: + loras[lora_name] = lora + lora_names.append(lora_name) + continue + if name: + lora_name = prefix + "." + name + else: + lora_name = prefix + lora_name = lora_name.replace(".", "_") + if lora_name in loras: + continue + + lora = create_single_module(lora_name, module, algo, **configs) + if lora is not None: + loras[lora_name] = lora + lora_names.append(lora_name) + return [loras[lora_name] for lora_name in lora_names], lora_names + + # create module instances + def create_modules( + prefix, + root_module: torch.nn.Module, + target_replace_modules, + target_replace_names=[], + ) -> List: + logger.info("Create LyCORIS Module") + loras = [] + next_config = {} + for name, module in root_module.named_modules(): + module_name = module.__class__.__name__ + if module_name in target_replace_modules and not any( + self.match_fn(t, name) for t in target_replace_names + ): + if module_name in self.MODULE_ALGO_MAP: + next_config = self.MODULE_ALGO_MAP[module_name] + algo = next_config.get("algo", network_module) + else: + algo = network_module + loras.extend( + create_modules_(f"{prefix}_{name}", module, algo, next_config)[ + 0 + ] + ) + next_config = {} + elif name in target_replace_names or any( + self.match_fn(t, name) for t in target_replace_names + ): + conf_from_name = self.find_conf_for_name(name) + if conf_from_name is not None: + next_config = conf_from_name + algo = next_config.get("algo", network_module) + elif module_name in self.MODULE_ALGO_MAP: + next_config = self.MODULE_ALGO_MAP[module_name] + algo = next_config.get("algo", network_module) + else: + algo = network_module + lora_name = prefix + "." + name + lora_name = lora_name.replace(".", "_") + lora = create_single_module(lora_name, module, algo, **next_config) + next_config = {} + if lora is not None: + loras.append(lora) + return loras + + if network_module == GLoRAModule: + logger.info("GLoRA enabled, only train transformer") + # only train transformer (for GLoRA) + LycorisNetworkKohya.UNET_TARGET_REPLACE_MODULE = [ + "Transformer2DModel", + "Attention", + ] + LycorisNetworkKohya.UNET_TARGET_REPLACE_NAME = [] + + self.text_encoder_loras = [] + if text_encoder: + if isinstance(text_encoder, list): + text_encoders = text_encoder + use_index = True + else: + text_encoders = [text_encoder] + use_index = False + + for i, te in enumerate(text_encoders): + self.text_encoder_loras.extend( + create_modules( + LycorisNetworkKohya.LORA_PREFIX_TEXT_ENCODER + + (f"{i+1}" if use_index else ""), + te, + LycorisNetworkKohya.TEXT_ENCODER_TARGET_REPLACE_MODULE, + LycorisNetworkKohya.TEXT_ENCODER_TARGET_REPLACE_NAME, + ) + ) + logger.info( + f"create LyCORIS for Text Encoder: {len(self.text_encoder_loras)} modules." + ) + + self.unet_loras = create_modules( + LycorisNetworkKohya.LORA_PREFIX_UNET, + unet, + LycorisNetworkKohya.UNET_TARGET_REPLACE_MODULE, + LycorisNetworkKohya.UNET_TARGET_REPLACE_NAME, + ) + logger.info(f"create LyCORIS for U-Net: {len(self.unet_loras)} modules.") + + algo_table = {} + for lora in self.text_encoder_loras + self.unet_loras: + algo_table[lora.__class__.__name__] = ( + algo_table.get(lora.__class__.__name__, 0) + 1 + ) + logger.info(f"module type table: {algo_table}") + + self.weights_sd = None + + self.loras = self.text_encoder_loras + self.unet_loras + # assertion + names = set() + for lora in self.loras: + assert ( + lora.lora_name not in names + ), f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def match_fn(self, pattern: str, name: str) -> bool: + if self.USE_FNMATCH: + return fnmatch.fnmatch(name, pattern) + return re.match(pattern, name) + + def find_conf_for_name( + self, + name: str, + ) -> dict[str, Any]: + if name in self.NAME_ALGO_MAP.keys(): + return self.NAME_ALGO_MAP[name] + + for key, value in self.NAME_ALGO_MAP.items(): + if self.match_fn(key, name): + return value + + return None + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + self.weights_sd = load_file(file) + else: + self.weights_sd = torch.load(file, map_location="cpu") + missing, unexpected = self.load_state_dict(self.weights_sd, strict=False) + state = {} + if missing: + state["missing keys"] = missing + if unexpected: + state["unexpected keys"] = unexpected + return state + + def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None): + assert ( + apply_text_encoder is not None and apply_unet is not None + ), f"internal error: flag not set" + + if apply_text_encoder: + logger.info("enable LyCORIS for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LyCORIS for U-Net") + else: + self.unet_loras = [] + + self.loras = self.text_encoder_loras + self.unet_loras + + for lora in self.loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + if self.weights_sd: + # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros) + info = self.load_state_dict(self.weights_sd, False) + logger.info(f"weights are loaded: {info}") + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if key.startswith(LycorisNetworkKohya.LORA_PREFIX_TEXT_ENCODER): + apply_text_encoder = True + elif key.startswith(LycorisNetworkKohya.LORA_PREFIX_UNET): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + self.loras = self.text_encoder_loras + self.unet_loras + super().merge_to(1) + + def apply_max_norm_regularization(self, max_norm_value, device): + key_scaled = 0 + norms = [] + for module in self.unet_loras + self.text_encoder_loras: + scaled, norm = module.apply_max_norm(max_norm_value, device) + if scaled is None: + continue + norms.append(norm) + key_scaled += scaled + + if key_scaled == 0: + return 0, 0, 0 + + return key_scaled, sum(norms) / len(norms), max(norms) + + def prepare_optimizer_params(self, text_encoder_lr=None, unet_lr: float = 1e-4, learning_rate=None): + def enumerate_params(loras): + params = [] + for lora in loras: + params.extend(lora.parameters()) + return params + + self.requires_grad_(True) + all_params = [] + lr_descriptions = [] + + if self.text_encoder_loras: + param_data = {"params": enumerate_params(self.text_encoder_loras)} + if text_encoder_lr is not None: + param_data["lr"] = text_encoder_lr + all_params.append(param_data) + lr_descriptions.append("text_encoder") + + if self.unet_loras: + param_data = {"params": enumerate_params(self.unet_loras)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) + lr_descriptions.append("unet") + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + #def on_step_start(self): + # pass + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash = precalculate_safetensors_hashes(state_dict) + metadata["sshs_model_hash"] = model_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) diff --git a/lycoris/logging.py b/lycoris/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..6c51ccca3ab40b9040f2900c9dc7cf3cde5c1d2a --- /dev/null +++ b/lycoris/logging.py @@ -0,0 +1,52 @@ +import sys +import copy +import logging +from functools import cache + + +class ColoredFormatter(logging.Formatter): + COLORS = { + "DEBUG": "\033[0;36m", # CYAN + "INFO": "\033[0;32m", # GREEN + "WARNING": "\033[0;33m", # YELLOW + "ERROR": "\033[0;31m", # RED + "CRITICAL": "\033[0;37;41m", # WHITE ON RED + "RESET": "\033[0m", # RESET COLOR + } + + def format(self, record): + colored_record = copy.copy(record) + levelname = colored_record.levelname + seq = self.COLORS.get(levelname, self.COLORS["RESET"]) + colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" + return super().format(colored_record) + + +logger = logging.getLogger("LyCORIS") +logger.propagate = False +logger.setLevel(logging.INFO) + + +if not logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + ColoredFormatter( + "%(asctime)s|[%(name)s]-%(levelname)s: %(message)s", "%Y-%m-%d %H:%M:%S" + ) + ) + logger.addHandler(handler) + + +@cache +def info_once(msg): + logger.info(msg) + + +@cache +def warning_once(msg): + logger.warning(msg) + + +@cache +def error_once(msg): + logger.error(msg) diff --git a/lycoris/modules/__init__.py b/lycoris/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d0c8bde2ef50c0f9499e92198135e64d772c349 --- /dev/null +++ b/lycoris/modules/__init__.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn + +from .base import LycorisBaseModule +from .locon import LoConModule +from .loha import LohaModule +from .lokr import LokrModule +from .full import FullModule +from .norms import NormModule +from .diag_oft import DiagOFTModule +from .boft import ButterflyOFTModule +from .glora import GLoRAModule +from .dylora import DyLoraModule +from .ia3 import IA3Module + +from ..functional.general import factorization + + +MODULE_LIST = [ + LoConModule, + LohaModule, + IA3Module, + LokrModule, + FullModule, + NormModule, + DiagOFTModule, + ButterflyOFTModule, + GLoRAModule, + DyLoraModule, +] + + +def get_module(lyco_state_dict, lora_name): + for module in MODULE_LIST: + if module.algo_check(lyco_state_dict, lora_name): + return module, tuple(module.extract_state_dict(lyco_state_dict, lora_name)) + return None, None + + +@torch.no_grad() +def make_module(lyco_type: LycorisBaseModule, params, lora_name, orig_module): + try: + module = lyco_type.make_module_from_state_dict(lora_name, orig_module, *params) + except NotImplementedError: + module = None + return module diff --git a/lycoris/modules/base.py b/lycoris/modules/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e8dec35cadde64848e9f3dad0929550268352561 --- /dev/null +++ b/lycoris/modules/base.py @@ -0,0 +1,315 @@ +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.utils.parametrize as parametrize + +from ..utils.quant import QuantLinears, log_bypass, log_suspect + + +class ModuleCustomSD(nn.Module): + def __init__(self): + super().__init__() + self._register_load_state_dict_pre_hook(self.load_weight_prehook) + self.register_load_state_dict_post_hook(self.load_weight_hook) + + def load_weight_prehook( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + pass + + def load_weight_hook(self, module, incompatible_keys): + pass + + def custom_state_dict(self): + return None + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + # TODO: Remove `args` and the parsing logic when BC allows. + if len(args) > 0: + if destination is None: + destination = args[0] + if len(args) > 1 and prefix == "": + prefix = args[1] + if len(args) > 2 and keep_vars is False: + keep_vars = args[2] + # DeprecationWarning is ignored by default + + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + + local_metadata = dict(version=self._version) + if hasattr(destination, "_metadata"): + destination._metadata[prefix[:-1]] = local_metadata + + if (custom_sd := self.custom_state_dict()) is not None: + for k, v in custom_sd.items(): + destination[f"{prefix}{k}"] = v + return destination + else: + return super().state_dict( + *args, destination=destination, prefix=prefix, keep_vars=keep_vars + ) + + +class LycorisBaseModule(ModuleCustomSD): + name: str + dtype_tensor: torch.Tensor + support_module = {} + weight_list = [] + weight_list_det = [] + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + dropout=0.0, + rank_dropout=0.0, + module_dropout=0.0, + rank_dropout_scale=False, + bypass_mode=None, + **kwargs, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + self.not_supported = False + + self.module = type(org_module) + if isinstance(org_module, nn.Linear): + self.module_type = "linear" + self.shape = (org_module.out_features, org_module.in_features) + self.op = F.linear + self.dim = org_module.out_features + self.kw_dict = {} + elif isinstance(org_module, nn.Conv1d): + self.module_type = "conv1d" + self.shape = ( + org_module.out_channels, + org_module.in_channels, + *org_module.kernel_size, + ) + self.op = F.conv1d + self.dim = org_module.out_channels + self.kw_dict = { + "stride": org_module.stride, + "padding": org_module.padding, + "dilation": org_module.dilation, + "groups": org_module.groups, + } + elif isinstance(org_module, nn.Conv2d): + self.module_type = "conv2d" + self.shape = ( + org_module.out_channels, + org_module.in_channels, + *org_module.kernel_size, + ) + self.op = F.conv2d + self.dim = org_module.out_channels + self.kw_dict = { + "stride": org_module.stride, + "padding": org_module.padding, + "dilation": org_module.dilation, + "groups": org_module.groups, + } + elif isinstance(org_module, nn.Conv3d): + self.module_type = "conv3d" + self.shape = ( + org_module.out_channels, + org_module.in_channels, + *org_module.kernel_size, + ) + self.op = F.conv3d + self.dim = org_module.out_channels + self.kw_dict = { + "stride": org_module.stride, + "padding": org_module.padding, + "dilation": org_module.dilation, + "groups": org_module.groups, + } + elif isinstance(org_module, nn.LayerNorm): + self.module_type = "layernorm" + self.shape = tuple(org_module.normalized_shape) + self.op = F.layer_norm + self.dim = org_module.normalized_shape[0] + self.kw_dict = { + "normalized_shape": org_module.normalized_shape, + "eps": org_module.eps, + } + elif isinstance(org_module, nn.GroupNorm): + self.module_type = "groupnorm" + self.shape = (org_module.num_channels,) + self.op = F.group_norm + self.group_num = org_module.num_groups + self.dim = org_module.num_channels + self.kw_dict = {"num_groups": org_module.num_groups, "eps": org_module.eps} + else: + self.not_supported = True + self.module_type = "unknown" + + self.register_buffer("dtype_tensor", torch.tensor(0.0), persistent=False) + + self.is_quant = False + if isinstance(org_module, QuantLinears): + if not bypass_mode: + log_bypass() + self.is_quant = True + bypass_mode = True + if ( + isinstance(org_module, nn.Linear) + and org_module.__class__.__name__ != "Linear" + ): + if bypass_mode is None: + log_suspect() + bypass_mode = True + if bypass_mode == True: + self.is_quant = True + self.bypass_mode = bypass_mode + self.dropout = dropout + self.rank_dropout = rank_dropout + self.rank_dropout_scale = rank_dropout_scale + self.module_dropout = module_dropout + + ## Dropout things + # Since LoKr/LoHa/OFT/BOFT are hard to follow the rank_dropout definition from kohya + # We redefine the dropout procedure here. + # g(x) = WX + drop(Brank_drop(AX)) for LoCon(lora), bypass + # g(x) = WX + drop(ΔWX) for any algo except LoCon(lora), bypass + # g(x) = (W + Brank_drop(A))X for LoCon(lora), rebuid + # g(x) = (W + rank_drop(ΔW))X for any algo except LoCon(lora), rebuild + self.drop = nn.Identity() if dropout == 0 else nn.Dropout(dropout) + self.rank_drop = ( + nn.Identity() if rank_dropout == 0 else nn.Dropout(rank_dropout) + ) + + self.multiplier = multiplier + self.org_forward = org_module.forward + self.org_module = [org_module] + + @classmethod + def parametrize(cls, org_module, attr, *args, **kwargs): + from .full import FullModule + + if cls is FullModule: + raise RuntimeError("FullModule cannot be used for parametrize.") + target_param = getattr(org_module, attr) + kwargs["bypass_mode"] = False + if target_param.dim() == 2: + proxy_module = nn.Linear( + target_param.shape[0], target_param.shape[1], bias=False + ) + proxy_module.weight = target_param + elif target_param.dim() > 2: + module_type = [ + None, + None, + None, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + None, + None, + ][target_param.dim()] + proxy_module = module_type( + target_param.shape[0], + target_param.shape[1], + *target_param.shape[2:], + bias=False, + ) + proxy_module.weight = target_param + module_obj = cls("", proxy_module, *args, **kwargs) + module_obj.forward = module_obj.parametrize_forward + module_obj.to(target_param) + parametrize.register_parametrization(org_module, attr, module_obj) + return module_obj + + @classmethod + def algo_check(cls, state_dict, lora_name): + return any(f"{lora_name}.{k}" in state_dict for k in cls.weight_list_det) + + @classmethod + def extract_state_dict(cls, state_dict, lora_name): + return [state_dict.get(f"{lora_name}.{k}", None) for k in cls.weight_list] + + @classmethod + def make_module_from_state_dict(cls, lora_name, orig_module, *weights): + raise NotImplementedError + + @property + def dtype(self): + return self.dtype_tensor.dtype + + @property + def device(self): + return self.dtype_tensor.device + + @property + def org_weight(self): + return self.org_module[0].weight + + @org_weight.setter + def org_weight(self, value): + self.org_module[0].weight.data.copy_(value) + + def apply_to(self, **kwargs): + if self.not_supported: + return + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + + def restore(self): + if self.not_supported: + return + self.org_module[0].forward = self.org_forward + + def merge_to(self, multiplier=1.0): + if self.not_supported: + return + self_device = next(self.parameters()).device + self_dtype = next(self.parameters()).dtype + self.to(self.org_weight) + weight, bias = self.get_merged_weight( + multiplier, self.org_weight.shape, self.org_weight.device + ) + self.org_weight = weight.to(self.org_weight) + if bias is not None: + bias = bias.to(self.org_weight) + if self.org_module[0].bias is not None: + self.org_module[0].bias.data.copy_(bias) + else: + self.org_module[0].bias = nn.Parameter(bias) + self.to(self_device, self_dtype) + + def get_diff_weight(self, multiplier=1.0, shape=None, device=None): + raise NotImplementedError + + def get_merged_weight(self, multiplier=1.0, shape=None, device=None): + raise NotImplementedError + + @torch.no_grad() + def apply_max_norm(self, max_norm, device=None): + return None, None + + def bypass_forward_diff(self, x, scale=1): + raise NotImplementedError + + def bypass_forward(self, x, scale=1): + raise NotImplementedError + + def parametrize_forward(self, x: torch.Tensor, *args, **kwargs): + return self.get_merged_weight( + multiplier=self.multiplier, shape=x.shape, device=x.device + )[0].to(x.dtype) + + def forward(self, *args, **kwargs): + raise NotImplementedError diff --git a/lycoris/modules/boft.py b/lycoris/modules/boft.py new file mode 100644 index 0000000000000000000000000000000000000000..c547527a5c711934d3fb1cc427a15fd19caacb16 --- /dev/null +++ b/lycoris/modules/boft.py @@ -0,0 +1,255 @@ +from functools import cache +from math import log2 + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + +from .base import LycorisBaseModule +from ..functional import power2factorization +from ..logging import logger + + +@cache +def log_butterfly_factorize(dim, factor, result): + logger.info( + f"Use BOFT({int(log2(result[1]))}, {result[0]//2})" + f" (equivalent to factor={result[0]}) " + f"for {dim=} and {factor=}" + ) + + +def butterfly_factor(dimension: int, factor: int = -1) -> tuple[int, int]: + m, n = power2factorization(dimension, factor) + + if n == 0: + raise ValueError( + f"It is impossible to decompose {dimension} with factor {factor} under BOFT constraints." + ) + + log_butterfly_factorize(dimension, factor, (m, n)) + return m, n + + +class ButterflyOFTModule(LycorisBaseModule): + name = "boft" + support_module = { + "linear", + "conv1d", + "conv2d", + "conv3d", + } + weight_list = [ + "oft_blocks", + "rescale", + "alpha", + ] + weight_list_det = ["oft_blocks"] + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=0.0, + rank_dropout=0.0, + module_dropout=0.0, + use_tucker=False, + use_scalar=False, + rank_dropout_scale=False, + constraint=0, + rescaled=False, + bypass_mode=None, + **kwargs, + ): + super().__init__( + lora_name, + org_module, + multiplier, + dropout, + rank_dropout, + module_dropout, + rank_dropout_scale, + bypass_mode, + ) + if self.module_type not in self.support_module: + raise ValueError(f"{self.module_type} is not supported in BOFT algo.") + + out_dim = self.dim + b, m_exp = butterfly_factor(out_dim, lora_dim) + self.block_size = b + self.block_num = m_exp + # BOFT(m, b) + self.boft_b = b + self.boft_m = sum(int(i) for i in f"{m_exp-1:b}") + 1 + # block_num > block_size + self.rescaled = rescaled + self.constraint = constraint * out_dim + self.register_buffer("alpha", torch.tensor(constraint)) + self.oft_blocks = nn.Parameter( + torch.zeros(self.boft_m, self.block_num, self.block_size, self.block_size) + ) + if rescaled: + self.rescale = nn.Parameter( + torch.ones(out_dim, *(1 for _ in range(org_module.weight.dim() - 1))) + ) + + @classmethod + def algo_check(cls, state_dict, lora_name): + if f"{lora_name}.oft_blocks" in state_dict: + oft_blocks = state_dict[f"{lora_name}.oft_blocks"] + if oft_blocks.ndim == 4: + return True + return False + + @classmethod + def make_module_from_state_dict( + cls, lora_name, orig_module, oft_blocks, rescale, alpha + ): + m, n, s, _ = oft_blocks.shape + module = cls( + lora_name, + orig_module, + 1, + lora_dim=s, + constraint=float(alpha), + rescaled=rescale is not None, + ) + module.oft_blocks.copy_(oft_blocks) + if rescale is not None: + module.rescale.copy_(rescale) + return module + + @property + def I(self): + return torch.eye(self.block_size, device=self.device) + + def get_r(self): + I = self.I + # for Q = -Q^T + q = self.oft_blocks - self.oft_blocks.transpose(-1, -2) + normed_q = q + # Diag OFT style constrain + if self.constraint > 0: + q_norm = torch.norm(q) + 1e-8 + if q_norm > self.constraint: + normed_q = q * self.constraint / q_norm + # use float() to prevent unsupported type + r = (I + normed_q) @ (I - normed_q).float().inverse() + return r + + def make_weight(self, scale=1, device=None, diff=False): + m = self.boft_m + b = self.boft_b + r_b = b // 2 + r = self.get_r() + inp = org = self.org_weight.to(device, dtype=r.dtype) + + for i in range(m): + bi = r[i] # b_num, b_size, b_size + g = 2 + k = 2**i * r_b + if scale != 1: + bi = bi * scale + (1 - scale) * self.I + inp = ( + inp.unflatten(-1, (-1, g, k)) + .transpose(-2, -1) + .flatten(-3) + .unflatten(-1, (-1, b)) + ) + inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp) + inp = ( + inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3) + ) + + if self.rescaled: + inp = inp * self.rescale + + if diff: + inp = inp - org + + return inp.to(self.oft_blocks.dtype) + + def get_diff_weight(self, multiplier=1, shape=None, device=None): + diff = self.make_weight(scale=multiplier, device=device, diff=True) + if shape is not None: + diff = diff.view(shape) + return diff, None + + def get_merged_weight(self, multiplier=1, shape=None, device=None): + diff = self.make_weight(scale=multiplier, device=device) + if shape is not None: + diff = diff.view(shape) + return diff, None + + @torch.no_grad() + def apply_max_norm(self, max_norm, device=None): + orig_norm = self.oft_blocks.to(device).norm() + norm = torch.clamp(orig_norm, max_norm / 2) + desired = torch.clamp(norm, max=max_norm) + ratio = desired / norm + + scaled = norm != desired + if scaled: + self.oft_blocks *= ratio + + return scaled, orig_norm * ratio + + def _bypass_forward(self, x, scale=1, diff=False): + m = self.boft_m + b = self.boft_b + r_b = b // 2 + r = self.get_r() + inp = org = self.org_forward(x) + if self.op in {F.conv2d, F.conv1d, F.conv3d}: + inp = inp.transpose(1, -1) + + for i in range(m): + bi = r[i] # b_num, b_size, b_size + g = 2 + k = 2**i * r_b + if scale != 1: + bi = bi * scale + (1 - scale) * self.I + inp = ( + inp.unflatten(-1, (-1, g, k)) + .transpose(-2, -1) + .flatten(-3) + .unflatten(-1, (-1, b)) + ) + inp = torch.einsum("b i j, ... b j -> ... b i", bi, inp) + inp = ( + inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3) + ) + + if self.rescaled: + inp = inp * self.rescale.transpose(0, -1) + + if self.op in {F.conv2d, F.conv1d, F.conv3d}: + inp = inp.transpose(1, -1) + + if diff: + inp = inp - org + return inp + + def bypass_forward_diff(self, x, scale=1): + return self._bypass_forward(x, scale, diff=True) + + def bypass_forward(self, x, scale=1): + return self._bypass_forward(x, scale, diff=False) + + def forward(self, x, *args, **kwargs): + if self.module_dropout and self.training: + if torch.rand(1) < self.module_dropout: + return self.org_forward(x) + scale = self.multiplier + + if self.bypass_mode: + return self.bypass_forward(x, scale) + else: + w = self.make_weight(scale, x.device) + kw_dict = self.kw_dict | {"weight": w, "bias": self.org_module[0].bias} + return self.op(x, **kw_dict) diff --git a/lycoris/modules/diag_oft.py b/lycoris/modules/diag_oft.py new file mode 100644 index 0000000000000000000000000000000000000000..3805995ac7c8be57ce5c3036392c441d3cb670dc --- /dev/null +++ b/lycoris/modules/diag_oft.py @@ -0,0 +1,217 @@ +from functools import cache + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base import LycorisBaseModule +from ..functional import factorization +from ..logging import logger + + +@cache +def log_oft_factorize(dim, factor, num, bdim): + logger.info( + f"Use OFT(block num: {num}, block dim: {bdim})" + f" (equivalent to lora_dim={num}) " + f"for {dim=} and lora_dim={factor=}" + ) + + +class DiagOFTModule(LycorisBaseModule): + name = "diag-oft" + support_module = { + "linear", + "conv1d", + "conv2d", + "conv3d", + } + weight_list = [ + "oft_blocks", + "rescale", + "alpha", + ] + weight_list_det = ["oft_blocks"] + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=0.0, + rank_dropout=0.0, + module_dropout=0.0, + use_tucker=False, + use_scalar=False, + rank_dropout_scale=False, + constraint=0, + rescaled=False, + bypass_mode=None, + **kwargs, + ): + super().__init__( + lora_name, + org_module, + multiplier, + dropout, + rank_dropout, + module_dropout, + rank_dropout_scale, + bypass_mode, + ) + if self.module_type not in self.support_module: + raise ValueError(f"{self.module_type} is not supported in Diag-OFT algo.") + + out_dim = self.dim + self.block_size, self.block_num = factorization(out_dim, lora_dim) + # block_num > block_size + self.rescaled = rescaled + self.constraint = constraint * out_dim + self.register_buffer("alpha", torch.tensor(constraint)) + self.oft_blocks = nn.Parameter( + torch.zeros(self.block_num, self.block_size, self.block_size) + ) + if rescaled: + self.rescale = nn.Parameter( + torch.ones(out_dim, *(1 for _ in range(org_module.weight.dim() - 1))) + ) + + log_oft_factorize( + dim=out_dim, + factor=lora_dim, + num=self.block_num, + bdim=self.block_size, + ) + + @classmethod + def algo_check(cls, state_dict, lora_name): + if f"{lora_name}.oft_blocks" in state_dict: + oft_blocks = state_dict[f"{lora_name}.oft_blocks"] + if oft_blocks.ndim == 3: + return True + return False + + @classmethod + def make_module_from_state_dict( + cls, lora_name, orig_module, oft_blocks, rescale, alpha + ): + n, s, _ = oft_blocks.shape + module = cls( + lora_name, + orig_module, + 1, + lora_dim=s, + constraint=float(alpha), + rescaled=rescale is not None, + ) + module.oft_blocks.copy_(oft_blocks) + if rescale is not None: + module.rescale.copy_(rescale) + return module + + @property + def I(self): + return torch.eye(self.block_size, device=self.device) + + def get_r(self): + I = self.I + # for Q = -Q^T + q = self.oft_blocks - self.oft_blocks.transpose(1, 2) + normed_q = q + if self.constraint > 0: + q_norm = torch.norm(q) + 1e-8 + if q_norm > self.constraint: + normed_q = q * self.constraint / q_norm + # use float() to prevent unsupported type + r = (I + normed_q) @ (I - normed_q).float().inverse() + return r + + def make_weight(self, scale=1, device=None, diff=False): + r = self.get_r() + _, *shape = self.org_weight.shape + org_weight = self.org_weight.to(device, dtype=r.dtype) + org_weight = org_weight.view(self.block_num, self.block_size, *shape) + # Init R=0, so add I on it to ensure the output of step0 is original model output + weight = torch.einsum( + "k n m, k n ... -> k m ...", + self.rank_drop(r * scale) - scale * self.I + (0 if diff else self.I), + org_weight, + ).view(-1, *shape) + if self.rescaled: + weight = self.rescale * weight + if diff: + weight = weight + (self.rescale - 1) * org_weight + return weight.to(self.oft_blocks.dtype) + + def get_diff_weight(self, multiplier=1, shape=None, device=None): + diff = self.make_weight(scale=multiplier, device=device, diff=True) + if shape is not None: + diff = diff.view(shape) + return diff, None + + def get_merged_weight(self, multiplier=1, shape=None, device=None): + diff = self.make_weight(scale=multiplier, device=device) + if shape is not None: + diff = diff.view(shape) + return diff, None + + @torch.no_grad() + def apply_max_norm(self, max_norm, device=None): + orig_norm = self.oft_blocks.to(device).norm() + norm = torch.clamp(orig_norm, max_norm / 2) + desired = torch.clamp(norm, max=max_norm) + ratio = desired / norm + + scaled = norm != desired + if scaled: + self.oft_blocks *= ratio + + return scaled, orig_norm * ratio + + def _bypass_forward(self, x, scale=1, diff=False): + r = self.get_r() + org_out = self.org_forward(x) + if self.op in {F.conv2d, F.conv1d, F.conv3d}: + org_out = org_out.transpose(1, -1) + *shape, _ = org_out.shape + org_out = org_out.view(*shape, self.block_num, self.block_size) + mask = neg_mask = 1 + if self.dropout != 0 and self.training: + mask = torch.ones_like(org_out) + mask = self.drop(mask) + neg_mask = torch.max(mask) - mask + oft_out = torch.einsum( + "k n m, ... k n -> ... k m", + r * scale * mask + (1 - scale) * self.I * neg_mask, + org_out, + ) + if diff: + out = out - org_out + out = oft_out.view(*shape, -1) + if self.rescaled: + out = self.rescale.transpose(-1, 0) * out + out = out + (self.rescale.transpose(-1, 0) - 1) * org_out + if self.op in {F.conv2d, F.conv1d, F.conv3d}: + out = out.transpose(1, -1) + return out + + def bypass_forward_diff(self, x, scale=1): + return self._bypass_forward(x, scale, diff=True) + + def bypass_forward(self, x, scale=1): + return self._bypass_forward(x, scale, diff=False) + + def forward(self, x: torch.Tensor, *args, **kwargs): + if self.module_dropout and self.training: + if torch.rand(1) < self.module_dropout: + return self.org_forward(x) + scale = self.multiplier + + if self.bypass_mode: + return self.bypass_forward(x, scale) + else: + w = self.make_weight(scale, x.device) + kw_dict = self.kw_dict | {"weight": w, "bias": self.org_module[0].bias} + return self.op(x, **kw_dict) diff --git a/lycoris/modules/dylora.py b/lycoris/modules/dylora.py new file mode 100644 index 0000000000000000000000000000000000000000..4138142e819f6a7146baa73bf36efec22926b166 --- /dev/null +++ b/lycoris/modules/dylora.py @@ -0,0 +1,156 @@ +import math +import random + +import torch +import torch.nn as nn + +from .base import LycorisBaseModule +from ..utils import product + + +class DyLoraModule(LycorisBaseModule): + support_module = { + "linear", + "conv1d", + "conv2d", + "conv3d", + } + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=0.0, + rank_dropout=0.0, + module_dropout=0.0, + use_tucker=False, + block_size=4, + use_scalar=False, + rank_dropout_scale=False, + weight_decompose=False, + bypass_mode=None, + rs_lora=False, + train_on_input=False, + **kwargs, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__( + lora_name, + org_module, + multiplier, + dropout, + rank_dropout, + module_dropout, + rank_dropout_scale, + bypass_mode, + ) + if self.module_type not in self.support_module: + raise ValueError(f"{self.module_type} is not supported in IA^3 algo.") + assert lora_dim % block_size == 0, "lora_dim must be a multiple of block_size" + self.block_count = lora_dim // block_size + self.block_size = block_size + + shape = ( + self.shape[0], + product(self.shape[1:]), + ) + + self.lora_dim = lora_dim + self.up_list = nn.ParameterList( + [torch.empty(shape[0], self.block_size) for i in range(self.block_count)] + ) + self.down_list = nn.ParameterList( + [torch.empty(self.block_size, shape[1]) for i in range(self.block_count)] + ) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # Need more experiences on init method + for v in self.down_list: + torch.nn.init.kaiming_uniform_(v, a=math.sqrt(5)) + for v in self.up_list: + torch.nn.init.zeros_(v) + + def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): + return + + def custom_state_dict(self): + destination = {} + destination["alpha"] = self.alpha + destination["lora_up.weight"] = nn.Parameter( + torch.concat(list(self.up_list), dim=1) + ) + destination["lora_down.weight"] = nn.Parameter( + torch.concat(list(self.down_list)).reshape( + self.lora_dim, -1, *self.shape[2:] + ) + ) + return destination + + def get_weight(self, rank): + b = math.ceil(rank / self.block_size) + down = torch.concat( + list(i.data for i in self.down_list[:b]) + list(self.down_list[b : (b + 1)]) + ) + up = torch.concat( + list(i.data for i in self.up_list[:b]) + list(self.up_list[b : (b + 1)]), + dim=1, + ) + return down, up, self.alpha / (b + 1) + + def get_random_rank_weight(self): + b = random.randint(0, self.block_count - 1) + return self.get_weight(b * self.block_size) + + def get_diff_weight(self, multiplier=1, shape=None, device=None, rank=None): + if rank is None: + down, up, scale = self.get_random_rank_weight() + else: + down, up, scale = self.get_weight(rank) + w = up @ (down * (scale * multiplier)) + if device is not None: + w = w.to(device) + if shape is not None: + w = w.view(shape) + else: + w = w.view(self.shape) + return w, None + + def get_merged_weight(self, multiplier=1, shape=None, device=None, rank=None): + diff, _ = self.get_diff_weight(multiplier, shape, device, rank) + return diff + self.org_weight, None + + def bypass_forward_diff(self, x, scale=1, rank=None): + if rank is None: + down, up, gamma = self.get_random_rank_weight() + else: + down, up, scale = self.get_weight(rank) + down = down.view(self.lora_dim, -1, *self.shape[2:]) + up = up.view(-1, self.lora_dim, *(1 for _ in self.shape[2:])) + scale = scale * gamma + return self.op(self.op(x, down, **self.kw_dict), up) + + def bypass_forward(self, x, scale=1, rank=None): + return self.org_forward(x) + self.bypass_forward_diff(x, scale, rank) + + def forward(self, x, *args, **kwargs): + if self.module_dropout and self.training: + if torch.rand(1) < self.module_dropout: + return self.org_forward(x) + if self.bypass_mode: + return self.bypass_forward(x, self.multiplier) + else: + weight = self.get_merged_weight(multiplier=self.multiplier)[0] + bias = ( + None + if self.org_module[0].bias is None + else self.org_module[0].bias.data + ) + return self.op(x, weight, bias, **self.kw_dict) diff --git a/lycoris/modules/full.py b/lycoris/modules/full.py new file mode 100644 index 0000000000000000000000000000000000000000..d4b0dc27fc47dec6759d086ba8b8e4726acfe875 --- /dev/null +++ b/lycoris/modules/full.py @@ -0,0 +1,214 @@ +from functools import cache + +import torch +import torch.nn as nn + +from .base import LycorisBaseModule +from ..logging import logger + + +@cache +def log_bypass_override(): + return logger.warning( + "Automatic Bypass-Mode detected in algo=full, " + "override with bypass_mode=False since algo=full not support bypass mode. " + "If you are using quantized model which require bypass mode, please don't use algo=full. " + ) + + +class FullModule(LycorisBaseModule): + name = "full" + support_module = { + "linear", + "conv1d", + "conv2d", + "conv3d", + } + weight_list = ["diff", "diff_b"] + weight_list_det = ["diff"] + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=0.0, + rank_dropout=0.0, + module_dropout=0.0, + use_tucker=False, + use_scalar=False, + rank_dropout_scale=False, + bypass_mode=None, + **kwargs, + ): + org_bypass = bypass_mode + super().__init__( + lora_name, + org_module, + multiplier, + dropout, + rank_dropout, + module_dropout, + rank_dropout_scale, + bypass_mode, + ) + if bypass_mode and org_bypass is None: + self.bypass_mode = False + log_bypass_override() + + if self.module_type not in self.support_module: + raise ValueError(f"{self.module_type} is not supported in Full algo.") + + if self.is_quant: + raise ValueError( + "Quant Linear is not supported and meaningless in Full algo." + ) + + if self.bypass_mode: + raise ValueError("bypass mode is not supported in Full algo.") + + self.weight = nn.Parameter(torch.zeros_like(org_module.weight)) + if org_module.bias is not None: + self.bias = nn.Parameter(torch.zeros_like(org_module.bias)) + else: + self.bias = None + self.is_diff = True + self._org_weight = [self.org_module[0].weight.data.cpu().clone()] + if self.org_module[0].bias is not None: + self.org_bias = [self.org_module[0].bias.data.cpu().clone()] + else: + self.org_bias = None + + @classmethod + def make_module_from_state_dict(cls, lora_name, orig_module, diff, diff_b): + module = cls( + lora_name, + orig_module, + 1, + ) + module.weight.copy_(diff) + if diff_b is not None: + if orig_module.bias is not None: + module.bias.copy_(diff_b) + else: + module.bias = nn.Parameter(diff_b) + module.is_diff = True + return module + + @property + def org_weight(self): + return self._org_weight[0] + + @org_weight.setter + def org_weight(self, value): + self.org_module[0].weight.data.copy_(value) + + def apply_to(self, **kwargs): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + self.weight.data.add_(self.org_module[0].weight.data) + self._org_weight = [self.org_module[0].weight.data.cpu().clone()] + delattr(self.org_module[0], "weight") + if self.org_module[0].bias is not None: + self.bias.data.add_(self.org_module[0].bias.data) + self.org_bias = [self.org_module[0].bias.data.cpu().clone()] + delattr(self.org_module[0], "bias") + else: + self.org_bias = None + self.is_diff = False + + def restore(self): + self.org_module[0].forward = self.org_forward + self.org_module[0].weight = nn.Parameter(self._org_weight[0]) + if self.org_bias is not None: + self.org_module[0].bias = nn.Parameter(self.org_bias[0]) + + def custom_state_dict(self): + sd = {"diff": self.weight.data.cpu() - self._org_weight[0]} + if self.bias is not None: + sd["diff_b"] = self.bias.data.cpu() - self.org_bias[0] + return sd + + def load_weight_prehook( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + diff_weight = state_dict.pop(f"{prefix}diff") + state_dict[f"{prefix}weight"] = diff_weight + self.weight.data.to(diff_weight) + if f"{prefix}diff_b" in state_dict: + diff_bias = state_dict.pop(f"{prefix}diff_b") + state_dict[f"{prefix}bias"] = diff_bias + self.bias.data.to(diff_bias) + + def make_weight(self, scale=1, device=None): + drop = ( + torch.rand(self.dim, device=device) > self.rank_dropout + if self.rank_dropout and self.training + else 1 + ) + if drop != 1 or scale != 1 or self.is_diff: + diff_w, diff_b = self.get_diff_weight(scale, device=device) + weight = self.org_weight + diff_w * drop + if self.org_bias is not None: + bias = self.org_bias + diff_b * drop + else: + bias = None + else: + weight = self.weight + bias = self.bias + return weight, bias + + def get_diff_weight(self, multiplier=1, shape=None, device=None): + if self.is_diff: + diff_b = None + if self.bias is not None: + diff_b = self.bias * multiplier + return self.weight * multiplier, diff_b + org_weight = self.org_module[0].weight.to(device, dtype=self.weight.dtype) + diff = self.weight.to(device) - org_weight + diff_b = None + if shape: + diff = diff.view(shape) + if self.bias is not None: + org_bias = self.org_module[0].bias.to(device, dtype=self.bias.dtype) + diff_b = self.bias.to(device) - org_bias + if device is not None: + diff = diff.to(device) + if self.bias is not None: + diff_b = diff_b.to(device) + if multiplier != 1: + diff = diff * multiplier + if diff_b is not None: + diff_b = diff_b * multiplier + return diff * multiplier, diff_b + + def get_merged_weight(self, multiplier=1, shape=None, device=None): + weight, bias = self.make_weight(multiplier, device) + if shape is not None: + weight = weight.view(shape) + if bias is not None: + bias = bias.view(shape[0]) + return weight, bias + + def forward(self, x: torch.Tensor, *args, **kwargs): + if ( + self.module_dropout + and self.training + and torch.rand(1) < self.module_dropout + ): + original = True + else: + original = False + if original: + return self.org_forward(x) + scale = self.multiplier + weight, bias = self.make_weight(scale, x.device) + kw_dict = self.kw_dict | {"weight": weight, "bias": bias} + return self.op(x, **kw_dict) diff --git a/lycoris/modules/glora.py b/lycoris/modules/glora.py new file mode 100644 index 0000000000000000000000000000000000000000..45a47d8de6749dd333594d4e008c47e42d873eaf --- /dev/null +++ b/lycoris/modules/glora.py @@ -0,0 +1,262 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base import LycorisBaseModule +from ..functional import tucker_weight_from_conv + + +class GLoRAModule(LycorisBaseModule): + name = "glora" + support_module = { + "linear", + "conv1d", + "conv2d", + "conv3d", + } + weight_list = [ + "a1.weight", + "a2.weight", + "b1.weight", + "b2.weight", + "bm.weight", + "alpha", + ] + weight_list_det = ["a1.weight"] + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=0.0, + rank_dropout=0.0, + module_dropout=0.0, + use_tucker=False, + use_scalar=False, + rank_dropout_scale=False, + weight_decompose=False, + bypass_mode=None, + rs_lora=False, + **kwargs, + ): + """ + f(x) = WX + WAX + BX, where A and B are low-rank matrices + bypass_forward(x) = W(X+A(X)) + B(X) + bypass_forward_diff(x) = W(A(X)) + B(X) + get_merged_weight() = W + WA + B + get_diff_weight() = WA + B + """ + super().__init__( + lora_name, + org_module, + multiplier, + dropout, + rank_dropout, + module_dropout, + rank_dropout_scale, + bypass_mode, + ) + if self.module_type not in self.support_module: + raise ValueError(f"{self.module_type} is not supported in GLoRA algo.") + self.lora_dim = lora_dim + self.tucker = False + self.rs_lora = rs_lora + + if self.module_type.startswith("conv"): + self.isconv = True + # For general LoCon + in_dim = org_module.in_channels + k_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + out_dim = org_module.out_channels + use_tucker = use_tucker and all(i == 1 for i in k_size) + self.down_op = self.op + self.up_op = self.op + + # A + self.a2 = self.module(in_dim, lora_dim, 1, bias=False) + self.a1 = self.module(lora_dim, in_dim, 1, bias=False) + + # B + if use_tucker and any(i != 1 for i in k_size): + self.b2 = self.module(in_dim, lora_dim, 1, bias=False) + self.bm = self.module( + lora_dim, lora_dim, k_size, stride, padding, bias=False + ) + self.tucker = True + else: + self.b2 = self.module( + in_dim, lora_dim, k_size, stride, padding, bias=False + ) + self.b1 = self.module(lora_dim, out_dim, 1, bias=False) + else: + self.isconv = False + self.down_op = F.linear + self.up_op = F.linear + in_dim = org_module.in_features + out_dim = org_module.out_features + self.a2 = nn.Linear(in_dim, lora_dim, bias=False) + self.a1 = nn.Linear(lora_dim, in_dim, bias=False) + self.b2 = nn.Linear(in_dim, lora_dim, bias=False) + self.b1 = nn.Linear(lora_dim, out_dim, bias=False) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = lora_dim if alpha is None or alpha == 0 else alpha + + r_factor = lora_dim + if self.rs_lora: + r_factor = math.sqrt(r_factor) + + self.scale = alpha / r_factor + + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + if use_scalar: + self.scalar = nn.Parameter(torch.tensor(0.0)) + else: + self.register_buffer("scalar", torch.tensor(1.0), persistent=False) + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.a1.weight, a=math.sqrt(5)) + torch.nn.init.kaiming_uniform_(self.b1.weight, a=math.sqrt(5)) + if use_scalar: + torch.nn.init.kaiming_uniform_(self.a2.weight, a=math.sqrt(5)) + torch.nn.init.kaiming_uniform_(self.b2.weight, a=math.sqrt(5)) + else: + torch.nn.init.zeros_(self.a2.weight) + torch.nn.init.zeros_(self.b2.weight) + + @classmethod + def make_module_from_state_dict( + cls, lora_name, orig_module, a1, a2, b1, b2, bm, alpha + ): + module = cls( + lora_name, + orig_module, + 1, + a2.size(0), + float(alpha), + use_tucker=bm is not None, + ) + module.a1.weight.data.copy_(a1) + module.a2.weight.data.copy_(a2) + module.b1.weight.data.copy_(b1) + module.b2.weight.data.copy_(b2) + if bm is not None: + module.bm.weight.data.copy_(bm) + return module + + def custom_state_dict(self): + destination = {} + destination["alpha"] = self.alpha + destination["a1.weight"] = self.a1.weight + destination["a2.weight"] = self.a2.weight * self.scalar + destination["b1.weight"] = self.b1.weight + destination["b2.weight"] = self.b2.weight * self.scalar + if self.tucker: + destination["bm.weight"] = self.bm.weight + return destination + + def load_weight_hook(self, module: nn.Module, incompatible_keys): + missing_keys = incompatible_keys.missing_keys + for key in missing_keys: + if "scalar" in key: + del missing_keys[missing_keys.index(key)] + if isinstance(self.scalar, nn.Parameter): + self.scalar.data.copy_(torch.ones_like(self.scalar)) + elif getattr(self, "scalar", None) is not None: + self.scalar.copy_(torch.ones_like(self.scalar)) + else: + self.register_buffer( + "scalar", torch.ones_like(self.scalar), persistent=False + ) + + def make_weight(self, device=None): + wa1 = self.a1.weight.view(self.a1.weight.size(0), -1) + wa2 = self.a2.weight.view(self.a2.weight.size(0), -1) + orig = self.org_weight + + if self.tucker: + wb = tucker_weight_from_conv(self.b1.weight, self.b2.weight, self.bm.weight) + else: + wb1 = self.b1.weight.view(self.b1.weight.size(0), -1) + wb2 = self.b2.weight.view(self.b2.weight.size(0), -1) + wb = wb1 @ wb2 + wb = wb.view(*orig.shape) + if orig.dim() > 2: + w_wa1 = torch.einsum("o i ..., i j -> o j ...", orig, wa1) + w_wa2 = torch.einsum("o i ..., i j -> o j ...", w_wa1, wa2) + else: + w_wa2 = (orig @ wa1) @ wa2 + return (wb + w_wa2) * self.scale * self.scalar + + def get_diff_weight(self, multiplier=1.0, shape=None, device=None): + weight = self.make_weight(device) * multiplier + if shape is not None: + weight = weight.view(shape) + return weight, None + + def get_merged_weight(self, multiplier=1, shape=None, device=None): + diff_w, _ = self.get_diff_weight(multiplier, shape, device) + return self.org_weight + diff_w, None + + def _bypass_forward(self, x, scale=1, diff=False): + scale = self.scale * scale + ax_mid = self.a2(x) * scale + bx_mid = self.b2(x) * scale + + if self.rank_dropout and self.training: + drop_a = ( + torch.rand(self.lora_dim, device=ax_mid.device) < self.rank_dropout + ).to(ax_mid.dtype) + drop_b = ( + torch.rand(self.lora_dim, device=bx_mid.device) < self.rank_dropout + ).to(bx_mid.dtype) + if self.rank_dropout_scale: + drop_a /= drop_a.mean() + drop_b /= drop_b.mean() + if (dims := len(x.shape)) == 4: + drop_a = drop_a.view(1, -1, 1, 1) + drop_b = drop_b.view(1, -1, 1, 1) + else: + drop_a = drop_a.view(*[1] * (dims - 1), -1) + drop_b = drop_b.view(*[1] * (dims - 1), -1) + ax_mid = ax_mid * drop_a + bx_mid = bx_mid * drop_b + return ( + self.org_forward( + (0 if diff else x) + self.drop(self.a1(ax_mid)) * self.scale + ) + + self.drop(self.b1(bx_mid)) * self.scale + ) + + def bypass_forward_diff(self, x, scale=1): + return self._bypass_forward(x, scale=scale, diff=True) + + def bypass_forward(self, x, scale=1): + return self._bypass_forward(x, scale=scale, diff=False) + + def forward(self, x, *args, **kwargs): + if self.module_dropout and self.training: + if torch.rand(1) < self.module_dropout: + return self.org_forward(x) + if self.bypass_mode: + return self.bypass_forward(x, self.multiplier) + else: + weight = ( + self.org_module[0].weight.data.to(self.dtype) + + self.get_diff_weight(multiplier=self.multiplier)[0] + ) + bias = ( + None + if self.org_module[0].bias is None + else self.org_module[0].bias.data + ) + return self.op(x, weight, bias, **self.kw_dict) diff --git a/lycoris/modules/ia3.py b/lycoris/modules/ia3.py new file mode 100644 index 0000000000000000000000000000000000000000..eeeaa1beae902fdfa870a429f03f805c51e13929 --- /dev/null +++ b/lycoris/modules/ia3.py @@ -0,0 +1,142 @@ +import torch +import torch.nn as nn + +from .base import LycorisBaseModule + + +class IA3Module(LycorisBaseModule): + name = "ia3" + support_module = { + "linear", + "conv1d", + "conv2d", + "conv3d", + } + weight_list = ["weight", "on_input"] + weight_list_det = ["on_input"] + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=0.0, + rank_dropout=0.0, + module_dropout=0.0, + use_tucker=False, + use_scalar=False, + rank_dropout_scale=False, + weight_decompose=False, + bypass_mode=None, + rs_lora=False, + train_on_input=False, + **kwargs, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__( + lora_name, + org_module, + multiplier, + dropout, + rank_dropout, + module_dropout, + rank_dropout_scale, + bypass_mode, + ) + if self.module_type not in self.support_module: + raise ValueError(f"{self.module_type} is not supported in IA^3 algo.") + + if self.module_type.startswith("conv"): + self.isconv = True + in_dim = org_module.in_channels + out_dim = org_module.out_channels + if train_on_input: + train_dim = in_dim + else: + train_dim = out_dim + self.weight = nn.Parameter( + torch.empty(1, train_dim, *(1 for _ in self.shape[2:])) + ) + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + if train_on_input: + train_dim = in_dim + else: + train_dim = out_dim + + self.weight = nn.Parameter(torch.empty(train_dim)) + + # Need more experiences on init method + torch.nn.init.constant_(self.weight, 0) + self.train_input = train_on_input + self.register_buffer("on_input", torch.tensor(int(train_on_input))) + + @classmethod + def make_module_from_state_dict(cls, lora_name, orig_module, weight): + module = cls( + lora_name, + orig_module, + 1, + ) + module.weight.data.copy_(weight) + return module + + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + + def make_weight(self, multiplier=1, shape=None, device=None, diff=False): + weight = self.weight * multiplier + int(not diff) + if self.train_input: + diff = self.org_weight * weight + else: + diff = self.org_weight.transpose(0, 1) * weight + diff = diff.transpose(0, 1) + if shape is not None: + diff = diff.view(shape) + if device is not None: + diff = diff.to(device) + return diff + + def get_diff_weight(self, multiplier=1, shape=None, device=None): + diff = self.make_weight( + multiplier=multiplier, shape=shape, device=device, diff=True + ) + return diff, None + + def get_merged_weight(self, multiplier=1, shape=None, device=None): + diff = self.make_weight(multiplier=multiplier, shape=shape, device=device) + return diff, None + + def _bypass_forward(self, x, scale=1, diff=False): + weight = self.weight * scale + int(not diff) + if self.train_input: + x = x * weight + out = self.org_forward(x) + if not self.train_input: + out = out * weight + return out + + def bypass_forward_diff(self, x, scale=1): + return self._bypass_forward(x, scale, diff=True) + + def bypass_forward(self, x, scale=1): + return self._bypass_forward(x, scale, diff=False) + + def forward(self, x, *args, **kwargs): + if self.module_dropout and self.training: + if torch.rand(1) < self.module_dropout: + return self.org_forward(x) + if self.bypass_mode: + return self.bypass_forward(x, self.multiplier) + else: + weight = self.get_merged_weight(multiplier=self.multiplier)[0] + bias = ( + None + if self.org_module[0].bias is None + else self.org_module[0].bias.data + ) + return self.op(x, weight, bias, **self.kw_dict) diff --git a/lycoris/modules/locon.py b/lycoris/modules/locon.py new file mode 100644 index 0000000000000000000000000000000000000000..0338684a28ce7f0b648a471914fcad0744f62cab --- /dev/null +++ b/lycoris/modules/locon.py @@ -0,0 +1,332 @@ +import math +from functools import cache + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base import LycorisBaseModule +from ..functional.general import rebuild_tucker +from ..logging import logger + + +@cache +def log_wd(): + return logger.warning( + "Using weight_decompose=True with LoRA (DoRA) will ignore network_dropout." + "Only rank dropout and module dropout will be applied" + ) + + +class LoConModule(LycorisBaseModule): + name = "locon" + support_module = { + "linear", + "conv1d", + "conv2d", + "conv3d", + } + weight_list = [ + "lora_up.weight", + "lora_down.weight", + "lora_mid.weight", + "alpha", + "dora_scale", + ] + weight_list_det = ["lora_up.weight"] + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=0.0, + rank_dropout=0.0, + module_dropout=0.0, + use_tucker=False, + use_scalar=False, + rank_dropout_scale=False, + weight_decompose=False, + wd_on_out=False, + bypass_mode=None, + rs_lora=False, + **kwargs, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__( + lora_name, + org_module, + multiplier, + dropout, + rank_dropout, + module_dropout, + rank_dropout_scale, + bypass_mode, + ) + if self.module_type not in self.support_module: + raise ValueError(f"{self.module_type} is not supported in LoRA/LoCon algo.") + self.lora_dim = lora_dim + self.tucker = False + self.rs_lora = rs_lora + + if self.module_type.startswith("conv"): + self.isconv = True + # For general LoCon + in_dim = org_module.in_channels + k_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + out_dim = org_module.out_channels + use_tucker = use_tucker and any(i != 1 for i in k_size) + self.down_op = self.op + self.up_op = self.op + if use_tucker and any(i != 1 for i in k_size): + self.lora_down = self.module(in_dim, lora_dim, 1, bias=False) + self.lora_mid = self.module( + lora_dim, lora_dim, k_size, stride, padding, bias=False + ) + self.tucker = True + else: + self.lora_down = self.module( + in_dim, lora_dim, k_size, stride, padding, bias=False + ) + self.lora_up = self.module(lora_dim, out_dim, 1, bias=False) + elif isinstance(org_module, nn.Linear): + self.isconv = False + self.down_op = F.linear + self.up_op = F.linear + in_dim = org_module.in_features + out_dim = org_module.out_features + self.lora_down = nn.Linear(in_dim, lora_dim, bias=False) + self.lora_up = nn.Linear(lora_dim, out_dim, bias=False) + else: + raise NotImplementedError + + self.wd = weight_decompose + self.wd_on_out = wd_on_out + if self.wd: + org_weight = org_module.weight.cpu().clone().float() + self.dora_norm_dims = org_weight.dim() - 1 + if self.wd_on_out: + self.dora_scale = nn.Parameter( + torch.norm( + org_weight.reshape(org_weight.shape[0], -1), + dim=1, + keepdim=True, + ).reshape(org_weight.shape[0], *[1] * self.dora_norm_dims) + ).float() + else: + self.dora_scale = nn.Parameter( + torch.norm( + org_weight.transpose(1, 0).reshape(org_weight.shape[1], -1), + dim=1, + keepdim=True, + ) + .reshape(org_weight.shape[1], *[1] * self.dora_norm_dims) + .transpose(1, 0) + ).float() + + if dropout: + self.dropout = nn.Dropout(dropout) + if self.wd: + log_wd() + else: + self.dropout = nn.Identity() + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = lora_dim if alpha is None or alpha == 0 else alpha + + r_factor = lora_dim + if self.rs_lora: + r_factor = math.sqrt(r_factor) + + self.scale = alpha / r_factor + + self.register_buffer("alpha", torch.tensor(alpha * (lora_dim / r_factor))) + + if use_scalar: + self.scalar = nn.Parameter(torch.tensor(0.0)) + else: + self.register_buffer("scalar", torch.tensor(1.0), persistent=False) + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + if use_scalar: + torch.nn.init.kaiming_uniform_(self.lora_up.weight, a=math.sqrt(5)) + else: + torch.nn.init.constant_(self.lora_up.weight, 0) + if self.tucker: + torch.nn.init.kaiming_uniform_(self.lora_mid.weight, a=math.sqrt(5)) + + @classmethod + def make_module_from_state_dict( + cls, lora_name, orig_module, up, down, mid, alpha, dora_scale + ): + module = cls( + lora_name, + orig_module, + 1, + down.size(0), + float(alpha), + use_tucker=mid is not None, + weight_decompose=dora_scale is not None, + ) + module.lora_up.weight.data.copy_(up) + module.lora_down.weight.data.copy_(down) + if mid is not None: + module.lora_mid.weight.data.copy_(mid) + if dora_scale is not None: + module.dora_scale.copy_(dora_scale) + return module + + def load_weight_hook(self, module: nn.Module, incompatible_keys): + missing_keys = incompatible_keys.missing_keys + for key in missing_keys: + if "scalar" in key: + del missing_keys[missing_keys.index(key)] + if isinstance(self.scalar, nn.Parameter): + self.scalar.data.copy_(torch.ones_like(self.scalar)) + elif getattr(self, "scalar", None) is not None: + self.scalar.copy_(torch.ones_like(self.scalar)) + else: + self.register_buffer( + "scalar", torch.ones_like(self.scalar), persistent=False + ) + + def make_weight(self, device=None): + wa = self.lora_up.weight.to(device) + wb = self.lora_down.weight.to(device) + if self.tucker: + t = self.lora_mid.weight + wa = wa.view(wa.size(0), -1).transpose(0, 1) + wb = wb.view(wb.size(0), -1) + weight = rebuild_tucker(t, wa, wb) + else: + weight = wa.view(wa.size(0), -1) @ wb.view(wb.size(0), -1) + + weight = weight.view(self.shape) + if self.training and self.rank_dropout: + drop = (torch.rand(weight.size(0), device=device) > self.rank_dropout).to( + weight.dtype + ) + drop = drop.view(-1, *[1] * len(weight.shape[1:])) + if self.rank_dropout_scale: + drop /= drop.mean() + weight *= drop + + return weight * self.scalar.to(device) + + def get_diff_weight(self, multiplier=1, shape=None, device=None): + scale = self.scale * multiplier + diff = self.make_weight(device=device) * scale + if shape is not None: + diff = diff.view(shape) + if device is not None: + diff = diff.to(device) + return diff, None + + def get_merged_weight(self, multiplier=1, shape=None, device=None): + diff = self.get_diff_weight(multiplier=1, shape=shape, device=device)[0] + weight = self.org_weight + if self.wd: + merged = self.apply_weight_decompose(weight + diff, multiplier) + else: + merged = weight + diff * multiplier + return merged, None + + def apply_weight_decompose(self, weight, multiplier=1): + weight = weight.to(self.dora_scale.dtype) + if self.wd_on_out: + weight_norm = ( + weight.reshape(weight.shape[0], -1) + .norm(dim=1) + .reshape(weight.shape[0], *[1] * self.dora_norm_dims) + ) + torch.finfo(weight.dtype).eps + else: + weight_norm = ( + weight.transpose(0, 1) + .reshape(weight.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight.shape[1], *[1] * self.dora_norm_dims) + .transpose(0, 1) + ) + torch.finfo(weight.dtype).eps + + scale = self.dora_scale.to(weight.device) / weight_norm + if multiplier != 1: + scale = multiplier * (scale - 1) + 1 + + return weight * scale + + def custom_state_dict(self): + destination = {} + if self.wd: + destination["dora_scale"] = self.dora_scale + destination["alpha"] = self.alpha + destination["lora_up.weight"] = self.lora_up.weight * self.scalar + destination["lora_down.weight"] = self.lora_down.weight + if self.tucker: + destination["lora_mid.weight"] = self.lora_mid.weight + return destination + + @torch.no_grad() + def apply_max_norm(self, max_norm, device=None): + orig_norm = self.make_weight(device).norm() * self.scale + norm = torch.clamp(orig_norm, max_norm / 2) + desired = torch.clamp(norm, max=max_norm) + ratio = desired.cpu() / norm.cpu() + + scaled = norm != desired + if scaled: + self.scalar *= ratio + + return scaled, orig_norm * ratio + + def bypass_forward_diff(self, x, scale=1): + if self.tucker: + mid = self.lora_mid(self.lora_down(x)) + else: + mid = self.lora_down(x) + + if self.rank_dropout and self.training: + drop = ( + torch.rand(self.lora_dim, device=mid.device) > self.rank_dropout + ).to(mid.dtype) + if self.rank_dropout_scale: + drop /= drop.mean() + if (dims := len(x.shape)) == 4: + drop = drop.view(1, -1, 1, 1) + else: + drop = drop.view(*[1] * (dims - 1), -1) + mid = mid * drop + + return self.dropout(self.lora_up(mid) * self.scalar * self.scale * scale) + + def bypass_forward(self, x, scale=1): + return self.org_forward(x) + self.bypass_forward_diff(x, scale=scale) + + def forward(self, x): + if self.module_dropout and self.training: + if torch.rand(1) < self.module_dropout: + return self.org_forward(x) + scale = self.scale + + dtype = self.dtype + if not self.bypass_mode: + diff_weight = self.make_weight(x.device).to(dtype) * scale + weight = self.org_module[0].weight.data.to(dtype) + if self.wd: + weight = self.apply_weight_decompose( + weight + diff_weight, self.multiplier + ) + else: + weight = weight + diff_weight * self.multiplier + bias = ( + None + if self.org_module[0].bias is None + else self.org_module[0].bias.data + ) + return self.op(x, weight, bias, **self.kw_dict) + else: + return self.bypass_forward(x, scale=self.multiplier) diff --git a/lycoris/modules/loha.py b/lycoris/modules/loha.py new file mode 100644 index 0000000000000000000000000000000000000000..54f8e52eb0b3837ec01e412dd562806653f8a3a8 --- /dev/null +++ b/lycoris/modules/loha.py @@ -0,0 +1,329 @@ +import math + +import torch +import torch.nn as nn + +from .base import LycorisBaseModule +from ..functional.loha import diff_weight as loha_diff_weight + + +class LohaModule(LycorisBaseModule): + name = "loha" + support_module = { + "linear", + "conv1d", + "conv2d", + "conv3d", + } + weight_list = [ + "hada_w1_a", + "hada_w1_b", + "hada_w2_a", + "hada_w2_b", + "hada_t1", + "hada_t2", + "alpha", + "dora_scale", + ] + weight_list_det = ["hada_w1_a"] + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=0.0, + rank_dropout=0.0, + module_dropout=0.0, + use_tucker=False, + use_scalar=False, + rank_dropout_scale=False, + weight_decompose=False, + wd_on_out=False, + bypass_mode=None, + rs_lora=False, + **kwargs, + ): + super().__init__( + lora_name, + org_module, + multiplier, + dropout, + rank_dropout, + module_dropout, + rank_dropout_scale, + bypass_mode, + ) + if self.module_type not in self.support_module: + raise ValueError(f"{self.module_type} is not supported in LoHa algo.") + self.lora_name = lora_name + self.lora_dim = lora_dim + self.tucker = False + self.rs_lora = rs_lora + + w_shape = self.shape + if self.module_type.startswith("conv"): + in_dim = org_module.in_channels + k_size = org_module.kernel_size + out_dim = org_module.out_channels + self.shape = (out_dim, in_dim, *k_size) + self.tucker = use_tucker and any(i != 1 for i in k_size) + if self.tucker: + w_shape = (out_dim, in_dim, *k_size) + else: + w_shape = (out_dim, in_dim * torch.tensor(k_size).prod().item()) + + if self.tucker: + self.hada_t1 = nn.Parameter(torch.empty(lora_dim, lora_dim, *w_shape[2:])) + self.hada_w1_a = nn.Parameter( + torch.empty(lora_dim, w_shape[0]) + ) # out_dim, 1-mode + self.hada_w1_b = nn.Parameter( + torch.empty(lora_dim, w_shape[1]) + ) # in_dim , 2-mode + + self.hada_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *w_shape[2:])) + self.hada_w2_a = nn.Parameter( + torch.empty(lora_dim, w_shape[0]) + ) # out_dim, 1-mode + self.hada_w2_b = nn.Parameter( + torch.empty(lora_dim, w_shape[1]) + ) # in_dim , 2-mode + else: + self.hada_w1_a = nn.Parameter(torch.empty(w_shape[0], lora_dim)) + self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, w_shape[1])) + + self.hada_w2_a = nn.Parameter(torch.empty(w_shape[0], lora_dim)) + self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, w_shape[1])) + + self.wd = weight_decompose + self.wd_on_out = wd_on_out + if self.wd: + org_weight = org_module.weight.cpu().clone().float() + self.dora_norm_dims = org_weight.dim() - 1 + if self.wd_on_out: + self.dora_scale = nn.Parameter( + torch.norm( + org_weight.reshape(org_weight.shape[0], -1), + dim=1, + keepdim=True, + ).reshape(org_weight.shape[0], *[1] * self.dora_norm_dims) + ).float() + else: + self.dora_scale = nn.Parameter( + torch.norm( + org_weight.transpose(1, 0).reshape(org_weight.shape[1], -1), + dim=1, + keepdim=True, + ) + .reshape(org_weight.shape[1], *[1] * self.dora_norm_dims) + .transpose(1, 0) + ).float() + + if self.dropout: + print("[WARN]LoHa/LoKr haven't implemented normal dropout yet.") + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = lora_dim if alpha is None or alpha == 0 else alpha + + r_factor = lora_dim + if self.rs_lora: + r_factor = math.sqrt(r_factor) + + self.scale = alpha / r_factor + + self.register_buffer("alpha", torch.tensor(alpha * (lora_dim / r_factor))) + + if use_scalar: + self.scalar = nn.Parameter(torch.tensor(0.0)) + else: + self.register_buffer("scalar", torch.tensor(1.0), persistent=False) + # Need more experiments on init method + if self.tucker: + torch.nn.init.normal_(self.hada_t1, std=0.1) + torch.nn.init.normal_(self.hada_t2, std=0.1) + torch.nn.init.normal_(self.hada_w1_b, std=1) + torch.nn.init.normal_(self.hada_w1_a, std=0.1) + torch.nn.init.normal_(self.hada_w2_b, std=1) + if use_scalar: + torch.nn.init.normal_(self.hada_w2_a, std=0.1) + else: + torch.nn.init.constant_(self.hada_w2_a, 0) + + @classmethod + def make_module_from_state_dict( + cls, lora_name, orig_module, w1a, w1b, w2a, w2b, t1, t2, alpha, dora_scale + ): + module = cls( + lora_name, + orig_module, + 1, + w1b.size(0), + float(alpha), + use_tucker=t1 is not None, + weight_decompose=dora_scale is not None, + ) + module.hada_w1_a.copy_(w1a) + module.hada_w1_b.copy_(w1b) + module.hada_w2_a.copy_(w2a) + module.hada_w2_b.copy_(w2b) + if t1 is not None: + module.hada_t1.copy_(t1) + module.hada_t2.copy_(t2) + if dora_scale is not None: + module.dora_scale.copy_(dora_scale) + return module + + def load_weight_hook(self, module: nn.Module, incompatible_keys): + missing_keys = incompatible_keys.missing_keys + for key in missing_keys: + if "scalar" in key: + del missing_keys[missing_keys.index(key)] + if isinstance(self.scalar, nn.Parameter): + self.scalar.data.copy_(torch.ones_like(self.scalar)) + elif getattr(self, "scalar", None) is not None: + self.scalar.copy_(torch.ones_like(self.scalar)) + else: + self.register_buffer( + "scalar", torch.ones_like(self.scalar), persistent=False + ) + + def get_weight(self, shape): + scale = torch.tensor( + self.scale, dtype=self.hada_w1_b.dtype, device=self.hada_w1_b.device + ) + if self.tucker: + weight = loha_diff_weight( + self.hada_w1_b, + self.hada_w1_a, + self.hada_w2_b, + self.hada_w2_a, + self.hada_t1, + self.hada_t2, + gamma=scale, + ) + else: + weight = loha_diff_weight( + self.hada_w1_b, + self.hada_w1_a, + self.hada_w2_b, + self.hada_w2_a, + None, + None, + gamma=scale, + ) + if shape is not None: + weight = weight.reshape(shape) + if self.training and self.rank_dropout: + drop = (torch.rand(weight.size(0)) > self.rank_dropout).to(weight.dtype) + drop = drop.view(-1, *[1] * len(weight.shape[1:])).to(weight.device) + if self.rank_dropout_scale: + drop /= drop.mean() + weight *= drop + return weight + + def get_diff_weight(self, multiplier=1, shape=None, device=None): + scale = self.scale * multiplier + diff = self.get_weight(shape) * scale + if device is not None: + diff = diff.to(device) + return diff, None + + def get_merged_weight(self, multiplier=1, shape=None, device=None): + diff = self.get_diff_weight(multiplier=1, shape=shape, device=device)[0] + weight = self.org_weight + if self.wd: + merged = self.apply_weight_decompose(weight + diff, multiplier) + else: + merged = weight + diff * multiplier + return merged, None + + def apply_weight_decompose(self, weight, multiplier=1): + weight = weight.to(self.dora_scale.dtype) + if self.wd_on_out: + weight_norm = ( + weight.reshape(weight.shape[0], -1) + .norm(dim=1) + .reshape(weight.shape[0], *[1] * self.dora_norm_dims) + ) + torch.finfo(weight.dtype).eps + else: + weight_norm = ( + weight.transpose(0, 1) + .reshape(weight.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight.shape[1], *[1] * self.dora_norm_dims) + .transpose(0, 1) + ) + torch.finfo(weight.dtype).eps + + scale = self.dora_scale.to(weight.device) / weight_norm + if multiplier != 1: + scale = multiplier * (scale - 1) + 1 + + return weight * scale + + def custom_state_dict(self): + destination = {} + destination["alpha"] = self.alpha + if self.wd: + destination["dora_scale"] = self.dora_scale + destination["hada_w1_a"] = self.hada_w1_a * self.scalar + destination["hada_w1_b"] = self.hada_w1_b + destination["hada_w2_a"] = self.hada_w2_a + destination["hada_w2_b"] = self.hada_w2_b + if self.tucker: + destination["hada_t1"] = self.hada_t1 + destination["hada_t2"] = self.hada_t2 + return destination + + @torch.no_grad() + def apply_max_norm(self, max_norm, device=None): + orig_norm = (self.get_weight(self.shape) * self.scalar).norm() + norm = torch.clamp(orig_norm, max_norm / 2) + desired = torch.clamp(norm, max=max_norm) + ratio = desired.cpu() / norm.cpu() + + scaled = norm != desired + if scaled: + self.scalar *= ratio + + return scaled, orig_norm * ratio + + def bypass_forward_diff(self, x, scale=1): + diff_weight = self.get_weight(self.shape) * self.scalar * scale + return self.drop(self.op(x, diff_weight, **self.kw_dict)) + + def bypass_forward(self, x, scale=1): + return self.org_forward(x) + self.bypass_forward_diff(x, scale=scale) + + def forward(self, x: torch.Tensor, *args, **kwargs): + if self.module_dropout and self.training: + if torch.rand(1) < self.module_dropout: + return self.op( + x, + self.org_module[0].weight.data, + ( + None + if self.org_module[0].bias is None + else self.org_module[0].bias.data + ), + ) + if self.bypass_mode: + return self.bypass_forward(x, scale=self.multiplier) + else: + diff_weight = self.get_weight(self.shape).to(self.dtype) * self.scalar + weight = self.org_module[0].weight.data.to(self.dtype) + if self.wd: + weight = self.apply_weight_decompose( + weight + diff_weight, self.multiplier + ) + else: + weight = weight + diff_weight * self.multiplier + bias = ( + None + if self.org_module[0].bias is None + else self.org_module[0].bias.data + ) + return self.op(x, weight, bias, **self.kw_dict) diff --git a/lycoris/modules/lokr.py b/lycoris/modules/lokr.py new file mode 100644 index 0000000000000000000000000000000000000000..12ecf42e97f3b377bc04b047dc6b623c8a6992d9 --- /dev/null +++ b/lycoris/modules/lokr.py @@ -0,0 +1,609 @@ +import math +from functools import cache + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base import LycorisBaseModule +from ..functional import factorization, rebuild_tucker +from ..functional.lokr import make_kron +from ..logging import logger + + +@cache +def logging_force_full_matrix(lora_dim, dim, factor): + logger.warning( + f"lora_dim {lora_dim} is too large for" + f" dim={dim} and {factor=}" + ", using full matrix mode." + ) + + +class LokrModule(LycorisBaseModule): + name = "kron" + support_module = { + "linear", + "conv1d", + "conv2d", + "conv3d", + } + weight_list = [ + "lokr_w1", + "lokr_w1_a", + "lokr_w1_b", + "lokr_w2", + "lokr_w2_a", + "lokr_w2_b", + "lokr_t1", + "lokr_t2", + "alpha", + "dora_scale", + ] + weight_list_det = ["lokr_w1", "lokr_w1_a"] + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=0.0, + rank_dropout=0.0, + module_dropout=0.0, + use_tucker=False, + use_scalar=False, + decompose_both=False, + factor: int = -1, # factorization factor + rank_dropout_scale=False, + weight_decompose=False, + wd_on_out=False, + full_matrix=False, + bypass_mode=None, + rs_lora=False, + unbalanced_factorization=False, + **kwargs, + ): + super().__init__( + lora_name, + org_module, + multiplier, + dropout, + rank_dropout, + module_dropout, + rank_dropout_scale, + bypass_mode, + ) + if self.module_type not in self.support_module: + raise ValueError(f"{self.module_type} is not supported in LoKr algo.") + + factor = int(factor) + self.lora_dim = lora_dim + self.tucker = False + self.use_w1 = False + self.use_w2 = False + self.full_matrix = full_matrix + self.rs_lora = rs_lora + + if self.module_type.startswith("conv"): + in_dim = org_module.in_channels + k_size = org_module.kernel_size + out_dim = org_module.out_channels + self.shape = (out_dim, in_dim, *k_size) + + in_m, in_n = factorization(in_dim, factor) + out_l, out_k = factorization(out_dim, factor) + if unbalanced_factorization: + out_l, out_k = out_k, out_l + shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), *k_size) + self.tucker = use_tucker and any(i != 1 for i in k_size) + if ( + decompose_both + and lora_dim < max(shape[0][0], shape[1][0]) / 2 + and not self.full_matrix + ): + self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim)) + self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0])) + else: + self.use_w1 = True + self.lokr_w1 = nn.Parameter( + torch.empty(shape[0][0], shape[1][0]) + ) # a*c, 1-mode + + if lora_dim >= max(shape[0][1], shape[1][1]) / 2 or self.full_matrix: + if not self.full_matrix: + logging_force_full_matrix(lora_dim, max(in_dim, out_dim), factor) + self.use_w2 = True + self.lokr_w2 = nn.Parameter( + torch.empty(shape[0][1], shape[1][1], *k_size) + ) + elif self.tucker: + self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *shape[2:])) + self.lokr_w2_a = nn.Parameter( + torch.empty(lora_dim, shape[0][1]) + ) # b, 1-mode + self.lokr_w2_b = nn.Parameter( + torch.empty(lora_dim, shape[1][1]) + ) # d, 2-mode + else: # Conv2d not tucker + # bigger part. weight and LoRA. [b, dim] x [dim, d*k1*k2] + self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim)) + self.lokr_w2_b = nn.Parameter( + torch.empty( + lora_dim, shape[1][1] * torch.tensor(shape[2:]).prod().item() + ) + ) + # w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d*k1*k2)) = (a, b)⊗(c, d*k1*k2) = (ac, bd*k1*k2) + else: # Linear + in_dim = org_module.in_features + out_dim = org_module.out_features + self.shape = (out_dim, in_dim) + + in_m, in_n = factorization(in_dim, factor) + out_l, out_k = factorization(out_dim, factor) + if unbalanced_factorization: + out_l, out_k = out_k, out_l + shape = ( + (out_l, out_k), + (in_m, in_n), + ) # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d + # smaller part. weight scale + if ( + decompose_both + and lora_dim < max(shape[0][0], shape[1][0]) / 2 + and not self.full_matrix + ): + self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim)) + self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0])) + else: + self.use_w1 = True + self.lokr_w1 = nn.Parameter( + torch.empty(shape[0][0], shape[1][0]) + ) # a*c, 1-mode + if lora_dim < max(shape[0][1], shape[1][1]) / 2 and not self.full_matrix: + # bigger part. weight and LoRA. [b, dim] x [dim, d] + self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim)) + self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1])) + # w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d)) = (a, b)⊗(c, d) = (ac, bd) + else: + if not self.full_matrix: + logging_force_full_matrix(lora_dim, max(in_dim, out_dim), factor) + self.use_w2 = True + self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1])) + + self.wd = weight_decompose + self.wd_on_out = wd_on_out + if self.wd: + org_weight = org_module.weight.cpu().clone().float() + self.dora_norm_dims = org_weight.dim() - 1 + if self.wd_on_out: + self.dora_scale = nn.Parameter( + torch.norm( + org_weight.reshape(org_weight.shape[0], -1), + dim=1, + keepdim=True, + ).reshape(org_weight.shape[0], *[1] * self.dora_norm_dims) + ).float() + else: + self.dora_scale = nn.Parameter( + torch.norm( + org_weight.transpose(1, 0).reshape(org_weight.shape[1], -1), + dim=1, + keepdim=True, + ) + .reshape(org_weight.shape[1], *[1] * self.dora_norm_dims) + .transpose(1, 0) + ).float() + + self.dropout = dropout + if dropout: + print("[WARN]LoHa/LoKr haven't implemented normal dropout yet.") + self.rank_dropout = rank_dropout + self.rank_dropout_scale = rank_dropout_scale + self.module_dropout = module_dropout + + if isinstance(alpha, torch.Tensor): + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = lora_dim if alpha is None or alpha == 0 else alpha + if self.use_w2 and self.use_w1: + # use scale = 1 + alpha = lora_dim + + r_factor = lora_dim + if self.rs_lora: + r_factor = math.sqrt(r_factor) + + self.scale = alpha / r_factor + + self.register_buffer("alpha", torch.tensor(alpha * (lora_dim / r_factor))) + + if use_scalar: + self.scalar = nn.Parameter(torch.tensor(0.0)) + else: + self.register_buffer("scalar", torch.tensor(1.0), persistent=False) + + if self.use_w2: + if use_scalar: + torch.nn.init.kaiming_uniform_(self.lokr_w2, a=math.sqrt(5)) + else: + torch.nn.init.constant_(self.lokr_w2, 0) + else: + if self.tucker: + torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5)) + torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5)) + if use_scalar: + torch.nn.init.kaiming_uniform_(self.lokr_w2_b, a=math.sqrt(5)) + else: + torch.nn.init.constant_(self.lokr_w2_b, 0) + + if self.use_w1: + torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5)) + else: + torch.nn.init.kaiming_uniform_(self.lokr_w1_a, a=math.sqrt(5)) + torch.nn.init.kaiming_uniform_(self.lokr_w1_b, a=math.sqrt(5)) + + @classmethod + def make_module_from_state_dict( + cls, + lora_name, + orig_module, + w1, + w1a, + w1b, + w2, + w2a, + w2b, + _, + t2, + alpha, + dora_scale, + ): + full_matrix = False + if w1a is not None: + lora_dim = w1a.size(1) + elif w2a is not None: + lora_dim = w2a.size(1) + else: + full_matrix = True + lora_dim = 1 + + if w1 is None: + out_dim = w1a.size(0) + in_dim = w1b.size(1) + else: + out_dim, in_dim = w1.shape + + shape_s = [out_dim, in_dim] + + if w2 is None: + out_dim *= w2a.size(0) + in_dim *= w2b.size(1) + else: + out_dim *= w2.size(0) + in_dim *= w2.size(1) + + if ( + shape_s[0] == factorization(out_dim, -1)[0] + and shape_s[1] == factorization(in_dim, -1)[0] + ): + factor = -1 + else: + w1_shape = w1.shape if w1 is not None else (w1a.size(0), w1b.size(1)) + w2_shape = w2.shape if w2 is not None else (w2a.size(0), w2b.size(1)) + shape_group_1 = (w1_shape[0], w2_shape[0]) + shape_group_2 = (w1_shape[1], w2_shape[1]) + w_shape = (w1_shape[0] * w2_shape[0], w1_shape[1] * w2_shape[1]) + factor1 = max(w1.shape) if w1 is not None else max(w1a.size(0), w1b.size(1)) + factor2 = max(w2.shape) if w2 is not None else max(w2a.size(0), w2b.size(1)) + if ( + w_shape[0] % factor1 == 0 + and w_shape[1] % factor1 == 0 + and factor1 in shape_group_1 + and factor1 in shape_group_2 + ): + factor = factor1 + elif ( + w_shape[0] % factor2 == 0 + and w_shape[1] % factor2 == 0 + and factor2 in shape_group_1 + and factor2 in shape_group_2 + ): + factor = factor2 + else: + factor = min(factor1, factor2) + + module = cls( + lora_name, + orig_module, + 1, + lora_dim, + float(alpha), + use_tucker=t2 is not None, + decompose_both=w1 is None and w2 is None, + factor=factor, + weight_decompose=dora_scale is not None, + full_matrix=full_matrix, + ) + if w1 is not None: + module.lokr_w1.copy_(w1) + else: + module.lokr_w1_a.copy_(w1a) + module.lokr_w1_b.copy_(w1b) + if w2 is not None: + module.lokr_w2.copy_(w2) + else: + module.lokr_w2_a.copy_(w2a) + module.lokr_w2_b.copy_(w2b) + if t2 is not None: + module.lokr_t2.copy_(t2) + if dora_scale is not None: + module.dora_scale.copy_(dora_scale) + return module + + def load_weight_hook(self, module: nn.Module, incompatible_keys): + missing_keys = incompatible_keys.missing_keys + for key in missing_keys: + if "scalar" in key: + del missing_keys[missing_keys.index(key)] + if isinstance(self.scalar, nn.Parameter): + self.scalar.data.copy_(torch.ones_like(self.scalar)) + elif getattr(self, "scalar", None) is not None: + self.scalar.copy_(torch.ones_like(self.scalar)) + else: + self.register_buffer( + "scalar", torch.ones_like(self.scalar), persistent=False + ) + + def get_weight(self, shape): + weight = make_kron( + self.lokr_w1 if self.use_w1 else self.lokr_w1_a @ self.lokr_w1_b, + ( + self.lokr_w2 + if self.use_w2 + else ( + rebuild_tucker(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) + if self.tucker + else self.lokr_w2_a @ self.lokr_w2_b + ) + ), + self.scale, + ) + dtype = weight.dtype + if shape is not None: + weight = weight.view(shape) + if self.training and self.rank_dropout: + drop = (torch.rand(weight.size(0)) > self.rank_dropout).to(dtype) + drop = drop.view(-1, *[1] * len(weight.shape[1:])) + if self.rank_dropout_scale: + drop /= drop.mean() + weight *= drop + return weight + + def get_diff_weight(self, multiplier=1, shape=None, device=None): + scale = self.scale * multiplier + diff = self.get_weight(shape) * scale + if device is not None: + diff = diff.to(device) + return diff, None + + def get_merged_weight(self, multiplier=1, shape=None, device=None): + diff = self.get_diff_weight(multiplier=1, shape=shape, device=device)[0] + weight = self.org_weight + if self.wd: + merged = self.apply_weight_decompose(weight + diff, multiplier) + else: + merged = weight + diff * multiplier + return merged, None + + def apply_weight_decompose(self, weight, multiplier=1): + weight = weight.to(self.dora_scale.dtype) + if self.wd_on_out: + weight_norm = ( + weight.reshape(weight.shape[0], -1) + .norm(dim=1) + .reshape(weight.shape[0], *[1] * self.dora_norm_dims) + ) + torch.finfo(weight.dtype).eps + else: + weight_norm = ( + weight.transpose(0, 1) + .reshape(weight.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight.shape[1], *[1] * self.dora_norm_dims) + .transpose(0, 1) + ) + torch.finfo(weight.dtype).eps + + scale = self.dora_scale.to(weight.device) / weight_norm + if multiplier != 1: + scale = multiplier * (scale - 1) + 1 + + return weight * scale + + def custom_state_dict(self): + destination = {} + destination["alpha"] = self.alpha + if self.wd: + destination["dora_scale"] = self.dora_scale + if self.use_w1: + destination["lokr_w1"] = self.lokr_w1 * self.scalar + else: + destination["lokr_w1_a"] = self.lokr_w1_a * self.scalar + destination["lokr_w1_b"] = self.lokr_w1_b + + if self.use_w2: + destination["lokr_w2"] = self.lokr_w2 + else: + destination["lokr_w2_a"] = self.lokr_w2_a + destination["lokr_w2_b"] = self.lokr_w2_b + if self.tucker: + destination["lokr_t2"] = self.lokr_t2 + return destination + + @torch.no_grad() + def apply_max_norm(self, max_norm, device=None): + orig_norm = self.get_weight(self.shape).norm() + norm = torch.clamp(orig_norm, max_norm / 2) + desired = torch.clamp(norm, max=max_norm) + ratio = desired.cpu() / norm.cpu() + + scaled = norm != desired + if scaled: + modules = 4 - self.use_w1 - self.use_w2 + (not self.use_w2 and self.tucker) + if self.use_w1: + self.lokr_w1 *= ratio ** (1 / modules) + else: + self.lokr_w1_a *= ratio ** (1 / modules) + self.lokr_w1_b *= ratio ** (1 / modules) + + if self.use_w2: + self.lokr_w2 *= ratio ** (1 / modules) + else: + if self.tucker: + self.lokr_t2 *= ratio ** (1 / modules) + self.lokr_w2_a *= ratio ** (1 / modules) + self.lokr_w2_b *= ratio ** (1 / modules) + + return scaled, orig_norm * ratio + + def bypass_forward_diff(self, h, scale=1): + is_conv = self.module_type.startswith("conv") + if self.use_w2: + ba = self.lokr_w2 + else: + a = self.lokr_w2_b + b = self.lokr_w2_a + + if self.tucker: + t = self.lokr_t2 + a = a.view(*a.shape, *[1] * (len(t.shape) - 2)) + b = b.view(*b.shape, *[1] * (len(t.shape) - 2)) + elif is_conv: + a = a.view(*a.shape, *self.shape[2:]) + b = b.view(*b.shape, *[1] * (len(self.shape) - 2)) + + if self.use_w1: + c = self.lokr_w1 + else: + c = self.lokr_w1_a @ self.lokr_w1_b + uq = c.size(1) + + if is_conv: + # (b, uq), vq, ... + b, _, *rest = h.shape + h_in_group = h.reshape(b * uq, -1, *rest) + else: + # b, ..., uq, vq + h_in_group = h.reshape(*h.shape[:-1], uq, -1) + + if self.use_w2: + hb = self.op(h_in_group, ba, **self.kw_dict) + else: + if is_conv: + if self.tucker: + ha = self.op(h_in_group, a) + ht = self.op(ha, t, **self.kw_dict) + hb = self.op(ht, b) + else: + ha = self.op(h_in_group, a, **self.kw_dict) + hb = self.op(ha, b) + else: + ha = self.op(h_in_group, a, **self.kw_dict) + hb = self.op(ha, b) + + if is_conv: + # (b, uq), vp, ..., f + # -> b, uq, vp, ..., f + # -> b, f, vp, ..., uq + hb = hb.view(b, -1, *hb.shape[1:]) + h_cross_group = hb.transpose(1, -1) + else: + # b, ..., uq, vq + # -> b, ..., vq, uq + h_cross_group = hb.transpose(-1, -2) + + hc = F.linear(h_cross_group, c) + if is_conv: + # b, f, vp, ..., up + # -> b, up, vp, ... ,f + # -> b, c, ..., f + hc = hc.transpose(1, -1) + h = hc.reshape(b, -1, *hc.shape[3:]) + else: + # b, ..., vp, up + # -> b, ..., up, vp + # -> b, ..., c + hc = hc.transpose(-1, -2) + h = hc.reshape(*hc.shape[:-2], -1) + + return self.drop(h * scale * self.scalar) + + def bypass_forward(self, x, scale=1): + return self.org_forward(x) + self.bypass_forward_diff(x, scale=scale) + + def forward(self, x: torch.Tensor, *args, **kwargs): + if self.module_dropout and self.training: + if torch.rand(1) < self.module_dropout: + return self.org_forward(x) + if self.bypass_mode: + return self.bypass_forward(x, self.multiplier) + else: + diff_weight = self.get_weight(self.shape).to(self.dtype) * self.scalar + weight = self.org_module[0].weight.data.to(self.dtype) + if self.wd: + weight = self.apply_weight_decompose( + weight + diff_weight, self.multiplier + ) + elif self.multiplier == 1: + weight = weight + diff_weight + else: + weight = weight + diff_weight * self.multiplier + bias = ( + None + if self.org_module[0].bias is None + else self.org_module[0].bias.data + ) + return self.op(x, weight, bias, **self.kw_dict) + + +if __name__ == "__main__": + base = nn.Conv2d(128, 128, 3, 1, 1) + net = LokrModule( + "", + base, + multiplier=1, + lora_dim=4, + alpha=1, + weight_decompose=False, + use_tucker=False, + use_scalar=False, + decompose_both=True, + ) + net.apply_to() + sd = net.state_dict() + for key in sd: + if key != "alpha": + sd[key] = torch.randn_like(sd[key]) + net.load_state_dict(sd) + + test_input = torch.randn(1, 128, 16, 16) + test_output = net(test_input) + print(test_output.shape) + + net2 = LokrModule( + "", + base, + multiplier=1, + lora_dim=4, + alpha=1, + weight_decompose=False, + use_tucker=False, + use_scalar=False, + bypass_mode=True, + decompose_both=True, + ) + net2.apply_to() + net2.load_state_dict(sd) + print(net2) + + test_output2 = net(test_input) + print(F.mse_loss(test_output, test_output2)) diff --git a/lycoris/modules/norms.py b/lycoris/modules/norms.py new file mode 100644 index 0000000000000000000000000000000000000000..6ab627a6573d7910a251997e258c358fd76f1a87 --- /dev/null +++ b/lycoris/modules/norms.py @@ -0,0 +1,161 @@ +import torch +import torch.nn as nn + +from .base import LycorisBaseModule +from ..logging import warning_once + + +class NormModule(LycorisBaseModule): + name = "norm" + support_module = { + "layernorm", + "groupnorm", + } + weight_list = ["w_norm", "b_norm"] + weight_list_det = ["w_norm"] + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + rank_dropout=0.0, + module_dropout=0.0, + rank_dropout_scale=False, + **kwargs, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__( + lora_name=lora_name, + org_module=org_module, + multiplier=multiplier, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + rank_dropout_scale=rank_dropout_scale, + **kwargs, + ) + if self.module_type == "unknown": + if not hasattr(org_module, "weight") or not hasattr(org_module, "_norm"): + warning_once(f"{type(org_module)} is not supported in Norm algo.") + self.not_supported = True + return + else: + self.dim = org_module.weight.numel() + self.not_supported = False + elif self.module_type not in self.support_module: + warning_once(f"{self.module_type} is not supported in Norm algo.") + self.not_supported = True + return + + self.w_norm = nn.Parameter(torch.zeros(self.dim)) + if hasattr(org_module, "bias"): + self.b_norm = nn.Parameter(torch.zeros(self.dim)) + if hasattr(org_module, "_norm"): + self.org_norm = org_module._norm + else: + self.org_norm = None + + @classmethod + def make_module_from_state_dict(cls, lora_name, orig_module, w_norm, b_norm): + module = cls( + lora_name, + orig_module, + 1, + ) + module.w_norm.copy_(w_norm) + if b_norm is not None: + module.b_norm.copy_(b_norm) + return module + + def make_weight(self, scale=1, device=None): + org_weight = self.org_module[0].weight.to(device, dtype=self.w_norm.dtype) + if hasattr(self.org_module[0], "bias"): + org_bias = self.org_module[0].bias.to(device, dtype=self.b_norm.dtype) + else: + org_bias = None + if self.rank_dropout and self.training: + drop = (torch.rand(self.dim, device=device) < self.rank_dropout).to( + self.w_norm.device + ) + if self.rank_dropout_scale: + drop /= drop.mean() + else: + drop = 1 + drop = ( + torch.rand(self.dim, device=device) < self.rank_dropout + if self.rank_dropout and self.training + else 1 + ) + weight = self.w_norm.to(device) * drop * scale + if org_bias is not None: + bias = self.b_norm.to(device) * drop * scale + return org_weight + weight, org_bias + bias if org_bias is not None else None + + def get_diff_weight(self, multiplier=1, shape=None, device=None): + if self.not_supported: + return 0, 0 + w = self.w_norm * multiplier + if device is not None: + w = w.to(device) + if shape is not None: + w = w.view(shape) + if self.b_norm is not None: + b = self.b_norm * multiplier + if device is not None: + b = b.to(device) + if shape is not None: + b = b.view(shape) + else: + b = None + return w, b + + def get_merged_weight(self, multiplier=1, shape=None, device=None): + if self.not_supported: + return None, None + diff_w, diff_b = self.get_diff_weight(multiplier, shape, device) + org_w = self.org_module[0].weight.to(device, dtype=self.w_norm.dtype) + weight = org_w + diff_w + if diff_b is not None: + org_b = self.org_module[0].bias.to(device, dtype=self.b_norm.dtype) + bias = org_b + diff_b + else: + bias = None + return weight, bias + + def forward(self, x): + if self.not_supported or ( + self.module_dropout + and self.training + and torch.rand(1) < self.module_dropout + ): + return self.org_forward(x) + scale = self.multiplier + + w, b = self.make_weight(scale, x.device) + if self.org_norm is not None: + normed = self.org_norm(x) + scaled = normed * w + if b is not None: + scaled += b + return scaled + + kw_dict = self.kw_dict | {"weight": w, "bias": b} + return self.op(x, **kw_dict) + + +if __name__ == "__main__": + base = nn.LayerNorm(128).cuda() + norm = NormModule("test", base, 1).cuda() + print(norm) + test_input = torch.randn(1, 128).cuda() + test_output = norm(test_input) + torch.sum(test_output).backward() + print(test_output.shape) + + base = nn.GroupNorm(4, 128).cuda() + norm = NormModule("test", base, 1).cuda() + print(norm) + test_input = torch.randn(1, 128, 3, 3).cuda() + test_output = norm(test_input) + torch.sum(test_output).backward() + print(test_output.shape) diff --git a/lycoris/utils/__init__.py b/lycoris/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1acf0e129e5168ba917a83707c66507898f76e3 --- /dev/null +++ b/lycoris/utils/__init__.py @@ -0,0 +1,483 @@ +import re +import hashlib +from io import BytesIO +from typing import Dict, Tuple, Union + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.linalg as linalg + +import safetensors.torch + +from tqdm import tqdm +from .general import * + + +def load_bytes_in_safetensors(tensors): + bytes = safetensors.torch.save(tensors) + b = BytesIO(bytes) + + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + + offset = n + 8 + b.seek(offset) + + return b.read() + + +def precalculate_safetensors_hashes(state_dict): + # calculate each tensor one by one to reduce memory usage + hash_sha256 = hashlib.sha256() + for tensor in state_dict.values(): + single_tensor_sd = {"tensor": tensor} + bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd) + hash_sha256.update(bytes_for_tensor) + + return f"0x{hash_sha256.hexdigest()}" + + +def str_bool(val): + return str(val).lower() != "false" + + +def default(val, d): + return val if val is not None else d + + +def make_sparse(t: torch.Tensor, sparsity=0.95): + abs_t = torch.abs(t) + np_array = abs_t.detach().cpu().numpy() + quan = float(np.quantile(np_array, sparsity)) + sparse_t = t.masked_fill(abs_t < quan, 0) + return sparse_t + + +def extract_conv( + weight: Union[torch.Tensor, nn.Parameter], + mode="fixed", + mode_param=0, + device="cpu", + is_cp=False, +) -> Tuple[nn.Parameter, nn.Parameter]: + weight = weight.to(device) + out_ch, in_ch, kernel_size, _ = weight.shape + + U, S, Vh = linalg.svd(weight.reshape(out_ch, -1)) + + if mode == "full": + return weight, "full" + elif mode == "fixed": + lora_rank = mode_param + elif mode == "threshold": + assert mode_param >= 0 + lora_rank = torch.sum(S > mode_param) + elif mode == "ratio": + assert 1 >= mode_param >= 0 + min_s = torch.max(S) * mode_param + lora_rank = torch.sum(S > min_s) + elif mode == "quantile" or mode == "percentile": + assert 1 >= mode_param >= 0 + s_cum = torch.cumsum(S, dim=0) + min_cum_sum = mode_param * torch.sum(S) + lora_rank = torch.sum(s_cum < min_cum_sum) + else: + raise NotImplementedError( + 'Extract mode should be "fixed", "threshold", "ratio" or "quantile"' + ) + lora_rank = max(1, lora_rank) + lora_rank = min(out_ch, in_ch, lora_rank) + if lora_rank >= out_ch / 2 and not is_cp: + return weight, "full" + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S).to(device) + Vh = Vh[:lora_rank, :] + + diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach() + extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach() + extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach() + del U, S, Vh, weight + return (extract_weight_A, extract_weight_B, diff), "low rank" + + +def extract_linear( + weight: Union[torch.Tensor, nn.Parameter], + mode="fixed", + mode_param=0, + device="cpu", +) -> Tuple[nn.Parameter, nn.Parameter]: + weight = weight.to(device) + out_ch, in_ch = weight.shape + + U, S, Vh = linalg.svd(weight) + + if mode == "full": + return weight, "full" + elif mode == "fixed": + lora_rank = mode_param + elif mode == "threshold": + assert mode_param >= 0 + lora_rank = torch.sum(S > mode_param) + elif mode == "ratio": + assert 1 >= mode_param >= 0 + min_s = torch.max(S) * mode_param + lora_rank = torch.sum(S > min_s) + elif mode == "quantile" or mode == "percentile": + assert 1 >= mode_param >= 0 + s_cum = torch.cumsum(S, dim=0) + min_cum_sum = mode_param * torch.sum(S) + lora_rank = torch.sum(s_cum < min_cum_sum) + else: + raise NotImplementedError( + 'Extract mode should be "fixed", "threshold", "ratio" or "quantile"' + ) + lora_rank = max(1, lora_rank) + lora_rank = min(out_ch, in_ch, lora_rank) + if lora_rank >= out_ch / 2: + return weight, "full" + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S).to(device) + Vh = Vh[:lora_rank, :] + + diff = (weight - U @ Vh).detach() + extract_weight_A = Vh.reshape(lora_rank, in_ch).detach() + extract_weight_B = U.reshape(out_ch, lora_rank).detach() + del U, S, Vh, weight + return (extract_weight_A, extract_weight_B, diff), "low rank" + + +@torch.no_grad() +def extract_diff( + base_tes, + db_tes, + base_unet, + db_unet, + mode="fixed", + linear_mode_param=0, + conv_mode_param=0, + extract_device="cpu", + use_bias=False, + sparsity=0.98, + small_conv=True, +): + UNET_TARGET_REPLACE_MODULE = [ + "Linear", + "Conv2d", + "LayerNorm", + "GroupNorm", + "GroupNorm32", + ] + TEXT_ENCODER_TARGET_REPLACE_MODULE = [ + "Embedding", + "Linear", + "Conv2d", + "LayerNorm", + "GroupNorm", + "GroupNorm32", + ] + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + + def make_state_dict( + prefix, + root_module: torch.nn.Module, + target_module: torch.nn.Module, + target_replace_modules, + ): + loras = {} + temp = {} + + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + temp[name] = module + + for name, module in tqdm( + list((n, m) for n, m in target_module.named_modules() if n in temp) + ): + weights = temp[name] + lora_name = prefix + "." + name + lora_name = lora_name.replace(".", "_") + layer = module.__class__.__name__ + + if layer in { + "Linear", + "Conv2d", + "LayerNorm", + "GroupNorm", + "GroupNorm32", + "Embedding", + }: + root_weight = module.weight + if torch.allclose(root_weight, weights.weight): + continue + else: + continue + module = module.to(extract_device) + weights = weights.to(extract_device) + + if mode == "full": + decompose_mode = "full" + elif layer == "Linear": + weight, decompose_mode = extract_linear( + (root_weight - weights.weight), + mode, + linear_mode_param, + device=extract_device, + ) + if decompose_mode == "low rank": + extract_a, extract_b, diff = weight + elif layer == "Conv2d": + is_linear = root_weight.shape[2] == 1 and root_weight.shape[3] == 1 + weight, decompose_mode = extract_conv( + (root_weight - weights.weight), + mode, + linear_mode_param if is_linear else conv_mode_param, + device=extract_device, + ) + if decompose_mode == "low rank": + extract_a, extract_b, diff = weight + if small_conv and not is_linear and decompose_mode == "low rank": + dim = extract_a.size(0) + (extract_c, extract_a, _), _ = extract_conv( + extract_a.transpose(0, 1), + "fixed", + dim, + extract_device, + True, + ) + extract_a = extract_a.transpose(0, 1) + extract_c = extract_c.transpose(0, 1) + loras[f"{lora_name}.lora_mid.weight"] = ( + extract_c.detach().cpu().contiguous().half() + ) + diff = ( + ( + root_weight + - torch.einsum( + "i j k l, j r, p i -> p r k l", + extract_c, + extract_a.flatten(1, -1), + extract_b.flatten(1, -1), + ) + ) + .detach() + .cpu() + .contiguous() + ) + del extract_c + else: + module = module.to("cpu") + weights = weights.to("cpu") + continue + + if decompose_mode == "low rank": + loras[f"{lora_name}.lora_down.weight"] = ( + extract_a.detach().cpu().contiguous().half() + ) + loras[f"{lora_name}.lora_up.weight"] = ( + extract_b.detach().cpu().contiguous().half() + ) + loras[f"{lora_name}.alpha"] = torch.Tensor([extract_a.shape[0]]).half() + if use_bias: + diff = diff.detach().cpu().reshape(extract_b.size(0), -1) + sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce() + + indices = sparse_diff.indices().to(torch.int16) + values = sparse_diff.values().half() + loras[f"{lora_name}.bias_indices"] = indices + loras[f"{lora_name}.bias_values"] = values + loras[f"{lora_name}.bias_size"] = torch.tensor(diff.shape).to( + torch.int16 + ) + del extract_a, extract_b, diff + elif decompose_mode == "full": + if "Norm" in layer: + w_key = "w_norm" + b_key = "b_norm" + else: + w_key = "diff" + b_key = "diff_b" + weight_diff = module.weight - weights.weight + loras[f"{lora_name}.{w_key}"] = ( + weight_diff.detach().cpu().contiguous().half() + ) + if getattr(weights, "bias", None) is not None: + bias_diff = module.bias - weights.bias + loras[f"{lora_name}.{b_key}"] = ( + bias_diff.detach().cpu().contiguous().half() + ) + else: + raise NotImplementedError + module = module.to("cpu") + weights = weights.to("cpu") + return loras + + all_loras = {} + + all_loras |= make_state_dict( + LORA_PREFIX_UNET, + base_unet, + db_unet, + UNET_TARGET_REPLACE_MODULE, + ) + del base_unet, db_unet + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + for idx, (te1, te2) in enumerate(zip(base_tes, db_tes)): + if len(base_tes) > 1: + prefix = f"{LORA_PREFIX_TEXT_ENCODER}{idx+1}" + else: + prefix = LORA_PREFIX_TEXT_ENCODER + all_loras |= make_state_dict( + prefix, + te1, + te2, + TEXT_ENCODER_TARGET_REPLACE_MODULE, + ) + del te1, te2 + + all_lora_name = set() + for k in all_loras: + lora_name, weight = k.rsplit(".", 1) + all_lora_name.add(lora_name) + print(len(all_lora_name)) + return all_loras + + +re_digits = re.compile(r"\d+") +re_compiled = {} + +suffix_conversion = { + "attentions": {}, + "resnets": { + "conv1": "in_layers_2", + "conv2": "out_layers_3", + "norm1": "in_layers_0", + "norm2": "out_layers_0", + "time_emb_proj": "emb_layers_1", + "conv_shortcut": "skip_connection", + }, +} + + +def convert_diffusers_name_to_compvis(key): + def match(match_list, regex_text): + regex = re_compiled.get(regex_text) + if regex is None: + regex = re.compile(regex_text) + re_compiled[regex_text] = regex + + r = re.match(regex, key) + if not r: + return False + + match_list.clear() + match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) + return True + + m = [] + + if match(m, r"lora_unet_conv_in(.*)"): + return f"lora_unet_input_blocks_0_0{m[0]}" + + if match(m, r"lora_unet_conv_out(.*)"): + return f"lora_unet_out_2{m[0]}" + + if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"): + return f"lora_unet_time_embed_{m[0] * 2 - 2}{m[1]}" + + if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"lora_unet_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) + return ( + f"lora_unet_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" + ) + + if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"lora_unet_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): + return f"lora_unet_input_blocks_{3 + m[0] * 3}_0_op" + + if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): + return f"lora_unet_output_blocks_{2 + m[0] * 3}_2_conv" + return key + + +@torch.no_grad() +def merge(tes, unet, lyco_state_dict, scale: float = 1.0, device="cpu"): + from ..modules import make_module, get_module + + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + merged = 0 + + def merge_state_dict( + prefix, + root_module: torch.nn.Module, + lyco_state_dict: Dict[str, torch.Tensor], + ): + nonlocal merged + for child_name, child_module in tqdm( + list(root_module.named_modules()), desc=f"Merging {prefix}" + ): + lora_name = prefix + "." + child_name + lora_name = lora_name.replace(".", "_") + + lyco_type, params = get_module(lyco_state_dict, lora_name) + if lyco_type is None: + continue + module = make_module(lyco_type, params, lora_name, child_module) + if module is None: + continue + module.to(device) + module.merge_to(scale) + key_dict.pop(convert_diffusers_name_to_compvis(lora_name), None) + key_dict.pop(lora_name, None) + merged += 1 + + key_dict = {} + for k, v in tqdm(list(lyco_state_dict.items()), desc="Converting Dtype and Device"): + module, weight_key = k.split(".", 1) + convert_key = convert_diffusers_name_to_compvis(module) + if convert_key != module and len(tes) > 1: + # kohya's format for sdxl is as same as SGM, not diffusers + del lyco_state_dict[k] + key_dict[convert_key] = key_dict.get(convert_key, []) + [k] + k = f"{convert_key}.{weight_key}" + else: + key_dict[module] = key_dict.get(module, []) + [k] + lyco_state_dict[k] = v.float().cpu() + + for idx, te in enumerate(tes): + if len(tes) > 1: + prefix = LORA_PREFIX_TEXT_ENCODER + str(idx + 1) + else: + prefix = LORA_PREFIX_TEXT_ENCODER + merge_state_dict( + prefix, + te, + lyco_state_dict, + ) + torch.cuda.empty_cache() + merge_state_dict( + LORA_PREFIX_UNET, + unet, + lyco_state_dict, + ) + torch.cuda.empty_cache() + print(f"Unused state dict key: {key_dict}") + print(f"{merged} Modules been merged") diff --git a/lycoris/utils/general.py b/lycoris/utils/general.py new file mode 100644 index 0000000000000000000000000000000000000000..ccfaaa133f3c21b64637dedb4277a716f802a96e --- /dev/null +++ b/lycoris/utils/general.py @@ -0,0 +1,5 @@ +def product(xs: list[int | float]): + res = 1 + for x in xs: + res *= x + return res diff --git a/lycoris/utils/logger.py b/lycoris/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..e3f51e575e982546c6d891b7511743514548b54e --- /dev/null +++ b/lycoris/utils/logger.py @@ -0,0 +1,35 @@ +import logging +import copy +import sys + + +class ColoredFormatter(logging.Formatter): + COLORS = { + "DEBUG": "\033[0;36m", # CYAN + "INFO": "\033[0;32m", # GREEN + "WARNING": "\033[0;33m", # YELLOW + "ERROR": "\033[0;31m", # RED + "CRITICAL": "\033[0;37;41m", # WHITE ON RED + "RESET": "\033[0m", # RESET COLOR + } + + def format(self, record): + colored_record = copy.copy(record) + levelname = colored_record.levelname + seq = self.COLORS.get(levelname, self.COLORS["RESET"]) + colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" + return super().format(colored_record) + + +# Create a new logger +logger = logging.getLogger("LyCORIS") +logger.propagate = False + +# Add handler if we don't have one. +if not logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s")) + logger.addHandler(handler) + +logger.setLevel(logging.DEBUG) +logger.debug("Logger initialized.") diff --git a/lycoris/utils/preset.py b/lycoris/utils/preset.py new file mode 100644 index 0000000000000000000000000000000000000000..f6ec844a1a7063e51da236255c8807d7624b0e7a --- /dev/null +++ b/lycoris/utils/preset.py @@ -0,0 +1,9 @@ +import toml + + +def read_preset(preset): + try: + return toml.load(preset) + except Exception as e: + print("Error: cannot read preset file. ", e) + return None diff --git a/lycoris/utils/quant.py b/lycoris/utils/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..aa4660e0a5d9a2b460f16c1beb6949e3734ae88d --- /dev/null +++ b/lycoris/utils/quant.py @@ -0,0 +1,88 @@ +from functools import cache + +SUPPORT_QUANT = False +try: + from bitsandbytes.nn import LinearNF4, Linear8bitLt, LinearFP4 + + SUPPORT_QUANT = True +except Exception: + import torch.nn as nn + + class LinearNF4(nn.Linear): + pass + + class Linear8bitLt(nn.Linear): + pass + + class LinearFP4(nn.Linear): + pass + + +try: + from quanto.nn import QLinear, QConv2d, QLayerNorm + + SUPPORT_QUANT = True +except Exception: + import torch.nn as nn + + class QLinear(nn.Linear): + pass + + class QConv2d(nn.Conv2d): + pass + + class QLayerNorm(nn.LayerNorm): + pass + + +try: + from optimum.quanto.nn import ( + QLinear as QLinearOpt, + QConv2d as QConv2dOpt, + QLayerNorm as QLayerNormOpt, + ) + + SUPPORT_QUANT = True +except Exception: + import torch.nn as nn + + class QLinearOpt(nn.Linear): + pass + + class QConv2dOpt(nn.Conv2d): + pass + + class QLayerNormOpt(nn.LayerNorm): + pass + + +from ..logging import logger + + +QuantLinears = ( + Linear8bitLt, + LinearFP4, + LinearNF4, + QLinear, + QConv2d, + QLayerNorm, + QLinearOpt, + QConv2dOpt, + QLayerNormOpt, +) + + +@cache +def log_bypass(): + return logger.warning( + "Using bnb/quanto/optimum-quanto with LyCORIS will enable force-bypass mode." + ) + + +@cache +def log_suspect(): + return logger.warning( + "Non-native Linear detected but bypass_mode is not set. " + "Automatically using force-bypass mode to avoid possible issues. " + "Please set bypass_mode=False explicitly if there are no quantized layers." + ) diff --git a/lycoris/utils/xformers_utils.py b/lycoris/utils/xformers_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..31f737ceafa7f6743100d3b64677cac08c1811fb --- /dev/null +++ b/lycoris/utils/xformers_utils.py @@ -0,0 +1,13 @@ +memory_efficient_attention = None +try: + import xformers +except Exception: + pass + +try: + from xformers.ops import memory_efficient_attention + + XFORMERS_AVAIL = True +except Exception: + memory_efficient_attention = None + XFORMERS_AVAIL = False diff --git a/lycoris/wrapper.py b/lycoris/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..3d53a11f3c045e89a1967d401ed352829ddf6655 --- /dev/null +++ b/lycoris/wrapper.py @@ -0,0 +1,640 @@ +# General LyCORIS wrapper based on kohya-ss/sd-scripts' style +import os +import fnmatch +import re +import logging + +from typing import Any, List + +import torch +import torch.nn as nn + +from .modules.locon import LoConModule +from .modules.loha import LohaModule +from .modules.lokr import LokrModule +from .modules.dylora import DyLoraModule +from .modules.glora import GLoRAModule +from .modules.norms import NormModule +from .modules.full import FullModule +from .modules.diag_oft import DiagOFTModule +from .modules.boft import ButterflyOFTModule +from .modules import get_module, make_module + +from .config import PRESET +from .utils.preset import read_preset +from .utils import str_bool +from .logging import logger + + +VALID_PRESET_KEYS = [ + "enable_conv", + "target_module", + "target_name", + "module_algo_map", + "name_algo_map", + "lora_prefix", + "use_fnmatch", + "unet_target_module", + "unet_target_name", + "text_encoder_target_module", + "text_encoder_target_name", + "exclude_name", +] + + +network_module_dict = { + "lora": LoConModule, + "locon": LoConModule, + "loha": LohaModule, + "lokr": LokrModule, + "dylora": DyLoraModule, + "glora": GLoRAModule, + "full": FullModule, + "diag-oft": DiagOFTModule, + "boft": ButterflyOFTModule, +} +deprecated_arg_dict = { + "disable_conv_cp": "use_tucker", + "use_cp": "use_tucker", + "use_conv_cp": "use_tucker", + "constrain": "constraint", +} + + +def create_lycoris(module, multiplier=1.0, linear_dim=4, linear_alpha=1, **kwargs): + for key, value in list(kwargs.items()): + if key in deprecated_arg_dict: + logger.warning( + f"{key} is deprecated. Please use {deprecated_arg_dict[key]} instead.", + stacklevel=2, + ) + kwargs[deprecated_arg_dict[key]] = value + if linear_dim is None: + linear_dim = 4 # default + conv_dim = int(kwargs.get("conv_dim", linear_dim) or linear_dim) + conv_alpha = float(kwargs.get("conv_alpha", linear_alpha) or linear_alpha) + dropout = float(kwargs.get("dropout", 0.0) or 0.0) + rank_dropout = float(kwargs.get("rank_dropout", 0.0) or 0.0) + module_dropout = float(kwargs.get("module_dropout", 0.0) or 0.0) + algo = (kwargs.get("algo", "lora") or "lora").lower() + use_tucker = str_bool( + not kwargs.get("disable_conv_cp", True) + or kwargs.get("use_conv_cp", False) + or kwargs.get("use_cp", False) + or kwargs.get("use_tucker", False) + ) + use_scalar = str_bool(kwargs.get("use_scalar", False)) + block_size = int(kwargs.get("block_size", 4) or 4) + train_norm = str_bool(kwargs.get("train_norm", False)) + constraint = float(kwargs.get("constraint", 0) or 0) + rescaled = str_bool(kwargs.get("rescaled", False)) + weight_decompose = str_bool(kwargs.get("dora_wd", False)) + wd_on_output = str_bool(kwargs.get("wd_on_output", False)) + full_matrix = str_bool(kwargs.get("full_matrix", False)) + bypass_mode = str_bool(kwargs.get("bypass_mode", None)) + unbalanced_factorization = str_bool(kwargs.get("unbalanced_factorization", False)) + + if unbalanced_factorization: + logger.info("Unbalanced factorization for LoKr is enabled") + + if bypass_mode: + logger.info("Bypass mode is enabled") + + if weight_decompose: + logger.info("Weight decomposition is enabled") + + if full_matrix: + logger.info("Full matrix mode for LoKr is enabled") + + preset = kwargs.get("preset", "full") + if preset not in PRESET: + preset = read_preset(preset) + else: + preset = PRESET[preset] + assert preset is not None + LycorisNetwork.apply_preset(preset) + + logger.info(f"Using rank adaptation algo: {algo}") + + network = LycorisNetwork( + module, + multiplier=multiplier, + lora_dim=linear_dim, + conv_lora_dim=conv_dim, + alpha=linear_alpha, + conv_alpha=conv_alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + use_tucker=use_tucker, + use_scalar=use_scalar, + network_module=algo, + train_norm=train_norm, + decompose_both=kwargs.get("decompose_both", False), + factor=kwargs.get("factor", -1), + block_size=block_size, + constraint=constraint, + rescaled=rescaled, + weight_decompose=weight_decompose, + wd_on_out=wd_on_output, + full_matrix=full_matrix, + bypass_mode=bypass_mode, + unbalanced_factorization=unbalanced_factorization, + ) + + return network + + +def create_lycoris_from_weights(multiplier, file, module, weights_sd=None, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping + loras = {} + for key in weights_sd: + if "." not in key: + continue + + lora_name = key.split(".")[0] + loras[lora_name] = None + + for name, modules in module.named_modules(): + lora_name = f"{LycorisNetwork.LORA_PREFIX}_{name}".replace(".", "_") + if lora_name in loras: + loras[lora_name] = modules + + original_level = logger.level + logger.setLevel(logging.ERROR) + network = LycorisNetwork(module, init_only=True) + network.multiplier = multiplier + network.loras = [] + logger.setLevel(original_level) + + logger.info("Loading Modules from state dict...") + for lora_name, orig_modules in loras.items(): + if orig_modules is None: + continue + lyco_type, params = get_module(weights_sd, lora_name) + module = make_module(lyco_type, params, lora_name, orig_modules) + if module is not None: + network.loras.append(module) + network.algo_table[module.__class__.__name__] = ( + network.algo_table.get(module.__class__.__name__, 0) + 1 + ) + logger.info(f"{len(network.loras)} Modules Loaded") + + for lora in network.loras: + lora.multiplier = multiplier + + return network, weights_sd + + +class LycorisNetwork(torch.nn.Module): + ENABLE_CONV = True + TARGET_REPLACE_MODULE = [ + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "GroupNorm", + "LayerNorm", + ] + TARGET_REPLACE_NAME = [] + LORA_PREFIX = "lycoris" + MODULE_ALGO_MAP = {} + NAME_ALGO_MAP = {} + USE_FNMATCH = False + TARGET_EXCLUDE_NAME = [] + + @classmethod + def apply_preset(cls, preset): + for preset_key in preset.keys(): + if preset_key not in VALID_PRESET_KEYS: + raise KeyError( + f'Unknown preset key "{preset_key}". Valid keys: {VALID_PRESET_KEYS}' + ) + + if "enable_conv" in preset: + cls.ENABLE_CONV = preset["enable_conv"] + if "target_module" in preset: + cls.TARGET_REPLACE_MODULE = preset["target_module"] + if "target_name" in preset: + cls.TARGET_REPLACE_NAME = preset["target_name"] + if "module_algo_map" in preset: + cls.MODULE_ALGO_MAP = preset["module_algo_map"] + if "name_algo_map" in preset: + cls.NAME_ALGO_MAP = preset["name_algo_map"] + if "lora_prefix" in preset: + cls.LORA_PREFIX = preset["lora_prefix"] + if "use_fnmatch" in preset: + cls.USE_FNMATCH = preset["use_fnmatch"] + if "exclude_name" in preset: + cls.TARGET_EXCLUDE_NAME = preset["exclude_name"] + return cls + + def __init__( + self, + module: nn.Module, + multiplier=1.0, + lora_dim=4, + conv_lora_dim=4, + alpha=1, + conv_alpha=1, + use_tucker=False, + dropout=0, + rank_dropout=0, + module_dropout=0, + network_module: str = "locon", + norm_modules=NormModule, + train_norm=False, + init_only=False, + **kwargs, + ) -> None: + super().__init__() + root_kwargs = kwargs + self.weights_sd = None + if init_only: + self.multiplier = 1 + self.lora_dim = 0 + self.alpha = 1 + self.conv_lora_dim = 0 + self.conv_alpha = 1 + self.dropout = 0 + self.rank_dropout = 0 + self.module_dropout = 0 + self.use_tucker = False + self.loras = [] + self.algo_table = {} + return + self.multiplier = multiplier + self.lora_dim = lora_dim + + if not self.ENABLE_CONV: + conv_lora_dim = 0 + + self.conv_lora_dim = int(conv_lora_dim) + if self.conv_lora_dim and self.conv_lora_dim != self.lora_dim: + logger.info("Apply different lora dim for conv layer") + logger.info(f"Conv Dim: {conv_lora_dim}, Linear Dim: {lora_dim}") + elif self.conv_lora_dim == 0: + logger.info("Disable conv layer") + + self.alpha = alpha + self.conv_alpha = float(conv_alpha) + if self.conv_lora_dim and self.alpha != self.conv_alpha: + logger.info("Apply different alpha value for conv layer") + logger.info(f"Conv alpha: {conv_alpha}, Linear alpha: {alpha}") + + if 1 >= dropout >= 0: + logger.info(f"Use Dropout value: {dropout}") + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + self.use_tucker = use_tucker + + def create_single_module( + lora_name: str, + module: torch.nn.Module, + algo_name, + dim=None, + alpha=None, + use_tucker=self.use_tucker, + **kwargs, + ): + for k, v in root_kwargs.items(): + if k in kwargs: + continue + kwargs[k] = v + + if train_norm and "Norm" in module.__class__.__name__: + return norm_modules( + lora_name, + module, + self.multiplier, + self.rank_dropout, + self.module_dropout, + **kwargs, + ) + lora = None + if isinstance(module, torch.nn.Linear) and lora_dim > 0: + dim = dim or lora_dim + alpha = alpha or self.alpha + elif isinstance( + module, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d) + ): + k_size, *_ = module.kernel_size + if k_size == 1 and lora_dim > 0: + dim = dim or lora_dim + alpha = alpha or self.alpha + elif conv_lora_dim > 0 or dim: + dim = dim or conv_lora_dim + alpha = alpha or self.conv_alpha + else: + return None + else: + return None + lora = network_module_dict[algo_name]( + lora_name, + module, + self.multiplier, + dim, + alpha, + self.dropout, + self.rank_dropout, + self.module_dropout, + use_tucker, + **kwargs, + ) + return lora + + def create_modules_( + prefix: str, + root_module: torch.nn.Module, + algo, + current_lora_map: dict[str, Any], + configs={}, + ): + assert current_lora_map is not None, "No mapping supplied" + loras = current_lora_map + lora_names = [] + for name, module in root_module.named_modules(): + module_name = module.__class__.__name__ + if module_name in self.MODULE_ALGO_MAP and module is not root_module: + next_config = self.MODULE_ALGO_MAP[module_name] + next_algo = next_config.get("algo", algo) + new_loras, new_lora_names, new_lora_map = create_modules_( + f"{prefix}_{name}" if name else prefix, + module, + next_algo, + loras, + configs=next_config, + ) + loras = {**loras, **new_lora_map} + for lora_name, lora in zip(new_lora_names, new_loras): + if lora_name not in loras and lora_name not in current_lora_map: + loras[lora_name] = lora + if lora_name not in lora_names: + lora_names.append(lora_name) + continue + + if name: + lora_name = prefix + "." + name + else: + lora_name = prefix + + if f"{self.LORA_PREFIX}_." in lora_name: + lora_name = lora_name.replace( + f"{self.LORA_PREFIX}_.", + f"{self.LORA_PREFIX}.", + ) + + lora_name = lora_name.replace(".", "_") + if lora_name in loras: + continue + + lora = create_single_module(lora_name, module, algo, **configs) + if lora is not None: + loras[lora_name] = lora + lora_names.append(lora_name) + return [loras[lora_name] for lora_name in lora_names], lora_names, loras + + # create module instances + def create_modules( + prefix, + root_module: torch.nn.Module, + target_replace_modules, + target_replace_names=[], + target_exclude_names=[], + ) -> List: + logger.info("Create LyCORIS Module") + loras = [] + lora_map = {} + next_config = {} + for name, module in root_module.named_modules(): + if name in target_exclude_names or any( + self.match_fn(t, name) for t in target_exclude_names + ): + continue + + module_name = module.__class__.__name__ + if module_name in target_replace_modules and not any( + self.match_fn(t, name) for t in target_replace_names + ): + if module_name in self.MODULE_ALGO_MAP: + next_config = self.MODULE_ALGO_MAP[module_name] + algo = next_config.get("algo", network_module) + else: + algo = network_module + + lora_lst, _, _lora_map = create_modules_( + f"{prefix}_{name}", + module, + algo, + lora_map, + configs=next_config, + ) + lora_map = {**lora_map, **_lora_map} + loras.extend(lora_lst) + next_config = {} + elif name in target_replace_names or any( + self.match_fn(t, name) for t in target_replace_names + ): + conf_from_name = self.find_conf_for_name(name) + if conf_from_name is not None: + next_config = conf_from_name + algo = next_config.get("algo", network_module) + elif module_name in self.MODULE_ALGO_MAP: + next_config = self.MODULE_ALGO_MAP[module_name] + algo = next_config.get("algo", network_module) + else: + algo = network_module + lora_name = prefix + "." + name + lora_name = lora_name.replace(".", "_") + + if lora_name in lora_map: + continue + + lora = create_single_module(lora_name, module, algo, **next_config) + next_config = {} + if lora is not None: + lora_map[lora.lora_name] = lora + loras.append(lora) + return loras + + self.loras = create_modules( + LycorisNetwork.LORA_PREFIX, + module, + list( + set( + [ + *LycorisNetwork.TARGET_REPLACE_MODULE, + *LycorisNetwork.MODULE_ALGO_MAP.keys(), + ] + ) + ), + list( + set( + [ + *LycorisNetwork.TARGET_REPLACE_NAME, + *LycorisNetwork.NAME_ALGO_MAP.keys(), + ] + ) + ), + target_exclude_names=LycorisNetwork.TARGET_EXCLUDE_NAME, + ) + logger.info(f"create LyCORIS: {len(self.loras)} modules.") + + algo_table = {} + for lora in self.loras: + algo_table[lora.__class__.__name__] = ( + algo_table.get(lora.__class__.__name__, 0) + 1 + ) + logger.info(f"module type table: {algo_table}") + + # Assertion to ensure we have not accidentally wrapped some layers + # multiple times. + names = set() + for lora in self.loras: + assert ( + lora.lora_name not in names + ), f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def match_fn(self, pattern: str, name: str) -> bool: + if self.USE_FNMATCH: + return fnmatch.fnmatch(name, pattern) + return bool(re.match(pattern, name)) + + def find_conf_for_name( + self, + name: str, + ) -> dict[str, Any]: + if name in self.NAME_ALGO_MAP.keys(): + return self.NAME_ALGO_MAP[name] + + for key, value in self.NAME_ALGO_MAP.items(): + if self.match_fn(key, name): + return value + + return None + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.loras: + lora.multiplier = self.multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + self.weights_sd = load_file(file) + else: + self.weights_sd = torch.load(file, map_location="cpu") + missing, unexpected = self.load_state_dict(self.weights_sd, strict=False) + state = {} + if missing: + state["missing keys"] = missing + if unexpected: + state["unexpected keys"] = unexpected + return state + + def apply_to(self): + """ + Register to modules to the subclass so that torch sees them. + """ + for lora in self.loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + if self.weights_sd: + # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros) + info = self.load_state_dict(self.weights_sd, False) + logger.info(f"weights are loaded: {info}") + + def is_mergeable(self): + return True + + def restore(self): + for lora in self.loras: + lora.restore() + + def merge_to(self, weight=1.0): + for lora in self.loras: + lora.merge_to(weight) + + def apply_max_norm_regularization(self, max_norm_value, device): + key_scaled = 0 + norms = [] + for module in self.loras: + scaled, norm = module.apply_max_norm(max_norm_value, device) + if scaled is None: + continue + norms.append(norm) + key_scaled += scaled + + if key_scaled == 0: + return key_scaled, 0, 0 + + return key_scaled, sum(norms) / len(norms), max(norms) + + def enable_gradient_checkpointing(self): + # not supported + def make_ckpt(module): + if isinstance(module, torch.nn.Module): + module.grad_ckpt = True + + self.apply(make_ckpt) + pass + + def prepare_optimizer_params(self, lr): + def enumerate_params(loras): + params = [] + for lora in loras: + params.extend(lora.parameters()) + return params + + self.requires_grad_(True) + all_params = [] + + param_data = {"params": enumerate_params(self.loras)} + if lr is not None: + param_data["lr"] = lr + all_params.append(param_data) + return all_params + + def prepare_grad_etc(self, *args): + self.requires_grad_(True) + + def on_epoch_start(self, *args): + self.train() + + def get_trainable_params(self, *args): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) diff --git a/models.yaml b/models.yaml new file mode 100644 index 0000000000000000000000000000000000000000..81b9992c30cf38dfec566c98a839142b99036da9 --- /dev/null +++ b/models.yaml @@ -0,0 +1,27 @@ +# Add your own model here +#: +# repo: +# base: +# license: +# license_name: +# license_link: +# file: +flux-dev: + repo: cocktailpeanut/xulf-dev + base: black-forest-labs/FLUX.1-dev + license: other + license_name: flux-1-dev-non-commercial-license + license_link: https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md + file: flux1-dev.sft +flux-schnell: + repo: black-forest-labs/FLUX.1-schnell + base: black-forest-labs/FLUX.1-schnell + license: apache-2.0 + file: flux1-schnell.safetensors +bdsqlsz/flux1-dev2pro-single: + repo: bdsqlsz/flux1-dev2pro-single + base: black-forest-labs/FLUX.1-dev + license: other + license_name: flux-1-dev-non-commercial-license + license_link: https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md + file: flux1-dev2pro.safetensors diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..d432181baaa3f89b2679a5780e03be18d6be1801 --- /dev/null +++ b/networks/check_lora_weights.py @@ -0,0 +1,48 @@ +import argparse +import os +import torch +from safetensors.torch import load_file +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) + +def main(file): + logger.info(f"loading: {file}") + if os.path.splitext(file)[1] == ".safetensors": + sd = load_file(file) + else: + sd = torch.load(file, map_location="cpu") + + values = [] + + keys = list(sd.keys()) + for key in keys: + if "lora_up" in key or "lora_down" in key: + values.append((key, sd[key])) + print(f"number of LoRA modules: {len(values)}") + + if args.show_all_keys: + for key in [k for k in keys if k not in values]: + values.append((key, sd[key])) + print(f"number of all modules: {len(values)}") + + for key, value in values: + value = value.to(torch.float32) + print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル") + parser.add_argument("-s", "--show_all_keys", action="store_true", help="show all keys / 全てのキーを表示する") + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + + main(args.file) diff --git a/networks/lora.py b/networks/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..2e6ff1ddfd388c8244cfa4d457f1464407f3b18d --- /dev/null +++ b/networks/lora.py @@ -0,0 +1,1403 @@ +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel +import numpy as np +import torch +import re +from ..library.utils import setup_logging +from ..library.sdxl_original_unet import SdxlUNet2DConditionModel + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + # if limit_rank: + # self.lora_dim = min(lora_dim, in_dim, out_dim) + # if self.lora_dim != lora_dim: + # logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # else: + self.lora_dim = lora_dim + + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + lx = self.lora_down(x) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded + lx * self.multiplier * scale + + +class LoRAInfModule(LoRAModule): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) + + self.org_module_ref = [org_module] # 後から参照できるように + self.enabled = True + + # check regional or not by lora_name + self.text_encoder = False + if lora_name.startswith("lora_te_"): + self.regional = False + self.use_sub_prompt = True + self.text_encoder = True + elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name: + self.regional = False + self.use_sub_prompt = True + elif "time_emb" in lora_name: + self.regional = False + self.use_sub_prompt = False + else: + self.regional = True + self.use_sub_prompt = False + + self.network: LoRANetwork = None + + def set_network(self, network): + self.network = network + + # freezeしてマージする + def merge_to(self, sd, dtype, device): + # get up/down weight + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"].to(torch.float) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + # 復元できるマージのため、このモジュールのweightを返す + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + # get up/down weight from module + up_weight = self.lora_up.weight.to(torch.float) + down_weight = self.lora_down.weight.to(torch.float) + + # pre-calculated weight + if len(down_weight.size()) == 2: + # linear + weight = self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = self.multiplier * conved * self.scale + + return weight + + def set_region(self, region): + self.region = region + self.region_mask = None + + def default_forward(self, x): + # logger.info(f"default_forward {self.lora_name} {x.size()}") + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + + if self.network is None or self.network.sub_prompt_index is None: + return self.default_forward(x) + if not self.regional and not self.use_sub_prompt: + return self.default_forward(x) + + if self.regional: + return self.regional_forward(x) + else: + return self.sub_prompt_forward(x) + + def get_mask_for_x(self, x): + # calculate size from shape of x + if len(x.size()) == 4: + h, w = x.size()[2:4] + area = h * w + else: + area = x.size()[1] + + mask = self.network.mask_dic.get(area, None) + if mask is None or len(x.size()) == 2: + # emb_layers in SDXL doesn't have mask + # if "emb" not in self.lora_name: + # print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}") + mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1) + return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts + if len(x.size()) == 3: + mask = torch.reshape(mask, (1, -1, 1)) + return mask + + def regional_forward(self, x): + if "attn2_to_out" in self.lora_name: + return self.to_out_forward(x) + + if self.network.mask_dic is None: # sub_prompt_index >= 3 + return self.default_forward(x) + + # apply mask for LoRA result + lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + mask = self.get_mask_for_x(lx) + # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) + # if mask.ndim > lx.ndim: # in some resolution, lx is 2d and mask is 3d (the reason is not checked) + # mask = mask.squeeze(-1) + lx = lx * mask + + x = self.org_forward(x) + x = x + lx + + if "attn2_to_q" in self.lora_name and self.network.is_last_network: + x = self.postp_to_q(x) + + return x + + def postp_to_q(self, x): + # repeat x to num_sub_prompts + has_real_uncond = x.size()[0] // self.network.batch_size == 3 + qc = self.network.batch_size # uncond + qc += self.network.batch_size * self.network.num_sub_prompts # cond + if has_real_uncond: + qc += self.network.batch_size # real_uncond + + query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype) + query[: self.network.batch_size] = x[: self.network.batch_size] + + for i in range(self.network.batch_size): + qi = self.network.batch_size + i * self.network.num_sub_prompts + query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i] + + if has_real_uncond: + query[-self.network.batch_size :] = x[-self.network.batch_size :] + + # logger.info(f"postp_to_q {self.lora_name} {x.size()} {query.size()} {self.network.num_sub_prompts}") + return query + + def sub_prompt_forward(self, x): + if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA + return self.org_forward(x) + + emb_idx = self.network.sub_prompt_index + if not self.text_encoder: + emb_idx += self.network.batch_size + + # apply sub prompt of X + lx = x[emb_idx :: self.network.num_sub_prompts] + lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale + + # logger.info(f"sub_prompt_forward {self.lora_name} {x.size()} {lx.size()} {emb_idx}") + + x = self.org_forward(x) + x[emb_idx :: self.network.num_sub_prompts] += lx + + return x + + def to_out_forward(self, x): + # logger.info(f"to_out_forward {self.lora_name} {x.size()} {self.network.is_last_network}") + + if self.network.is_last_network: + masks = [None] * self.network.num_sub_prompts + self.network.shared[self.lora_name] = (None, masks) + else: + lx, masks = self.network.shared[self.lora_name] + + # call own LoRA + x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts] + lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale + + if self.network.is_last_network: + lx = torch.zeros( + (self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype + ) + self.network.shared[self.lora_name] = (lx, masks) + + # logger.info(f"to_out_forward {lx.size()} {lx1.size()} {self.network.sub_prompt_index} {self.network.num_sub_prompts}") + lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1 + masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1) + + # if not last network, return x and masks + x = self.org_forward(x) + if not self.network.is_last_network: + return x + + lx, masks = self.network.shared.pop(self.lora_name) + + # if last network, combine separated x with mask weighted sum + has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2 + + out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype) + out[: self.network.batch_size] = x[: self.network.batch_size] # uncond + if has_real_uncond: + out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond + + # logger.info(f"to_out_forward {self.lora_name} {self.network.sub_prompt_index} {self.network.num_sub_prompts}") + # if num_sub_prompts > num of LoRAs, fill with zero + for i in range(len(masks)): + if masks[i] is None: + masks[i] = torch.zeros_like(masks[0]) + + mask = torch.cat(masks) + mask_sum = torch.sum(mask, dim=0) + 1e-4 + for i in range(self.network.batch_size): + # 1枚の画像ごとに処理する + lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts] + lx1 = lx1 * mask + lx1 = torch.sum(lx1, dim=0) + + xi = self.network.batch_size + i * self.network.num_sub_prompts + x1 = x[xi : xi + self.network.num_sub_prompts] + x1 = x1 * mask + x1 = torch.sum(x1, dim=0) + x1 = x1 / mask_sum + + x1 = x1 + lx1 + out[self.network.batch_size + i] = x1 + + # logger.info(f"to_out_forward {x.size()} {out.size()} {has_real_uncond}") + return out + + +def parse_block_lr_kwargs(is_sdxl: bool, nw_kwargs: Dict) -> Optional[List[float]]: + down_lr_weight = nw_kwargs.get("down_lr_weight", None) + mid_lr_weight = nw_kwargs.get("mid_lr_weight", None) + up_lr_weight = nw_kwargs.get("up_lr_weight", None) + + # 以上のいずれにも設定がない場合は無効としてNoneを返す + if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None: + return None + + # extract learning rate weight for each block + if down_lr_weight is not None: + # if some parameters are not set, use zero + if "," in down_lr_weight: + down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] + + if mid_lr_weight is not None: + mid_lr_weight = [(float(s) if s else 0.0) for s in mid_lr_weight.split(",")] + + if up_lr_weight is not None: + if "," in up_lr_weight: + up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] + + return get_block_lr_weight( + is_sdxl, down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0)) + ) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: AutoencoderKL, + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], + unet, + neuron_dropout: Optional[float] = None, + **kwargs, +): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel) + + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # block dim/alpha/lr + block_dims = kwargs.get("block_dims", None) + block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs) + + # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする + if block_dims is not None or block_lr_weight is not None: + block_alphas = kwargs.get("block_alphas", None) + conv_block_dims = kwargs.get("conv_block_dims", None) + conv_block_alphas = kwargs.get("conv_block_alphas", None) + + block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas( + is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha + ) + + # remove block dim/alpha without learning rate + block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas( + is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight + ) + + else: + block_alphas = None + conv_block_dims = None + conv_block_alphas = None + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoder, + unet, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + block_dims=block_dims, + block_alphas=block_alphas, + conv_block_dims=conv_block_dims, + conv_block_alphas=conv_block_alphas, + varbose=True, + is_sdxl=is_sdxl, + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + if block_lr_weight is not None: + network.set_block_lr_weight(block_lr_weight) + + return network + + +# このメソッドは外部から呼び出される可能性を考慮しておく +# network_dim, network_alpha にはデフォルト値が入っている。 +# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている +# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている +def get_block_dims_and_alphas( + is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha +): + if not is_sdxl: + num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS + else: + # 1+9+3+9+1=23, no LoRA for emb_layers (0) + num_total_blocks = 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1 + + def parse_ints(s): + return [int(i) for i in s.split(",")] + + def parse_floats(s): + return [float(i) for i in s.split(",")] + + # block_dimsとblock_alphasをパースする。必ず値が入る + if block_dims is not None: + block_dims = parse_ints(block_dims) + assert len(block_dims) == num_total_blocks, ( + f"block_dims must have {num_total_blocks} elements but {len(block_dims)} elements are given" + + f" / block_dimsは{num_total_blocks}個指定してください(指定された個数: {len(block_dims)})" + ) + else: + logger.warning( + f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります" + ) + block_dims = [network_dim] * num_total_blocks + + if block_alphas is not None: + block_alphas = parse_floats(block_alphas) + assert ( + len(block_alphas) == num_total_blocks + ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください" + else: + logger.warning( + f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります" + ) + block_alphas = [network_alpha] * num_total_blocks + + # conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う + if conv_block_dims is not None: + conv_block_dims = parse_ints(conv_block_dims) + assert ( + len(conv_block_dims) == num_total_blocks + ), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください" + + if conv_block_alphas is not None: + conv_block_alphas = parse_floats(conv_block_alphas) + assert ( + len(conv_block_alphas) == num_total_blocks + ), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください" + else: + if conv_alpha is None: + conv_alpha = 1.0 + logger.warning( + f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります" + ) + conv_block_alphas = [conv_alpha] * num_total_blocks + else: + if conv_dim is not None: + logger.warning( + f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります" + ) + conv_block_dims = [conv_dim] * num_total_blocks + conv_block_alphas = [conv_alpha] * num_total_blocks + else: + conv_block_dims = None + conv_block_alphas = None + + return block_dims, block_alphas, conv_block_dims, conv_block_alphas + + +# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出せるようにclass外に出しておく +# 戻り値は block ごとの倍率のリスト +def get_block_lr_weight( + is_sdxl, + down_lr_weight: Union[str, List[float]], + mid_lr_weight: List[float], + up_lr_weight: Union[str, List[float]], + zero_threshold: float, +) -> Optional[List[float]]: + # パラメータ未指定時は何もせず、今までと同じ動作とする + if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None: + return None + + if not is_sdxl: + max_len_for_down_or_up = LoRANetwork.NUM_OF_BLOCKS + max_len_for_mid = LoRANetwork.NUM_OF_MID_BLOCKS + else: + max_len_for_down_or_up = LoRANetwork.SDXL_NUM_OF_BLOCKS + max_len_for_mid = LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + + def get_list(name_with_suffix) -> List[float]: + import math + + tokens = name_with_suffix.split("+") + name = tokens[0] + base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0 + + if name == "cosine": + return [ + math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr + for i in reversed(range(max_len_for_down_or_up)) + ] + elif name == "sine": + return [math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr for i in range(max_len_for_down_or_up)] + elif name == "linear": + return [i / (max_len_for_down_or_up - 1) + base_lr for i in range(max_len_for_down_or_up)] + elif name == "reverse_linear": + return [i / (max_len_for_down_or_up - 1) + base_lr for i in reversed(range(max_len_for_down_or_up))] + elif name == "zeros": + return [0.0 + base_lr] * max_len_for_down_or_up + else: + logger.error( + "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" + % (name) + ) + return None + + if type(down_lr_weight) == str: + down_lr_weight = get_list(down_lr_weight) + if type(up_lr_weight) == str: + up_lr_weight = get_list(up_lr_weight) + + if (up_lr_weight != None and len(up_lr_weight) > max_len_for_down_or_up) or ( + down_lr_weight != None and len(down_lr_weight) > max_len_for_down_or_up + ): + logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len_for_down_or_up) + logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len_for_down_or_up) + up_lr_weight = up_lr_weight[:max_len_for_down_or_up] + down_lr_weight = down_lr_weight[:max_len_for_down_or_up] + + if mid_lr_weight != None and len(mid_lr_weight) > max_len_for_mid: + logger.warning("mid_weight is too long. Parameters after %d-th are ignored." % max_len_for_mid) + logger.warning("mid_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len_for_mid) + mid_lr_weight = mid_lr_weight[:max_len_for_mid] + + if (up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up) or ( + down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up + ): + logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_down_or_up) + logger.warning( + "down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_down_or_up + ) + + if down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up: + down_lr_weight = down_lr_weight + [1.0] * (max_len_for_down_or_up - len(down_lr_weight)) + if up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up: + up_lr_weight = up_lr_weight + [1.0] * (max_len_for_down_or_up - len(up_lr_weight)) + + if mid_lr_weight != None and len(mid_lr_weight) < max_len_for_mid: + logger.warning("mid_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_mid) + logger.warning("mid_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_mid) + mid_lr_weight = mid_lr_weight + [1.0] * (max_len_for_mid - len(mid_lr_weight)) + + if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): + logger.info("apply block learning rate / 階層別学習率を適用します。") + if down_lr_weight != None: + down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] + logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}") + else: + down_lr_weight = [1.0] * max_len_for_down_or_up + logger.info("down_lr_weight: all 1.0, すべて1.0") + + if mid_lr_weight != None: + mid_lr_weight = [w if w > zero_threshold else 0 for w in mid_lr_weight] + logger.info(f"mid_lr_weight: {mid_lr_weight}") + else: + mid_lr_weight = [1.0] * max_len_for_mid + logger.info("mid_lr_weight: all 1.0, すべて1.0") + + if up_lr_weight != None: + up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] + logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}") + else: + up_lr_weight = [1.0] * max_len_for_down_or_up + logger.info("up_lr_weight: all 1.0, すべて1.0") + + lr_weight = down_lr_weight + mid_lr_weight + up_lr_weight + + if is_sdxl: + lr_weight = [1.0] + lr_weight + [1.0] # add 1.0 for emb_layers and out + + assert (not is_sdxl and len(lr_weight) == LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS) or ( + is_sdxl and len(lr_weight) == 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1 + ), f"lr_weight length is invalid: {len(lr_weight)}" + + return lr_weight + + +# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく +def remove_block_dims_and_alphas( + is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight: Optional[List[float]] +): + if block_lr_weight is not None: + for i, lr in enumerate(block_lr_weight): + if lr == 0: + block_dims[i] = 0 + if conv_block_dims is not None: + conv_block_dims[i] = 0 + return block_dims, block_alphas, conv_block_dims, conv_block_alphas + + +# 外部から呼び出す可能性を考慮しておく +def get_block_index(lora_name: str, is_sdxl: bool = False) -> int: + block_idx = -1 # invalid lora name + if not is_sdxl: + m = RE_UPDOWN.search(lora_name) + if m: + g = m.groups() + i = int(g[1]) + j = int(g[3]) + if g[2] == "resnets": + idx = 3 * i + j + elif g[2] == "attentions": + idx = 3 * i + j + elif g[2] == "upsamplers" or g[2] == "downsamplers": + idx = 3 * i + 2 + + if g[0] == "down": + block_idx = 1 + idx # 0に該当するLoRAは存在しない + elif g[0] == "up": + block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx + elif "mid_block_" in lora_name: + block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12 + else: + # copy from sdxl_train + if lora_name.startswith("lora_unet_"): + name = lora_name[len("lora_unet_") :] + if name.startswith("time_embed_") or name.startswith("label_emb_"): # No LoRA + block_idx = 0 # 0 + elif name.startswith("input_blocks_"): # 1-9 + block_idx = 1 + int(name.split("_")[2]) + elif name.startswith("middle_block_"): # 10-12 + block_idx = 10 + int(name.split("_")[2]) + elif name.startswith("output_blocks_"): # 13-21 + block_idx = 13 + int(name.split("_")[2]) + elif name.startswith("out_"): # 22, out, no LoRA + block_idx = 22 + + return block_idx + + +def convert_diffusers_to_sai_if_needed(weights_sd): + # only supports U-Net LoRA modules + + found_up_down_blocks = False + for k in list(weights_sd.keys()): + if "down_blocks" in k: + found_up_down_blocks = True + break + if "up_blocks" in k: + found_up_down_blocks = True + break + if not found_up_down_blocks: + return + + from ..library.sdxl_model_util import make_unet_conversion_map + + unet_conversion_map = make_unet_conversion_map() + unet_conversion_map = {hf.replace(".", "_")[:-1]: sd.replace(".", "_")[:-1] for sd, hf in unet_conversion_map} + + # # add extra conversion + # unet_conversion_map["up_blocks_1_upsamplers_0"] = "lora_unet_output_blocks_2_2_conv" + + logger.info(f"Converting LoRA keys from Diffusers to SAI") + lora_unet_prefix = "lora_unet_" + for k in list(weights_sd.keys()): + if not k.startswith(lora_unet_prefix): + continue + + unet_module_name = k[len(lora_unet_prefix) :].split(".")[0] + + # search for conversion: this is slow because the algorithm is O(n^2), but the number of keys is small + for hf_module_name, sd_module_name in unet_conversion_map.items(): + if hf_module_name in unet_module_name: + new_key = ( + lora_unet_prefix + + unet_module_name.replace(hf_module_name, sd_module_name) + + k[len(lora_unet_prefix) + len(unet_module_name) :] + ) + weights_sd[new_key] = weights_sd.pop(k) + found = True + break + + if not found: + logger.warning(f"Key {k} is not found in unet_conversion_map") + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel) + + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # if keys are Diffusers based, convert to SAI based + convert_diffusers_to_sai_if_needed(weights_sd) + + # get dim/alpha mapping + modules_dim = {} + modules_alpha = {} + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + # support old LoRA without alpha + for key in modules_dim.keys(): + if key not in modules_alpha: + modules_alpha[key] = modules_dim[key] + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork( + text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class + ) + + # block lr + block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs) + if block_lr_weight is not None: + network.set_block_lr_weight(block_lr_weight) + + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 + NUM_OF_MID_BLOCKS = 1 + SDXL_NUM_OF_BLOCKS = 9 # SDXLのモデルでのinput/outputの層の数 total=1(base) 9(input) + 3(mid) + 9(output) + 1(out) = 23 + SDXL_NUM_OF_MID_BLOCKS = 3 + + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + + # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER + LORA_PREFIX_TEXT_ENCODER1 = "lora_te1" + LORA_PREFIX_TEXT_ENCODER2 = "lora_te2" + + def __init__( + self, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + block_dims: Optional[List[int]] = None, + block_alphas: Optional[List[float]] = None, + conv_block_dims: Optional[List[int]] = None, + conv_block_alphas: Optional[List[float]] = None, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + module_class: Type[object] = LoRAModule, + varbose: Optional[bool] = False, + is_sdxl: Optional[bool] = False, + ) -> None: + """ + LoRA network: すごく引数が多いが、パターンは以下の通り + 1. lora_dimとalphaを指定 + 2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定 + 3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない + 4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する + 5. modules_dimとmodules_alphaを指定 (推論用) + """ + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + elif block_dims is not None: + logger.info(f"create LoRA network from block_dims") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + logger.info(f"block_dims: {block_dims}") + logger.info(f"block_alphas: {block_alphas}") + if conv_block_dims is not None: + logger.info(f"conv_block_dims: {conv_block_dims}") + logger.info(f"conv_block_alphas: {conv_block_alphas}") + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + if self.conv_lora_dim is not None: + logger.info( + f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + ) + + # create module instances + def create_modules( + is_unet: bool, + text_encoder_idx: Optional[int], # None, 1, 2 + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_UNET + if is_unet + else ( + self.LORA_PREFIX_TEXT_ENCODER + if text_encoder_idx is None + else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2) + ) + ) + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + elif is_unet and block_dims is not None: + # U-Netでblock_dims指定あり + block_idx = get_block_index(lora_name, is_sdxl) + if is_linear or is_conv2d_1x1: + dim = block_dims[block_idx] + alpha = block_alphas[block_idx] + elif conv_block_dims is not None: + dim = conv_block_dims[block_idx] + alpha = conv_block_alphas[block_idx] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None): + skipped.append(lora_name) + continue + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + ) + loras.append(lora) + return loras, skipped + + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + if len(text_encoders) > 1: + index = i + 1 + logger.info(f"create LoRA for Text Encoder {index}:") + else: + index = None + logger.info(f"create LoRA for Text Encoder:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: + target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + + self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + skipped = skipped_te + skipped_un + if varbose and len(skipped) > 0: + logger.warning( + f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + self.block_lr_weight = None + self.block_lr = False + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info(f"weights are merged") + + # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない + def set_block_lr_weight(self, block_lr_weight: Optional[List[float]]): + self.block_lr = True + self.block_lr_weight = block_lr_weight + + def get_lr_weight(self, block_idx: int) -> float: + if not self.block_lr or self.block_lr_weight is None: + return 1.0 + return self.block_lr_weight[block_idx] + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + # TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?) + # if ( + # self.loraplus_lr_ratio is not None + # or self.loraplus_text_encoder_lr_ratio is not None + # or self.loraplus_unet_lr_ratio is not None + # ): + # assert ( + # optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower() + # ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません" + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, ratio): + param_groups = {"lora": {}, "plus": {}} + for lora in loras: + for name, param in lora.named_parameters(): + if ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions + + if self.text_encoder_loras: + params, descriptions = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + if self.block_lr: + is_sdxl = False + for lora in self.unet_loras: + if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name: + is_sdxl = True + break + + # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類 + block_idx_to_lora = {} + for lora in self.unet_loras: + idx = get_block_index(lora.lora_name, is_sdxl) + if idx not in block_idx_to_lora: + block_idx_to_lora[idx] = [] + block_idx_to_lora[idx].append(lora) + + # blockごとにパラメータを設定する + for idx, block_loras in block_idx_to_lora.items(): + params, descriptions = assemble_params( + block_loras, + (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx), + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions]) + + else: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from ..library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + # mask is a tensor with values from 0 to 1 + def set_region(self, sub_prompt_index, is_last_network, mask): + if mask.max() == 0: + mask = torch.ones_like(mask) + + self.mask = mask + self.sub_prompt_index = sub_prompt_index + self.is_last_network = is_last_network + + for lora in self.text_encoder_loras + self.unet_loras: + lora.set_network(self) + + def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared, ds_ratio=None): + self.batch_size = batch_size + self.num_sub_prompts = num_sub_prompts + self.current_size = (height, width) + self.shared = shared + + # create masks + mask = self.mask + mask_dic = {} + mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w + ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight + dtype = ref_weight.dtype + device = ref_weight.device + + def resize_add(mh, mw): + # logger.info(mh, mw, mh * mw) + m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16 + m = m.to(device, dtype=dtype) + mask_dic[mh * mw] = m + + h = height // 8 + w = width // 8 + for _ in range(4): + resize_add(h, w) + if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2 + resize_add(h + h % 2, w + w % 2) + + # deep shrink + if ds_ratio is not None: + hd = int(h * ds_ratio) + wd = int(w * ds_ratio) + resize_add(hd, wd) + + h = (h + 1) // 2 + w = (w + 1) // 2 + + self.mask_dic = mask_dic + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/networks/lora_flux.py b/networks/lora_flux.py new file mode 100644 index 0000000000000000000000000000000000000000..5fb863332f521af45ebd40eda4c19b619987131b --- /dev/null +++ b/networks/lora_flux.py @@ -0,0 +1,1032 @@ +# temporary minimum implementation of LoRA +# FLUX doesn't have Conv2d, so we ignore it +# TODO commonize with the original implementation + +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel +import numpy as np +import torch +import re +#from ..library.utils import setup_logging + +#setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + split_dims: Optional[List[int]] = None, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + self.lora_dim = lora_dim + self.split_dims = split_dims + + if split_dims is None: + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + else: + # conv2d not supported + assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim" + assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear" + # print(f"split_dims: {split_dims}") + self.lora_down = torch.nn.ModuleList( + [torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))] + ) + self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims]) + for lora_down in self.lora_down: + torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) + for lora_up in self.lora_up: + torch.nn.init.zeros_(lora_up.weight) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + if self.split_dims is None: + lx = self.lora_down(x) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded + lx * self.multiplier * scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + + # normal dropout + if self.dropout is not None and self.training: + lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs] + + # rank dropout + if self.rank_dropout is not None and self.training: + masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs] + for i in range(len(lxs)): + if len(lx.size()) == 3: + masks[i] = masks[i].unsqueeze(1) + elif len(lx.size()) == 4: + masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1) + lxs[i] = lxs[i] * masks[i] + + # scaling for rank dropout: treat as if the rank is changed + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + + return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale + + +class LoRAInfModule(LoRAModule): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) + + self.org_module_ref = [org_module] # 後から参照できるように + self.enabled = True + self.network: LoRANetwork = None + + def set_network(self, network): + self.network = network + + # freezeしてマージする + def merge_to(self, sd, dtype, device): + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"] + org_dtype = weight.dtype + org_device = weight.device + weight = weight.to(torch.float) # calc in float + + if dtype is None: + dtype = org_dtype + if device is None: + device = org_device + + if self.split_dims is None: + # get up/down weight + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + else: + # split_dims + total_dims = sum(self.split_dims) + for i in range(len(self.split_dims)): + # get up/down weight + down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim) + up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank) + + # pad up_weight -> (total_dims, rank) + padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float) + padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight + + # merge weight + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + # 復元できるマージのため、このモジュールのweightを返す + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + # get up/down weight from module + up_weight = self.lora_up.weight.to(torch.float) + down_weight = self.lora_down.weight.to(torch.float) + + # pre-calculated weight + if len(down_weight.size()) == 2: + # linear + weight = self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = self.multiplier * conved * self.scale + + return weight + + def set_region(self, region): + self.region = region + self.region_mask = None + + def default_forward(self, x): + # logger.info(f"default_forward {self.lora_name} {x.size()}") + if self.split_dims is None: + lx = self.lora_down(x) + lx = self.lora_up(lx) + return self.org_forward(x) + lx * self.multiplier * self.scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + return self.default_forward(x) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + ae: AutoencoderKL, + text_encoders: List[CLIPTextModel], + flux, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # single or double blocks + train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double" + if train_blocks is not None: + assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}" + + only_if_contains = kwargs.get("only_if_contains", None) + if only_if_contains is not None: + only_if_contains = [word.strip() for word in only_if_contains.split(',')] + + # split qkv + split_qkv = kwargs.get("split_qkv", False) + if split_qkv is not None: + split_qkv = True if split_qkv == "True" else False + + # train T5XXL + train_t5xxl = kwargs.get("train_t5xxl", False) + if train_t5xxl is not None: + train_t5xxl = True if train_t5xxl == "True" else False + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoders, + flux, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + train_blocks=train_blocks, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + varbose=True, + only_if_contains=only_if_contains + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping, and train t5xxl + modules_dim = {} + modules_alpha = {} + train_t5xxl = None + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + if train_t5xxl is None or train_t5xxl is False: + train_t5xxl = "lora_te3" in lora_name + + if train_t5xxl is None: + train_t5xxl = False + + # # split qkv + # double_qkv_rank = None + # single_qkv_rank = None + # rank = None + # for lora_name, dim in modules_dim.items(): + # if "double" in lora_name and "qkv" in lora_name: + # double_qkv_rank = dim + # elif "single" in lora_name and "linear1" in lora_name: + # single_qkv_rank = dim + # elif rank is None: + # rank = dim + # if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None: + # break + # split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or ( + # single_qkv_rank is not None and single_qkv_rank != rank + # ) + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork( + text_encoders, + flux, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + ) + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + # FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"] + LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible + LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" + LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible + + def __init__( + self, + text_encoders: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + module_class: Type[object] = LoRAModule, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + train_blocks: Optional[str] = None, + split_qkv: bool = False, + train_t5xxl: bool = False, + varbose: Optional[bool] = False, + only_if_contains: Optional[List[str]] = None, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.train_blocks = train_blocks if train_blocks is not None else "all" + self.split_qkv = split_qkv + self.train_t5xxl = train_t5xxl + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + self.only_if_contains = only_if_contains + + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + # if self.conv_lora_dim is not None: + # logger.info( + # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + # ) + if self.split_qkv: + logger.info(f"split qkv for LoRA") + if self.train_blocks is not None: + logger.info(f"train {self.train_blocks} blocks only") + if train_t5xxl: + logger.info(f"train T5XXL as well") + + #self.only_if_contains = ["lora_unet_single_blocks_20_linear2"] + + # create module instances + def create_modules( + is_flux: bool, text_encoder_idx: Optional[int], root_module: torch.nn.Module, target_replace_modules: List[str] + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_FLUX + if is_flux + else (self.LORA_PREFIX_TEXT_ENCODER_CLIP if text_encoder_idx == 0 else self.LORA_PREFIX_TEXT_ENCODER_T5) + ) + + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + #lora_unet_single_blocks_20_linear2 + + if "unet" in lora_name and (self.only_if_contains is not None and not any(word in lora_name for word in self.only_if_contains)): + continue + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): + skipped.append(lora_name) + continue + + # qkv split + split_dims = None + if is_flux and split_qkv: + if "double" in lora_name and "qkv" in lora_name: + split_dims = [3072] * 3 + elif "single" in lora_name and "linear1" in lora_name: + split_dims = [3072] * 3 + [12288] + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + split_dims=split_dims, + ) + loras.append(lora) + return loras, skipped + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + index = i + if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False + break + + logger.info(f"create LoRA for Text Encoder {index+1}:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.") + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + + # create LoRA for U-Net + if self.train_blocks == "all": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE + elif self.train_blocks == "single": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE + elif self.train_blocks == "double": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] + self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules) + logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.") + #print(self.unet_loras) + + skipped = skipped_te + skipped_un + if varbose and len(skipped) > 0: + logger.warning( + f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def load_state_dict(self, state_dict, strict=True): + # override to convert original weight to split qkv + if not self.split_qkv: + return super().load_state_dict(state_dict, strict) + + # split qkv + for key in list(state_dict.keys()): + if "double" in key and "qkv" in key: + split_dims = [3072] * 3 + elif "single" in key and "linear1" in key: + split_dims = [3072] * 3 + [12288] + else: + continue + + weight = state_dict[key] + lora_name = key.split(".")[0] + if "lora_down" in key and "weight" in key: + # dense weight (rank*3, in_dim) + split_weight = torch.chunk(weight, len(split_dims), dim=0) + for i, split_w in enumerate(split_weight): + state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w + + del state_dict[key] + # print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}") + elif "lora_up" in key and "weight" in key: + # sparse weight (out_dim=sum(split_dims), rank*3) + rank = weight.size(1) // len(split_dims) + i = 0 + for j in range(len(split_dims)): + state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dims[j], j * rank : (j + 1) * rank] + i += split_dims[j] + del state_dict[key] + + # # check is sparse + # i = 0 + # is_zero = True + # for j in range(len(split_dims)): + # for k in range(len(split_dims)): + # if j == k: + # continue + # is_zero = is_zero and torch.all(weight[i : i + split_dims[j], k * rank : (k + 1) * rank] == 0) + # i += split_dims[j] + # if not is_zero: + # logger.warning(f"weight is not sparse: {key}") + # else: + # logger.info(f"weight is sparse: {key}") + + # print( + # f"split {key}: {weight.shape} to {[state_dict[k].shape for k in [f'{lora_name}.lora_up.{j}.weight' for j in range(len(split_dims))]]}" + # ) + + # alpha is unchanged + + return super().load_state_dict(state_dict, strict) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if not self.split_qkv: + return super().state_dict(destination, prefix, keep_vars) + + # merge qkv + state_dict = super().state_dict(destination, prefix, keep_vars) + new_state_dict = {} + for key in list(state_dict.keys()): + if "double" in key and "qkv" in key: + split_dims = [3072] * 3 + elif "single" in key and "linear1" in key: + split_dims = [3072] * 3 + [12288] + else: + new_state_dict[key] = state_dict[key] + continue + + if key not in state_dict: + continue # already merged + + lora_name = key.split(".")[0] + + # (rank, in_dim) * 3 + down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))] + # (split dim, rank) * 3 + up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))] + + alpha = state_dict.pop(f"{lora_name}.alpha") + + # merge down weight + down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # merge up weight (sum of split_dim, rank*3) + rank = up_weights[0].size(1) + up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + i = 0 + for j in range(len(split_dims)): + up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j] + i += split_dims[j] + + new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight + new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight + new_state_dict[f"{lora_name}.alpha"] = alpha + + # print( + # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # ) + print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return new_state_dict + + def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoders, flux, weights_sd, dtype=None, device=None): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP) or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_FLUX): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info(f"weights are merged") + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + # make sure text_encoder_lr as list of two elements + # if float, use the same value for both text encoders + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): + text_encoder_lr = [default_lr, default_lr] + elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int): + text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr)] + elif len(text_encoder_lr) == 1: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, loraplus_ratio): + param_groups = {"lora": {}, "plus": {}} + for lora in loras: + for name, param in lora.named_parameters(): + if loraplus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * loraplus_ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions + + if self.text_encoder_loras: + loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + + # split text encoder loras for te1 and te3 + te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP)] + te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)] + if len(te1_loras) > 0: + logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}") + params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te3_loras) > 0: + logger.info(f"Text Encoder 2 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[1]}") + params, descriptions = assemble_params(te3_loras, text_encoder_lr[1], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 2 " + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) + +def precalculate_safetensors_hashes(tensors, metadata): + """Precalculate the model hashes needed by sd-webui-additional-networks to + save time on indexing the model later.""" + + import hashlib + import safetensors.torch + from io import BytesIO + + # Retain only training metadata for hash calculation + metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} + bytes = safetensors.torch.save(tensors, metadata) + b = BytesIO(bytes) + + def addnet_hash_legacy(b): + """Old model hash used by sd-webui-additional-networks for .safetensors format files""" + m = hashlib.sha256() + b.seek(0x100000) + m.update(b.read(0x10000)) + return m.hexdigest()[0:8] + + def addnet_hash_safetensors(b): + """New model hash used by sd-webui-additional-networks for .safetensors format files""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) + return hash_sha256.hexdigest() + + model_hash = addnet_hash_safetensors(b) + legacy_hash = addnet_hash_legacy(b) + return model_hash, legacy_hash \ No newline at end of file diff --git a/networks/lora_sd3.py b/networks/lora_sd3.py new file mode 100644 index 0000000000000000000000000000000000000000..c4513eb227859af49c56afe175be1514a5501a4f --- /dev/null +++ b/networks/lora_sd3.py @@ -0,0 +1,837 @@ +# temporary minimum implementation of LoRA +# SD3 doesn't have Conv2d, so we ignore it +# TODO commonize with the original/SD3/FLUX implementation + +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from transformers import CLIPTextModelWithProjection, T5EncoderModel +import torch +from ..library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from .lora_flux import LoRAModule, LoRAInfModule +from ..library import sd3_models + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: sd3_models.SDVAE, + text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]], + mmdit, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv + context_attn_dim = kwargs.get("context_attn_dim", None) + context_mlp_dim = kwargs.get("context_mlp_dim", None) + context_mod_dim = kwargs.get("context_mod_dim", None) + x_attn_dim = kwargs.get("x_attn_dim", None) + x_mlp_dim = kwargs.get("x_mlp_dim", None) + x_mod_dim = kwargs.get("x_mod_dim", None) + if context_attn_dim is not None: + context_attn_dim = int(context_attn_dim) + if context_mlp_dim is not None: + context_mlp_dim = int(context_mlp_dim) + if context_mod_dim is not None: + context_mod_dim = int(context_mod_dim) + if x_attn_dim is not None: + x_attn_dim = int(x_attn_dim) + if x_mlp_dim is not None: + x_mlp_dim = int(x_mlp_dim) + if x_mod_dim is not None: + x_mod_dim = int(x_mod_dim) + type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim] + if all([d is None for d in type_dims]): + type_dims = None + + # emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear] + emb_dims = kwargs.get("emb_dims", None) + if emb_dims is not None: + emb_dims = emb_dims.strip() + if emb_dims.startswith("[") and emb_dims.endswith("]"): + emb_dims = emb_dims[1:-1] + emb_dims = [int(d) for d in emb_dims.split(",")] # is it better to use ast.literal_eval? + assert len(emb_dims) == 6, f"invalid emb_dims: {emb_dims}, must be 6 dimensions (context, t, x, y, final_mod, final_linear)" + + # double/single train blocks + def parse_block_selection(selection: str, total_blocks: int) -> List[bool]: + """ + Parse a block selection string and return a list of booleans. + + Args: + selection (str): A string specifying which blocks to select. + total_blocks (int): The total number of blocks available. + + Returns: + List[bool]: A list of booleans indicating which blocks are selected. + """ + if selection == "all": + return [True] * total_blocks + if selection == "none" or selection == "": + return [False] * total_blocks + + selected = [False] * total_blocks + ranges = selection.split(",") + + for r in ranges: + if "-" in r: + start, end = map(str.strip, r.split("-")) + start = int(start) + end = int(end) + assert 0 <= start < total_blocks, f"invalid start index: {start}" + assert 0 <= end < total_blocks, f"invalid end index: {end}" + assert start <= end, f"invalid range: {start}-{end}" + for i in range(start, end + 1): + selected[i] = True + else: + index = int(r) + assert 0 <= index < total_blocks, f"invalid index: {index}" + selected[index] = True + + return selected + + train_block_indices = kwargs.get("train_block_indices", None) + if train_block_indices is not None: + train_block_indices = parse_block_selection(train_block_indices, 999) # 999 is a dummy number + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # split qkv + split_qkv = kwargs.get("split_qkv", False) + if split_qkv is not None: + split_qkv = True if split_qkv == "True" else False + + # train T5XXL + train_t5xxl = kwargs.get("train_t5xxl", False) + if train_t5xxl is not None: + train_t5xxl = True if train_t5xxl == "True" else False + + # verbose + verbose = kwargs.get("verbose", False) + if verbose is not None: + verbose = True if verbose == "True" else False + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoders, + mmdit, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + type_dims=type_dims, + emb_dims=emb_dims, + train_block_indices=train_block_indices, + verbose=verbose, + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, ae, text_encoders, mmdit, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping, and train t5xxl + modules_dim = {} + modules_alpha = {} + train_t5xxl = None + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + if train_t5xxl is None or train_t5xxl is False: + train_t5xxl = "lora_te3" in lora_name + + if train_t5xxl is None: + train_t5xxl = False + + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork( + text_encoders, + mmdit, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + ) + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + SD3_TARGET_REPLACE_MODULE = ["SingleDiTBlock"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"] + LORA_PREFIX_SD3 = "lora_unet" # make ComfyUI compatible + LORA_PREFIX_TEXT_ENCODER_CLIP_L = "lora_te1" + LORA_PREFIX_TEXT_ENCODER_CLIP_G = "lora_te2" + LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible + + def __init__( + self, + text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]], + unet: sd3_models.MMDiT, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + module_class: Type[object] = LoRAModule, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + split_qkv: bool = False, + train_t5xxl: bool = False, + type_dims: Optional[List[int]] = None, + emb_dims: Optional[List[int]] = None, + train_block_indices: Optional[List[bool]] = None, + verbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.split_qkv = split_qkv + self.train_t5xxl = train_t5xxl + + self.type_dims = type_dims + self.emb_dims = emb_dims + self.train_block_indices = train_block_indices + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + self.emb_dims = [0] * 6 # create emb_dims + # verbose = True + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + # if self.conv_lora_dim is not None: + # logger.info( + # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + # ) + + qkv_dim = 0 + if self.split_qkv: + logger.info(f"split qkv for LoRA") + qkv_dim = unet.joint_blocks[0].context_block.attn.qkv.weight.size(0) + if train_t5xxl: + logger.info(f"train T5XXL as well") + + # create module instances + def create_modules( + is_mmdit: bool, + text_encoder_idx: Optional[int], + root_module: torch.nn.Module, + target_replace_modules: List[str], + filter: Optional[str] = None, + default_dim: Optional[int] = None, + include_conv2d_if_filter: bool = False, + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_SD3 + if is_mmdit + else [self.LORA_PREFIX_TEXT_ENCODER_CLIP_L, self.LORA_PREFIX_TEXT_ENCODER_CLIP_G, self.LORA_PREFIX_TEXT_ENCODER_T5][ + text_encoder_idx + ] + ) + + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if target_replace_modules is None or module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None: # dirty hack for all modules + module = root_module # search all modules + + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + (name + "." if name else "") + child_name + lora_name = lora_name.replace(".", "_") + + force_incl_conv2d = False + if filter is not None: + if not filter in lora_name: + continue + force_incl_conv2d = include_conv2d_if_filter + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha + + if is_mmdit and type_dims is not None: + # type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim] + identifier = [ + ("context_block", "attn"), + ("context_block", "mlp"), + ("context_block", "adaLN_modulation"), + ("x_block", "attn"), + ("x_block", "mlp"), + ("x_block", "adaLN_modulation"), + ] + for i, d in enumerate(type_dims): + if d is not None and all([id in lora_name for id in identifier[i]]): + dim = d # may be 0 for skip + break + + if is_mmdit and dim and self.train_block_indices is not None and "joint_blocks" in lora_name: + # "lora_unet_joint_blocks_0_x_block_attn_proj..." + block_index = int(lora_name.split("_")[4]) # bit dirty + if self.train_block_indices is not None and not self.train_block_indices[block_index]: + dim = 0 + + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + elif force_incl_conv2d: + # x_embedder + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): + skipped.append(lora_name) + continue + + # qkv split + split_dims = None + if is_mmdit and split_qkv: + if "joint_blocks" in lora_name and "qkv" in lora_name: + split_dims = [qkv_dim // 3] * 3 + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + split_dims=split_dims, + ) + loras.append(lora) + + if target_replace_modules is None: + break # all modules are searched + return loras, skipped + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + index = i + if not train_t5xxl and index >= 2: # 0: CLIP-L, 1: CLIP-G, 2: T5XXL, so we skip T5XXL if train_t5xxl is False + break + + logger.info(f"create LoRA for Text Encoder {index+1}:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.") + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + + # create LoRA for U-Net + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] + self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.SD3_TARGET_REPLACE_MODULE) + + # emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear] + if self.emb_dims: + for filter, in_dim in zip( + [ + "context_embedder", + "_t_embedder", # don't use "t_embedder" because it's used in "context_embedder" + "x_embedder", + "y_embedder", + "final_layer_adaLN_modulation", + "final_layer_linear", + ], + self.emb_dims, + ): + # x_embedder is conv2d, so we need to include it + loras, _ = create_modules( + True, None, unet, None, filter=filter, default_dim=in_dim, include_conv2d_if_filter=filter == "x_embedder" + ) + # if len(loras) > 0: + # logger.info(f"create LoRA for {filter}: {len(loras)} modules.") + self.unet_loras.extend(loras) + + logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.") + if verbose: + for lora in self.unet_loras: + logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}") + + skipped = skipped_te + skipped_un + if verbose and len(skipped) > 0: + logger.warning( + f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def load_state_dict(self, state_dict, strict=True): + # override to convert original weight to split qkv + if not self.split_qkv: + return super().load_state_dict(state_dict, strict) + + # split qkv + for key in list(state_dict.keys()): + if not ("joint_blocks" in key and "qkv" in key): + continue + + weight = state_dict[key] + lora_name = key.split(".")[0] + if "lora_down" in key and "weight" in key: + # dense weight (rank*3, in_dim) + split_weight = torch.chunk(weight, 3, dim=0) + for i, split_w in enumerate(split_weight): + state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w + + del state_dict[key] + # print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}") + elif "lora_up" in key and "weight" in key: + # sparse weight (out_dim=sum(split_dims), rank*3) + rank = weight.size(1) // 3 + i = 0 + split_dim = weight.shape[0] // 3 + for j in range(3): + state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dim, j * rank : (j + 1) * rank] + i += split_dim + del state_dict[key] + + # alpha is unchanged + + return super().load_state_dict(state_dict, strict) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if not self.split_qkv: + return super().state_dict(destination, prefix, keep_vars) + + # merge qkv + state_dict = super().state_dict(destination, prefix, keep_vars) + new_state_dict = {} + for key in list(state_dict.keys()): + if not ("joint_blocks" in key and "qkv" in key): + new_state_dict[key] = state_dict[key] + continue + + if key not in state_dict: + continue # already merged + + lora_name = key.split(".")[0] + + # (rank, in_dim) * 3 + down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(3)] + # (split dim, rank) * 3 + up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(3)] + + alpha = state_dict.pop(f"{lora_name}.alpha") + + # merge down weight + down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # merge up weight (sum of split_dim, rank*3) + split_dim, rank = up_weights[0].size() + qkv_dim = split_dim * 3 + up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + i = 0 + for j in range(3): + up_weight[i : i + split_dim, j * rank : (j + 1) * rank] = up_weights[j] + i += split_dim + + new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight + new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight + new_state_dict[f"{lora_name}.alpha"] = alpha + + # print( + # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # ) + print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return new_state_dict + + def apply_to(self, text_encoders, mmdit, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoders, mmdit, weights_sd, dtype=None, device=None): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if ( + key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_L) + or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_G) + or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5) + ): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_SD3): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info(f"weights are merged") + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + # make sure text_encoder_lr as list of three elements + # if float, use the same value for all three + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): + text_encoder_lr = [default_lr, default_lr, default_lr] + elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int): + text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr), float(text_encoder_lr)] + elif len(text_encoder_lr) == 1: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0], text_encoder_lr[0]] + elif len(text_encoder_lr) == 2: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[1], text_encoder_lr[1]] + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, loraplus_ratio): + param_groups = {"lora": {}, "plus": {}} + for lora in loras: + for name, param in lora.named_parameters(): + if loraplus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * loraplus_ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions + + if self.text_encoder_loras: + loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + + # split text encoder loras for te1 and te3 + te1_loras = [ + lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_L) + ] + te2_loras = [ + lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_G) + ] + te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)] + if len(te1_loras) > 0: + logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}") + params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te2_loras) > 0: + logger.info(f"Text Encoder 2 (CLIP-G): {len(te2_loras)} modules, LR {text_encoder_lr[1]}") + params, descriptions = assemble_params(te2_loras, text_encoder_lr[1], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te3_loras) > 0: + logger.info(f"Text Encoder 3 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[2]}") + params, descriptions = assemble_params(te3_loras, text_encoder_lr[2], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 3 " + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from ..library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/networks/merge_lora.py b/networks/merge_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..f86aa4863c6b1e8d26bebc26bb3456feca1e766a --- /dev/null +++ b/networks/merge_lora.py @@ -0,0 +1,360 @@ +import math +import argparse +import os +import time +import torch +from safetensors.torch import load_file, save_file +from ..library import sai_model_spec, train_util +from ..library import model_util as model_util +import lora +from ..library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) + +def load_state_dict(file_name, dtype): + if os.path.splitext(file_name)[1] == ".safetensors": + sd = load_file(file_name) + metadata = train_util.load_metadata_from_safetensors(file_name) + else: + sd = torch.load(file_name, map_location="cpu") + metadata = {} + + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + + return sd, metadata + + +def save_to_file(file_name, model, state_dict, dtype, metadata): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if os.path.splitext(file_name)[1] == ".safetensors": + save_file(model, file_name, metadata=metadata) + else: + torch.save(model, file_name) + + +def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): + text_encoder.to(merge_dtype) + unet.to(merge_dtype) + + # create module map + name_to_module = {} + for i, root_module in enumerate([text_encoder, unet]): + if i == 0: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER + target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE + else: + prefix = lora.LoRANetwork.LORA_PREFIX_UNET + target_replace_modules = ( + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + ) + + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + name_to_module[lora_name] = child_module + + for model, ratio in zip(models, ratios): + logger.info(f"loading: {model}") + lora_sd, _ = load_state_dict(model, merge_dtype) + + logger.info(f"merging...") + for key in lora_sd.keys(): + if "lora_down" in key: + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + # find original module for this lora + module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + if module_name not in name_to_module: + logger.info(f"no module found for LoRA weight: {key}") + continue + module = name_to_module[module_name] + # logger.info(f"apply {key} to {module}") + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + weight = module.weight + if len(weight.size()) == 2: + # linear + if len(up_weight.size()) == 4: # use linear projection mismatch + up_weight = up_weight.squeeze(3).squeeze(2) + down_weight = down_weight.squeeze(3).squeeze(2) + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + module.weight = torch.nn.Parameter(weight) + + +def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): + base_alphas = {} # alpha for merged model + base_dims = {} + + merged_sd = {} + v2 = None + base_model = None + for model, ratio in zip(models, ratios): + logger.info(f"loading: {model}") + lora_sd, lora_metadata = load_state_dict(model, merge_dtype) + + if lora_metadata is not None: + if v2 is None: + v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # return string + if base_model is None: + base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) + + # get alpha and dim + alphas = {} # alpha for current model + dims = {} # dims for current model + for key in lora_sd.keys(): + if "alpha" in key: + lora_module_name = key[: key.rfind(".alpha")] + alpha = float(lora_sd[key].detach().numpy()) + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + elif "lora_down" in key: + lora_module_name = key[: key.rfind(".lora_down")] + dim = lora_sd[key].size()[0] + dims[lora_module_name] = dim + if lora_module_name not in base_dims: + base_dims[lora_module_name] = dim + + for lora_module_name in dims.keys(): + if lora_module_name not in alphas: + alpha = dims[lora_module_name] + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + + logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + + # merge + logger.info(f"merging...") + for key in lora_sd.keys(): + if "alpha" in key: + continue + if "lora_up" in key and concat: + concat_dim = 1 + elif "lora_down" in key and concat: + concat_dim = 0 + else: + concat_dim = None + + lora_module_name = key[: key.rfind(".lora_")] + + base_alpha = base_alphas[lora_module_name] + alpha = alphas[lora_module_name] + + scale = math.sqrt(alpha / base_alpha) * ratio + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + + if key in merged_sd: + assert ( + merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None + ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" + if concat_dim is not None: + merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) + else: + merged_sd[key] = merged_sd[key] + lora_sd[key] * scale + else: + merged_sd[key] = lora_sd[key] * scale + + # set alpha to sd + for lora_module_name, alpha in base_alphas.items(): + key = lora_module_name + ".alpha" + merged_sd[key] = torch.tensor(alpha) + if shuffle: + key_down = lora_module_name + ".lora_down.weight" + key_up = lora_module_name + ".lora_up.weight" + dim = merged_sd[key_down].shape[0] + perm = torch.randperm(dim) + merged_sd[key_down] = merged_sd[key_down][perm] + merged_sd[key_up] = merged_sd[key_up][:,perm] + + logger.info("merged model") + logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + + # check all dims are same + dims_list = list(set(base_dims.values())) + alphas_list = list(set(base_alphas.values())) + all_same_dims = True + all_same_alphas = True + for dims in dims_list: + if dims != dims_list[0]: + all_same_dims = False + break + for alphas in alphas_list: + if alphas != alphas_list[0]: + all_same_alphas = False + break + + # build minimum metadata + dims = f"{dims_list[0]}" if all_same_dims else "Dynamic" + alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic" + metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, None) + + return merged_sd, metadata, v2 == "True" + + +def merge(args): + assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None + + merge_dtype = str_to_dtype(args.precision) + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + if args.sd_model is not None: + logger.info(f"loading SD model: {args.sd_model}") + + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) + + merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) + + if args.no_metadata: + sai_metadata = None + else: + merged_from = sai_model_spec.build_merged_from([args.sd_model] + args.models) + title = os.path.splitext(os.path.basename(args.save_to))[0] + sai_metadata = sai_model_spec.build_metadata( + None, + args.v2, + args.v2, + False, + False, + False, + time.time(), + title=title, + merged_from=merged_from, + is_stable_diffusion_ckpt=True, + ) + if args.v2: + # TODO read sai modelspec + logger.warning( + "Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" + ) + + logger.info(f"saving SD model to: {args.save_to}") + model_util.save_stable_diffusion_checkpoint( + args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae + ) + else: + state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + + logger.info(f"calculating hashes and creating metadata...") + + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + if not args.no_metadata: + merged_from = sai_model_spec.build_merged_from(args.models) + title = os.path.splitext(os.path.basename(args.save_to))[0] + sai_metadata = sai_model_spec.build_metadata( + state_dict, v2, v2, False, True, False, time.time(), title=title, merged_from=merged_from + ) + if v2: + # TODO read sai modelspec + logger.warning( + "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" + ) + metadata.update(sai_metadata) + + logger.info(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む") + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", + ) + parser.add_argument( + "--precision", + type=str, + default="float", + choices=["float", "fp16", "bf16"], + help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", + ) + parser.add_argument( + "--sd_model", + type=str, + default=None, + help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする", + ) + parser.add_argument( + "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + ) + parser.add_argument( + "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" + ) + parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument( + "--no_metadata", + action="store_true", + help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", + ) + parser.add_argument( + "--concat", + action="store_true", + help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / " + + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="shuffle lora weight./ " + + "LoRAの重みをシャッフルする", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + merge(args) diff --git a/networks/resize_lora.py b/networks/resize_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..0409d5dfb9aabcc0c890c7b35364509c9593b3e5 --- /dev/null +++ b/networks/resize_lora.py @@ -0,0 +1,411 @@ +# Convert LoRA to different rank approximation (should only be used to go to lower rank) +# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py +# Thanks to cloneofsimo + +import os +import argparse +import torch +from safetensors.torch import load_file, save_file, safe_open +from tqdm import tqdm +import numpy as np + +from ..library import train_util +from ..library import model_util +from ..library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +MIN_SV = 1e-6 + +# Model save and load functions + + +def load_state_dict(file_name, dtype): + if model_util.is_safetensors(file_name): + sd = load_file(file_name) + with safe_open(file_name, framework="pt") as f: + metadata = f.metadata() + else: + sd = torch.load(file_name, map_location="cpu") + metadata = None + + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + + return sd, metadata + + +def save_to_file(file_name, state_dict, dtype, metadata): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if model_util.is_safetensors(file_name): + save_file(state_dict, file_name, metadata) + else: + torch.save(state_dict, file_name) + + +# Indexing functions + + +def index_sv_cumulative(S, target): + original_sum = float(torch.sum(S)) + cumulative_sums = torch.cumsum(S, dim=0) / original_sum + index = int(torch.searchsorted(cumulative_sums, target)) + 1 + index = max(1, min(index, len(S) - 1)) + + return index + + +def index_sv_fro(S, target): + S_squared = S.pow(2) + S_fro_sq = float(torch.sum(S_squared)) + sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq + index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 + index = max(1, min(index, len(S) - 1)) + + return index + + +def index_sv_ratio(S, target): + max_sv = S[0] + min_sv = max_sv / target + index = int(torch.sum(S > min_sv).item()) + index = max(1, min(index, len(S) - 1)) + + return index + + +# Modified from Kohaku-blueleaf's extract/merge functions +def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): + out_size, in_size, kernel_size, _ = weight.size() + U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device)) + + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) + lora_rank = param_dict["new_rank"] + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu() + param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu() + del U, S, Vh, weight + return param_dict + + +def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): + out_size, in_size = weight.size() + + U, S, Vh = torch.linalg.svd(weight.to(device)) + + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) + lora_rank = param_dict["new_rank"] + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu() + param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu() + del U, S, Vh, weight + return param_dict + + +def merge_conv(lora_down, lora_up, device): + in_rank, in_size, kernel_size, k_ = lora_down.shape + out_size, out_rank, _, _ = lora_up.shape + assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch" + + lora_down = lora_down.to(device) + lora_up = lora_up.to(device) + + merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1) + weight = merged.reshape(out_size, in_size, kernel_size, kernel_size) + del lora_up, lora_down + return weight + + +def merge_linear(lora_down, lora_up, device): + in_rank, in_size = lora_down.shape + out_size, out_rank = lora_up.shape + assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch" + + lora_down = lora_down.to(device) + lora_up = lora_up.to(device) + + weight = lora_up @ lora_down + del lora_up, lora_down + return weight + + +# Calculate new rank + + +def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): + param_dict = {} + + if dynamic_method == "sv_ratio": + # Calculate new dim and alpha based off ratio + new_rank = index_sv_ratio(S, dynamic_param) + 1 + new_alpha = float(scale * new_rank) + + elif dynamic_method == "sv_cumulative": + # Calculate new dim and alpha based off cumulative sum + new_rank = index_sv_cumulative(S, dynamic_param) + 1 + new_alpha = float(scale * new_rank) + + elif dynamic_method == "sv_fro": + # Calculate new dim and alpha based off sqrt sum of squares + new_rank = index_sv_fro(S, dynamic_param) + 1 + new_alpha = float(scale * new_rank) + else: + new_rank = rank + new_alpha = float(scale * new_rank) + + if S[0] <= MIN_SV: # Zero matrix, set dim to 1 + new_rank = 1 + new_alpha = float(scale * new_rank) + elif new_rank > rank: # cap max rank at rank + new_rank = rank + new_alpha = float(scale * new_rank) + + # Calculate resize info + s_sum = torch.sum(torch.abs(S)) + s_rank = torch.sum(torch.abs(S[:new_rank])) + + S_squared = S.pow(2) + s_fro = torch.sqrt(torch.sum(S_squared)) + s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank])) + fro_percent = float(s_red_fro / s_fro) + + param_dict["new_rank"] = new_rank + param_dict["new_alpha"] = new_alpha + param_dict["sum_retained"] = (s_rank) / s_sum + param_dict["fro_retained"] = fro_percent + param_dict["max_ratio"] = S[0] / S[new_rank - 1] + + return param_dict + + +def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): + network_alpha = None + network_dim = None + verbose_str = "\n" + fro_list = [] + + # Extract loaded lora dim and alpha + for key, value in lora_sd.items(): + if network_alpha is None and "alpha" in key: + network_alpha = value + if network_dim is None and "lora_down" in key and len(value.size()) == 2: + network_dim = value.size()[0] + if network_alpha is not None and network_dim is not None: + break + if network_alpha is None: + network_alpha = network_dim + + scale = network_alpha / network_dim + + if dynamic_method: + logger.info( + f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}" + ) + + lora_down_weight = None + lora_up_weight = None + + o_lora_sd = lora_sd.copy() + block_down_name = None + block_up_name = None + + with torch.no_grad(): + for key, value in tqdm(lora_sd.items()): + weight_name = None + if "lora_down" in key: + block_down_name = key.rsplit(".lora_down", 1)[0] + weight_name = key.rsplit(".", 1)[-1] + lora_down_weight = value + else: + continue + + # find corresponding lora_up and alpha + block_up_name = block_down_name + lora_up_weight = lora_sd.get(block_up_name + ".lora_up." + weight_name, None) + lora_alpha = lora_sd.get(block_down_name + ".alpha", None) + + weights_loaded = lora_down_weight is not None and lora_up_weight is not None + + if weights_loaded: + + conv2d = len(lora_down_weight.size()) == 4 + if lora_alpha is None: + scale = 1.0 + else: + scale = lora_alpha / lora_down_weight.size()[0] + + if conv2d: + full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) + param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale) + else: + full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) + param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) + + if verbose: + max_ratio = param_dict["max_ratio"] + sum_retained = param_dict["sum_retained"] + fro_retained = param_dict["fro_retained"] + if not np.isnan(fro_retained): + fro_list.append(float(fro_retained)) + + verbose_str += f"{block_down_name:75} | " + verbose_str += ( + f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" + ) + + if verbose and dynamic_method: + verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n" + else: + verbose_str += "\n" + + new_alpha = param_dict["new_alpha"] + o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype) + + block_down_name = None + block_up_name = None + lora_down_weight = None + lora_up_weight = None + weights_loaded = False + del param_dict + + if verbose: + print(verbose_str) + print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") + logger.info("resizing complete") + return o_lora_sd, network_dim, new_alpha + + +def resize(args): + if args.save_to is None or not ( + args.save_to.endswith(".ckpt") + or args.save_to.endswith(".pt") + or args.save_to.endswith(".pth") + or args.save_to.endswith(".safetensors") + ): + raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.") + + args.new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank + + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None + + if args.dynamic_method and not args.dynamic_param: + raise Exception("If using dynamic_method, then dynamic_param is required") + + merge_dtype = str_to_dtype("float") # matmul method above only seems to work in float32 + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + logger.info("loading Model...") + lora_sd, metadata = load_state_dict(args.model, merge_dtype) + + logger.info("Resizing Lora...") + state_dict, old_dim, new_alpha = resize_lora_model( + lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose + ) + + # update metadata + if metadata is None: + metadata = {} + + comment = metadata.get("ss_training_comment", "") + + if not args.dynamic_method: + conv_desc = "" if args.new_rank == args.new_conv_rank else f" (conv: {args.new_conv_rank})" + metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}{conv_desc}; {comment}" + metadata["ss_network_dim"] = str(args.new_rank) + metadata["ss_network_alpha"] = str(new_alpha) + else: + metadata["ss_training_comment"] = ( + f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}" + ) + metadata["ss_network_dim"] = "Dynamic" + metadata["ss_network_alpha"] = "Dynamic" + + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + logger.info(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, save_dtype, metadata) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat", + ) + parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") + parser.add_argument( + "--new_conv_rank", + type=int, + default=None, + help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ", + ) + parser.add_argument( + "--save_to", + type=str, + default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors", + ) + parser.add_argument( + "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う" + ) + parser.add_argument( + "--verbose", action="store_true", help="Display verbose resizing information / rank変更時の詳細情報を出力する" + ) + parser.add_argument( + "--dynamic_method", + type=str, + default=None, + choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"], + help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank", + ) + parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction") + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + resize(args) diff --git a/nodes.py b/nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..49a59721a984dc785b615d56693a283a71d591d6 --- /dev/null +++ b/nodes.py @@ -0,0 +1,1798 @@ +import os +import torch +from torchvision import transforms + +import folder_paths +import comfy.model_management as mm +import comfy.utils +import toml +import json +import time +import shutil +import shlex + +from pathlib import Path +script_directory = os.path.dirname(os.path.abspath(__file__)) + +from .flux_train_network_comfy import FluxNetworkTrainer +from .library import flux_train_utils as flux_train_utils +from .flux_train_comfy import FluxTrainer +from .flux_train_comfy import setup_parser as train_setup_parser +from .library.device_utils import init_ipex +init_ipex() + +from .library import train_util +from .train_network import setup_parser as train_network_setup_parser +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import io +from PIL import Image + +import logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +class FluxTrainModelSelect: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "transformer": (folder_paths.get_filename_list("unet"), ), + "vae": (folder_paths.get_filename_list("vae"), ), + "clip_l": (folder_paths.get_filename_list("clip"), ), + "t5": (folder_paths.get_filename_list("clip"), ), + }, + "optional": { + "lora_path": ("STRING",{"multiline": True, "forceInput": True, "default": "", "tooltip": "pre-trained LoRA path to load (network_weights)"}), + } + } + + RETURN_TYPES = ("TRAIN_FLUX_MODELS",) + RETURN_NAMES = ("flux_models",) + FUNCTION = "loadmodel" + CATEGORY = "FluxTrainer" + + def loadmodel(self, transformer, vae, clip_l, t5, lora_path=""): + + transformer_path = folder_paths.get_full_path("unet", transformer) + vae_path = folder_paths.get_full_path("vae", vae) + clip_path = folder_paths.get_full_path("clip", clip_l) + t5_path = folder_paths.get_full_path("clip", t5) + + flux_models = { + "transformer": transformer_path, + "vae": vae_path, + "clip_l": clip_path, + "t5": t5_path, + "lora_path": lora_path + } + + return (flux_models,) + +class TrainDatasetGeneralConfig: + queue_counter = 0 + @classmethod + def IS_CHANGED(s, reset_on_queue=False, **kwargs): + if reset_on_queue: + s.queue_counter += 1 + print(f"queue_counter: {s.queue_counter}") + return s.queue_counter + @classmethod + def INPUT_TYPES(s): + return {"required": { + "color_aug": ("BOOLEAN",{"default": False, "tooltip": "enable weak color augmentation"}), + "flip_aug": ("BOOLEAN",{"default": False, "tooltip": "enable horizontal flip augmentation"}), + "shuffle_caption": ("BOOLEAN",{"default": False, "tooltip": "shuffle caption"}), + "caption_dropout_rate": ("FLOAT",{"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01,"tooltip": "tag dropout rate"}), + "alpha_mask": ("BOOLEAN",{"default": False, "tooltip": "use alpha channel as mask for training"}), + }, + "optional": { + "reset_on_queue": ("BOOLEAN",{"default": False, "tooltip": "Force refresh of everything for cleaner queueing"}), + "caption_extension": ("STRING",{"default": ".txt", "tooltip": "extension for caption files"}), + } + } + + RETURN_TYPES = ("JSON",) + RETURN_NAMES = ("dataset_general",) + FUNCTION = "create_config" + CATEGORY = "FluxTrainer" + + def create_config(self, shuffle_caption, caption_dropout_rate, color_aug, flip_aug, alpha_mask, reset_on_queue=False, caption_extension=".txt"): + + dataset = { + "general": { + "shuffle_caption": shuffle_caption, + "caption_extension": caption_extension, + "keep_tokens_separator": "|||", + "caption_dropout_rate": caption_dropout_rate, + "color_aug": color_aug, + "flip_aug": flip_aug, + }, + "datasets": [] + } + dataset_json = json.dumps(dataset, indent=2) + #print(dataset_json) + dataset_config = { + "datasets": dataset_json, + "alpha_mask": alpha_mask + } + return (dataset_config,) + +class TrainDatasetRegularization: + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "dataset_path": ("STRING",{"multiline": True, "default": "", "tooltip": "path to dataset, root is the 'ComfyUI' folder, with windows portable 'ComfyUI_windows_portable'"}), + "class_tokens": ("STRING",{"multiline": True, "default": "", "tooltip": "aka trigger word, if specified, will be added to the start of each caption, if no captions exist, will be used on it's own"}), + "num_repeats": ("INT", {"default": 1, "min": 1, "tooltip": "number of times to repeat dataset for an epoch"}), + }, + } + + RETURN_TYPES = ("JSON",) + RETURN_NAMES = ("subset",) + FUNCTION = "create_config" + CATEGORY = "FluxTrainer" + + def create_config(self, dataset_path, class_tokens, num_repeats): + + reg_subset = { + "image_dir": dataset_path, + "class_tokens": class_tokens, + "num_repeats": num_repeats, + "is_reg": True + } + + return reg_subset, + +class TrainDatasetAdd: + def __init__(self): + self.previous_dataset_signature = None + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "dataset_config": ("JSON",), + "width": ("INT",{"min": 64, "default": 1024, "tooltip": "base resolution width"}), + "height": ("INT",{"min": 64, "default": 1024, "tooltip": "base resolution height"}), + "batch_size": ("INT",{"min": 1, "default": 2, "tooltip": "Higher batch size uses more memory and generalizes the training more"}), + "dataset_path": ("STRING",{"multiline": True, "default": "", "tooltip": "path to dataset, root is the 'ComfyUI' folder, with windows portable 'ComfyUI_windows_portable'"}), + "class_tokens": ("STRING",{"multiline": True, "default": "", "tooltip": "aka trigger word, if specified, will be added to the start of each caption, if no captions exist, will be used on it's own"}), + "enable_bucket": ("BOOLEAN",{"default": True, "tooltip": "enable buckets for multi aspect ratio training"}), + "bucket_no_upscale": ("BOOLEAN",{"default": False, "tooltip": "don't allow upscaling when bucketing"}), + "num_repeats": ("INT", {"default": 1, "min": 1, "tooltip": "number of times to repeat dataset for an epoch"}), + "min_bucket_reso": ("INT", {"default": 256, "min": 64, "max": 4096, "step": 8, "tooltip": "min bucket resolution"}), + "max_bucket_reso": ("INT", {"default": 1024, "min": 64, "max": 4096, "step": 8, "tooltip": "max bucket resolution"}), + }, + "optional": { + "regularization": ("JSON", {"tooltip": "reg data dir"}), + } + } + + RETURN_TYPES = ("JSON",) + RETURN_NAMES = ("dataset",) + FUNCTION = "create_config" + CATEGORY = "FluxTrainer" + + def create_config(self, dataset_config, dataset_path, class_tokens, width, height, batch_size, num_repeats, enable_bucket, + bucket_no_upscale, min_bucket_reso, max_bucket_reso, regularization=None): + + new_dataset = { + "resolution": (width, height), + "batch_size": batch_size, + "enable_bucket": enable_bucket, + "bucket_no_upscale": bucket_no_upscale, + "min_bucket_reso": min_bucket_reso, + "max_bucket_reso": max_bucket_reso, + "subsets": [ + { + "image_dir": dataset_path, + "class_tokens": class_tokens, + "num_repeats": num_repeats + } + ] + } + if regularization is not None: + new_dataset["subsets"].append(regularization) + + # Generate a signature for the new dataset + new_dataset_signature = self.generate_signature(new_dataset) + + # Load the existing datasets + existing_datasets = json.loads(dataset_config["datasets"]) + + # Remove the previously added dataset if it exists + if self.previous_dataset_signature: + existing_datasets["datasets"] = [ + ds for ds in existing_datasets["datasets"] + if self.generate_signature(ds) != self.previous_dataset_signature + ] + + # Add the new dataset + existing_datasets["datasets"].append(new_dataset) + + # Store the new dataset signature for future runs + self.previous_dataset_signature = new_dataset_signature + + # Convert back to JSON and update dataset_config + updated_dataset_json = json.dumps(existing_datasets, indent=2) + dataset_config["datasets"] = updated_dataset_json + + return dataset_config, + + def generate_signature(self, dataset): + # Create a unique signature for the dataset based on its attributes + return json.dumps(dataset, sort_keys=True) + +class OptimizerConfig: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "optimizer_type": (["adamw8bit", "adamw","prodigy", "CAME", "Lion8bit", "Lion", "adamwschedulefree", "sgdschedulefree", "AdEMAMix8bit", "PagedAdEMAMix8bit", "ProdigyPlusScheduleFree"], {"default": "adamw8bit", "tooltip": "optimizer type"}), + "max_grad_norm": ("FLOAT",{"default": 1.0, "min": 0.0, "tooltip": "gradient clipping"}), + "lr_scheduler": (["constant", "cosine", "cosine_with_restarts", "polynomial", "constant_with_warmup"], {"default": "constant", "tooltip": "learning rate scheduler"}), + "lr_warmup_steps": ("INT",{"default": 0, "min": 0, "tooltip": "learning rate warmup steps"}), + "lr_scheduler_num_cycles": ("INT",{"default": 1, "min": 1, "tooltip": "learning rate scheduler num cycles"}), + "lr_scheduler_power": ("FLOAT",{"default": 1.0, "min": 0.0, "tooltip": "learning rate scheduler power"}), + "min_snr_gamma": ("FLOAT",{"default": 5.0, "min": 0.0, "step": 0.01, "tooltip": "gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by the paper"}), + "extra_optimizer_args": ("STRING",{"multiline": True, "default": "", "tooltip": "additional optimizer args"}), + }, + } + + RETURN_TYPES = ("ARGS",) + RETURN_NAMES = ("optimizer_settings",) + FUNCTION = "create_config" + CATEGORY = "FluxTrainer" + + def create_config(self, min_snr_gamma, extra_optimizer_args, **kwargs): + kwargs["min_snr_gamma"] = min_snr_gamma if min_snr_gamma != 0.0 else None + kwargs["optimizer_args"] = [arg.strip() for arg in extra_optimizer_args.strip().split('|') if arg.strip()] + return (kwargs,) + +class OptimizerConfigAdafactor: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "max_grad_norm": ("FLOAT",{"default": 0.0, "min": 0.0, "tooltip": "gradient clipping"}), + "lr_scheduler": (["constant", "cosine", "cosine_with_restarts", "polynomial", "constant_with_warmup", "adafactor"], {"default": "constant_with_warmup", "tooltip": "learning rate scheduler"}), + "lr_warmup_steps": ("INT",{"default": 0, "min": 0, "tooltip": "learning rate warmup steps"}), + "lr_scheduler_num_cycles": ("INT",{"default": 1, "min": 1, "tooltip": "learning rate scheduler num cycles"}), + "lr_scheduler_power": ("FLOAT",{"default": 1.0, "min": 0.0, "tooltip": "learning rate scheduler power"}), + "relative_step": ("BOOLEAN",{"default": False, "tooltip": "relative step"}), + "scale_parameter": ("BOOLEAN",{"default": False, "tooltip": "scale parameter"}), + "warmup_init": ("BOOLEAN",{"default": False, "tooltip": "warmup init"}), + "clip_threshold": ("FLOAT",{"default": 1.0, "min": 0.0, "tooltip": "clip threshold"}), + "min_snr_gamma": ("FLOAT",{"default": 5.0, "min": 0.0, "step": 0.01, "tooltip": "gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by the paper"}), + "extra_optimizer_args": ("STRING",{"multiline": True, "default": "", "tooltip": "additional optimizer args"}), + }, + } + + RETURN_TYPES = ("ARGS",) + RETURN_NAMES = ("optimizer_settings",) + FUNCTION = "create_config" + CATEGORY = "FluxTrainer" + + def create_config(self, relative_step, scale_parameter, warmup_init, clip_threshold, min_snr_gamma, extra_optimizer_args, **kwargs): + kwargs["optimizer_type"] = "adafactor" + extra_args = [arg.strip() for arg in extra_optimizer_args.strip().split('|') if arg.strip()] + node_args = [ + f"relative_step={relative_step}", + f"scale_parameter={scale_parameter}", + f"warmup_init={warmup_init}", + f"clip_threshold={clip_threshold}" + ] + kwargs["optimizer_args"] = node_args + extra_args + kwargs["min_snr_gamma"] = min_snr_gamma if min_snr_gamma != 0.0 else None + + return (kwargs,) + +class FluxTrainerLossConfig: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "loss_type": (["l2", "huber","smooth_l1"], {"default": "huber", "tooltip": "The type of loss function to use"}), + "huber_schedule": (["snr", "exponential", "constant"], {"default": "exponential", "tooltip": "The scheduling method for Huber loss (constant, exponential, or SNR-based). Only used when loss_type is 'huber' or 'smooth_l1'. default is snr"}), + "huber_c": ("FLOAT",{"default": 0.25, "min": 0.0, "step": 0.01, "tooltip": "The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1"}), + "huber_scale": ("FLOAT",{"default": 1.75, "min": 0.0, "step": 0.01, "tooltip": "The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0"}), + }, + } + + RETURN_TYPES = ("ARGS",) + RETURN_NAMES = ("loss_args",) + FUNCTION = "create_config" + CATEGORY = "FluxTrainer" + + def create_config(self, **kwargs): + return (kwargs,) + +class OptimizerConfigProdigy: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "max_grad_norm": ("FLOAT",{"default": 0.0, "min": 0.0, "tooltip": "gradient clipping"}), + "lr_scheduler": (["constant", "cosine", "cosine_with_restarts", "polynomial", "constant_with_warmup", "adafactor"], {"default": "constant", "tooltip": "learning rate scheduler"}), + "lr_warmup_steps": ("INT",{"default": 0, "min": 0, "tooltip": "learning rate warmup steps"}), + "lr_scheduler_num_cycles": ("INT",{"default": 1, "min": 1, "tooltip": "learning rate scheduler num cycles"}), + "lr_scheduler_power": ("FLOAT",{"default": 1.0, "min": 0.0, "tooltip": "learning rate scheduler power"}), + "weight_decay": ("FLOAT",{"default": 0.0, "step": 0.0001, "tooltip": "weight decay (L2 penalty)"}), + "decouple": ("BOOLEAN",{"default": True, "tooltip": "use AdamW style weight decay"}), + "use_bias_correction": ("BOOLEAN",{"default": False, "tooltip": "turn on Adam's bias correction"}), + "min_snr_gamma": ("FLOAT",{"default": 5.0, "min": 0.0, "step": 0.01, "tooltip": "gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by the paper"}), + "extra_optimizer_args": ("STRING",{"multiline": True, "default": "", "tooltip": "additional optimizer args"}), + }, + } + + RETURN_TYPES = ("ARGS",) + RETURN_NAMES = ("optimizer_settings",) + FUNCTION = "create_config" + CATEGORY = "FluxTrainer" + + def create_config(self, weight_decay, decouple, min_snr_gamma, use_bias_correction, extra_optimizer_args, **kwargs): + kwargs["optimizer_type"] = "prodigy" + extra_args = [arg.strip() for arg in extra_optimizer_args.strip().split('|') if arg.strip()] + node_args = [ + f"weight_decay={weight_decay}", + f"decouple={decouple}", + f"use_bias_correction={use_bias_correction}" + ] + kwargs["optimizer_args"] = node_args + extra_args + kwargs["min_snr_gamma"] = min_snr_gamma if min_snr_gamma != 0.0 else None + + return (kwargs,) + +class TrainNetworkConfig: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_type": (["lora", "LyCORIS/LoKr", "LyCORIS/Locon", "LyCORIS/LoHa"], {"default": "lora", "tooltip": "network type"}), + "lycoris_preset": (["full", "full-lin", "attn-mlp", "attn-only"], {"default": "attn-mlp"}), + "factor": ("INT",{"default": -1, "min": -1, "max": 16, "step": 1, "tooltip": "LoKr factor"}), + "extra_network_args": ("STRING",{"multiline": True, "default": "", "tooltip": "additional network args"}), + }, + } + + RETURN_TYPES = ("NETWORK_CONFIG",) + RETURN_NAMES = ("network_config",) + FUNCTION = "create_config" + CATEGORY = "FluxTrainer" + + def create_config(self, network_type, extra_network_args, lycoris_preset, factor): + + extra_args = [arg.strip() for arg in extra_network_args.strip().split('|') if arg.strip()] + + if network_type == "lora": + network_module = ".networks.lora" + elif network_type == "LyCORIS/LoKr": + network_module = ".lycoris.kohya" + algo = "lokr" + elif network_type == "LyCORIS/Locon": + network_module = ".lycoris.kohya" + algo = "locon" + elif network_type == "LyCORIS/LoHa": + network_module = ".lycoris.kohya" + algo = "loha" + + network_args = [ + f"algo={algo}", + f"factor={factor}", + f"preset={lycoris_preset}" + ] + network_config = { + "network_module": network_module, + "network_args": network_args + extra_args + } + + return (network_config,) + +class OptimizerConfigProdigyPlusScheduleFree: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "lr": ("FLOAT",{"default": 1.0, "min": 0.0, "step": 1e-7, "tooltip": "Learning rate adjustment parameter. Increases or decreases the Prodigy learning rate."}), + "max_grad_norm": ("FLOAT",{"default": 0.0, "min": 0.0, "tooltip": "gradient clipping"}), + "prodigy_steps": ("INT",{"default": 0, "min": 0, "tooltip": "Freeze Prodigy stepsize adjustments after a certain optimiser step."}), + "d0": ("FLOAT",{"default": 1e-6, "min": 0.0,"step": 1e-7, "tooltip": "initial learning rate"}), + "d_coeff": ("FLOAT",{"default": 1.0, "min": 0.0, "step": 1e-7, "tooltip": "Coefficient in the expression for the estimate of d (default 1.0). Values such as 0.5 and 2.0 typically work as well. Changing this parameter is the preferred way to tune the method."}), + "split_groups": ("BOOLEAN",{"default": True, "tooltip": "Track individual adaptation values for each parameter group."}), + #"beta3": ("FLOAT",{"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.0001, "tooltip": " Coefficient for computing the Prodigy stepsize using running averages. If set to None, uses the value of square root of beta2 (default: None)."}), + #"beta4": ("FLOAT",{"default": 0, "min": 0.0, "max": 1.0, "step": 0.0001, "tooltip": "Coefficient for updating the learning rate from Prodigy's adaptive stepsize. Smooths out spikes in learning rate adjustments. If set to None, beta1 is used instead. (default 0, which disables smoothing and uses original Prodigy behaviour)."}), + "use_bias_correction": ("BOOLEAN",{"default": False, "tooltip": "Use the RAdam variant of schedule-free"}), + "min_snr_gamma": ("FLOAT",{"default": 5.0, "min": 0.0, "step": 0.01, "tooltip": "gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by the paper"}), + "use_stableadamw": ("BOOLEAN",{"default": True, "tooltip": "Scales parameter updates by the root-mean-square of the normalised gradient, in essence identical to Adafactor's gradient scaling. Set to False if the adaptive learning rate never improves."}), + "use_cautious" : ("BOOLEAN",{"default": False, "tooltip": "Experimental. Perform 'cautious' updates, as proposed in https://arxiv.org/pdf/2411.16085. Modifies the update to isolate and boost values that align with the current gradient."}), + "use_adopt": ("BOOLEAN",{"default": False, "tooltip": "Experimental. Performs a modified step where the second moment is updated after the parameter update, so as not to include the current gradient in the denominator. This is a partial implementation of ADOPT (https://arxiv.org/abs/2411.02853), as we don't have a first moment to use for the update."}), + "use_grams": ("BOOLEAN",{"default": False, "tooltip": "Perform 'grams' updates, as proposed in https://arxiv.org/abs/2412.17107. Modifies the update using sign operations that align with the current gradient. Note that we do not have access to a first moment, so this deviates from the paper (we apply the sign directly to the update). May have a limited effect."}), + "stochastic_rounding": ("BOOLEAN",{"default": True, "tooltip": "Use stochastic rounding for bfloat16 weights"}), + "use_orthograd": ("BOOLEAN",{"default": False, "tooltip": "Experimental. Updates weights using the component of the gradient that is orthogonal to the current weight direction, as described in (https://arxiv.org/pdf/2501.04697). Can help prevent overfitting and improve generalisation."}), + "use_focus ": ("BOOLEAN",{"default": False, "tooltip": "Experimental. Modifies the update step to better handle noise at large step sizes. (https://arxiv.org/abs/2501.12243). This method is incompatible with factorisation, Muon and Adam-atan2."}), + "extra_optimizer_args": ("STRING",{"multiline": True, "default": "", "tooltip": "additional optimizer args"}), + }, + } + + RETURN_TYPES = ("ARGS",) + RETURN_NAMES = ("optimizer_settings",) + FUNCTION = "create_config" + CATEGORY = "FluxTrainer" + + def create_config(self, min_snr_gamma, use_bias_correction, extra_optimizer_args, **kwargs): + kwargs["optimizer_type"] = "ProdigyPlusScheduleFree" + kwargs["lr_scheduler"] = "constant" + extra_args = [arg.strip() for arg in extra_optimizer_args.strip().split('|') if arg.strip()] + node_args = [ + f"use_bias_correction={use_bias_correction}", + ] + kwargs["optimizer_args"] = node_args + extra_args + kwargs["min_snr_gamma"] = min_snr_gamma if min_snr_gamma != 0.0 else None + + return (kwargs,) + +class InitFluxLoRATraining: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "flux_models": ("TRAIN_FLUX_MODELS",), + "dataset": ("JSON",), + "optimizer_settings": ("ARGS",), + "output_name": ("STRING", {"default": "flux_lora", "multiline": False}), + "output_dir": ("STRING", {"default": "flux_trainer_output", "multiline": False, "tooltip": "path to dataset, root is the 'ComfyUI' folder, with windows portable 'ComfyUI_windows_portable'"}), + "network_dim": ("INT", {"default": 4, "min": 1, "max": 100000, "step": 1, "tooltip": "network dim"}), + "network_alpha": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2048.0, "step": 0.01, "tooltip": "network alpha"}), + "learning_rate": ("FLOAT", {"default": 4e-4, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "learning rate"}), + "max_train_steps": ("INT", {"default": 1500, "min": 1, "max": 100000, "step": 1, "tooltip": "max number of training steps"}), + "apply_t5_attn_mask": ("BOOLEAN", {"default": True, "tooltip": "apply t5 attention mask"}), + "cache_latents": (["disk", "memory", "disabled"], {"tooltip": "caches text encoder outputs"}), + "cache_text_encoder_outputs": (["disk", "memory", "disabled"], {"tooltip": "caches text encoder outputs"}), + "blocks_to_swap": ("INT", {"default": 0, "tooltip": "Previously known as split_mode, number of blocks to swap to save memory, default to enable is 18"}), + "weighting_scheme": (["logit_normal", "sigma_sqrt", "mode", "cosmap", "none"],), + "logit_mean": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "mean to use when using the logit_normal weighting scheme"}), + "logit_std": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01,"tooltip": "std to use when using the logit_normal weighting scheme"}), + "mode_scale": ("FLOAT", {"default": 1.29, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Scale of mode weighting scheme. Only effective when using the mode as the weighting_scheme"}), + "timestep_sampling": (["sigmoid", "uniform", "sigma", "shift", "flux_shift"], {"tooltip": "Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid (recommend value of 3.1582 for discrete_flow_shift)"}), + "sigmoid_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1, "tooltip": "Scale factor for sigmoid timestep sampling (only used when timestep-sampling is sigmoid"}), + "model_prediction_type": (["raw", "additive", "sigma_scaled"], {"tooltip": "How to interpret and process the model prediction: raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."}), + "guidance_scale": ("FLOAT", {"default": 1.0, "min": 1.0, "max": 32.0, "step": 0.01, "tooltip": "guidance scale, for Flux training should be 1.0"}), + "discrete_flow_shift": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.0001, "tooltip": "for the Euler Discrete Scheduler, default is 3.0"}), + "highvram": ("BOOLEAN", {"default": False, "tooltip": "memory mode"}), + "fp8_base": ("BOOLEAN", {"default": True, "tooltip": "use fp8 for base model"}), + "gradient_dtype": (["fp32", "fp16", "bf16"], {"default": "fp32", "tooltip": "the actual dtype training uses"}), + "save_dtype": (["fp32", "fp16", "bf16", "fp8_e4m3fn", "fp8_e5m2"], {"default": "bf16", "tooltip": "the dtype to save checkpoints as"}), + "attention_mode": (["sdpa", "xformers", "disabled"], {"default": "sdpa", "tooltip": "memory efficient attention mode"}), + "sample_prompts": ("STRING", {"multiline": True, "default": "illustration of a kitten | photograph of a turtle", "tooltip": "validation sample prompts, for multiple prompts, separate by `|`"}), + }, + "optional": { + "additional_args": ("STRING", {"multiline": True, "default": "", "tooltip": "additional args to pass to the training command"}), + "resume_args": ("ARGS", {"default": "", "tooltip": "resume args to pass to the training command"}), + "train_text_encoder": (['disabled', 'clip_l', 'clip_l_fp8', 'clip_l+T5', 'clip_l+T5_fp8'], {"default": 'disabled', "tooltip": "also train the selected text encoders using specified dtype, T5 can not be trained without clip_l"}), + "clip_l_lr": ("FLOAT", {"default": 0, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "text encoder learning rate"}), + "T5_lr": ("FLOAT", {"default": 0, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "text encoder learning rate"}), + "block_args": ("ARGS", {"default": "", "tooltip": "limit the blocks used in the LoRA"}), + "gradient_checkpointing": (["enabled", "enabled_with_cpu_offloading", "disabled"], {"default": "enabled", "tooltip": "use gradient checkpointing"}), + "loss_args": ("ARGS", {"default": "", "tooltip": "loss args"}), + "network_config": ("NETWORK_CONFIG", {"tooltip": "additional network config"}), + }, + "hidden": { + "prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO" + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "INT", "KOHYA_ARGS",) + RETURN_NAMES = ("network_trainer", "epochs_count", "args",) + FUNCTION = "init_training" + CATEGORY = "FluxTrainer" + + def init_training(self, flux_models, dataset, optimizer_settings, sample_prompts, output_name, attention_mode, + gradient_dtype, save_dtype, additional_args=None, resume_args=None, train_text_encoder='disabled', + block_args=None, gradient_checkpointing="enabled", prompt=None, extra_pnginfo=None, clip_l_lr=0, T5_lr=0, loss_args=None, network_config=None, **kwargs): + mm.soft_empty_cache() + + output_dir = os.path.abspath(kwargs.get("output_dir")) + os.makedirs(output_dir, exist_ok=True) + + total, used, free = shutil.disk_usage(output_dir) + + required_free_space = 2 * (2**30) + if free <= required_free_space: + raise ValueError(f"Insufficient disk space. Required: {required_free_space/2**30}GB. Available: {free/2**30}GB") + + dataset_config = dataset["datasets"] + dataset_toml = toml.dumps(json.loads(dataset_config)) + + parser = train_network_setup_parser() + flux_train_utils.add_flux_train_arguments(parser) + + if additional_args is not None: + print(f"additional_args: {additional_args}") + args, _ = parser.parse_known_args(args=shlex.split(additional_args)) + else: + args, _ = parser.parse_known_args() + + if kwargs.get("cache_latents") == "memory": + kwargs["cache_latents"] = True + kwargs["cache_latents_to_disk"] = False + elif kwargs.get("cache_latents") == "disk": + kwargs["cache_latents"] = True + kwargs["cache_latents_to_disk"] = True + kwargs["caption_dropout_rate"] = 0.0 + kwargs["shuffle_caption"] = False + kwargs["token_warmup_step"] = 0.0 + kwargs["caption_tag_dropout_rate"] = 0.0 + else: + kwargs["cache_latents"] = False + kwargs["cache_latents_to_disk"] = False + + if kwargs.get("cache_text_encoder_outputs") == "memory": + kwargs["cache_text_encoder_outputs"] = True + kwargs["cache_text_encoder_outputs_to_disk"] = False + elif kwargs.get("cache_text_encoder_outputs") == "disk": + kwargs["cache_text_encoder_outputs"] = True + kwargs["cache_text_encoder_outputs_to_disk"] = True + else: + kwargs["cache_text_encoder_outputs"] = False + kwargs["cache_text_encoder_outputs_to_disk"] = False + + if '|' in sample_prompts: + prompts = sample_prompts.split('|') + else: + prompts = [sample_prompts] + + config_dict = { + "sample_prompts": prompts, + "save_precision": save_dtype, + "mixed_precision": "bf16", + "num_cpu_threads_per_process": 1, + "pretrained_model_name_or_path": flux_models["transformer"], + "clip_l": flux_models["clip_l"], + "t5xxl": flux_models["t5"], + "ae": flux_models["vae"], + "save_model_as": "safetensors", + "persistent_data_loader_workers": False, + "max_data_loader_n_workers": 0, + "seed": 42, + "network_module": ".networks.lora_flux" if network_config is None else network_config["network_module"], + "dataset_config": dataset_toml, + "output_name": f"{output_name}_rank{kwargs.get('network_dim')}_{save_dtype}", + "loss_type": "l2", + "t5xxl_max_token_length": 512, + "alpha_mask": dataset["alpha_mask"], + "network_train_unet_only": True if train_text_encoder == 'disabled' else False, + "fp8_base_unet": True if "fp8" in train_text_encoder else False, + "disable_mmap_load_safetensors": False, + "network_args": None if network_config is None else network_config["network_args"], + } + attention_settings = { + "sdpa": {"mem_eff_attn": True, "xformers": False, "spda": True}, + "xformers": {"mem_eff_attn": True, "xformers": True, "spda": False} + } + config_dict.update(attention_settings.get(attention_mode, {})) + + gradient_dtype_settings = { + "fp16": {"full_fp16": True, "full_bf16": False, "mixed_precision": "fp16"}, + "bf16": {"full_bf16": True, "full_fp16": False, "mixed_precision": "bf16"} + } + config_dict.update(gradient_dtype_settings.get(gradient_dtype, {})) + + if train_text_encoder != 'disabled': + if T5_lr != "NaN": + config_dict["text_encoder_lr"] = clip_l_lr + if T5_lr != "NaN": + config_dict["text_encoder_lr"] = [clip_l_lr, T5_lr] + + if gradient_checkpointing == "disabled": + config_dict["gradient_checkpointing"] = False + elif gradient_checkpointing == "enabled_with_cpu_offloading": + config_dict["gradient_checkpointing"] = True + config_dict["cpu_offload_checkpointing"] = True + else: + config_dict["gradient_checkpointing"] = True + + if flux_models["lora_path"]: + config_dict["network_weights"] = flux_models["lora_path"] + + config_dict.update(kwargs) + config_dict.update(optimizer_settings) + + if loss_args: + config_dict.update(loss_args) + + if resume_args: + config_dict.update(resume_args) + + for key, value in config_dict.items(): + setattr(args, key, value) + + #network args + additional_network_args = [] + + if "T5" in train_text_encoder: + additional_network_args.append("train_t5xxl=True") + + if block_args: + additional_network_args.append(block_args["include"]) + + # Handle network_args in args Namespace + if hasattr(args, 'network_args') and isinstance(args.network_args, list): + args.network_args.extend(additional_network_args) + else: + setattr(args, 'network_args', additional_network_args) + + saved_args_file_path = os.path.join(output_dir, f"{output_name}_args.json") + with open(saved_args_file_path, 'w') as f: + json.dump(vars(args), f, indent=4) + + #workflow saving + metadata = {} + if extra_pnginfo is not None: + metadata.update(extra_pnginfo["workflow"]) + + saved_workflow_file_path = os.path.join(output_dir, f"{output_name}_workflow.json") + with open(saved_workflow_file_path, 'w') as f: + json.dump(metadata, f, indent=4) + + #pass args to kohya and initialize trainer + with torch.inference_mode(False): + network_trainer = FluxNetworkTrainer() + training_loop = network_trainer.init_train(args) + + epochs_count = network_trainer.num_train_epochs + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, epochs_count, args) + +class InitFluxTraining: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "flux_models": ("TRAIN_FLUX_MODELS",), + "dataset": ("JSON",), + "optimizer_settings": ("ARGS",), + "output_name": ("STRING", {"default": "flux", "multiline": False}), + "output_dir": ("STRING", {"default": "flux_trainer_output", "multiline": False, "tooltip": "path to dataset, root is the 'ComfyUI' folder, with windows portable 'ComfyUI_windows_portable'"}), + "learning_rate": ("FLOAT", {"default": 4e-6, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "learning rate"}), + "max_train_steps": ("INT", {"default": 1500, "min": 1, "max": 100000, "step": 1, "tooltip": "max number of training steps"}), + "apply_t5_attn_mask": ("BOOLEAN", {"default": True, "tooltip": "apply t5 attention mask"}), + "t5xxl_max_token_length": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 8, "tooltip": "dev and LibreFlux uses 512, schnell 256"}), + "cache_latents": (["disk", "memory", "disabled"], {"tooltip": "caches text encoder outputs"}), + "cache_text_encoder_outputs": (["disk", "memory", "disabled"], {"tooltip": "caches text encoder outputs"}), + "weighting_scheme": (["logit_normal", "sigma_sqrt", "mode", "cosmap", "none"],), + "logit_mean": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "mean to use when using the logit_normal weighting scheme"}), + "logit_std": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01,"tooltip": "std to use when using the logit_normal weighting scheme"}), + "mode_scale": ("FLOAT", {"default": 1.29, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Scale of mode weighting scheme. Only effective when using the mode as the weighting_scheme"}), + "loss_type": (["l1", "l2", "huber", "smooth_l1"], {"default": "l2", "tooltip": "loss type"}), + "timestep_sampling": (["sigmoid", "uniform", "sigma", "shift", "flux_shift"], {"tooltip": "Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid (recommend value of 3.1582 for discrete_flow_shift)"}), + "sigmoid_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1, "tooltip": "Scale factor for sigmoid timestep sampling (only used when timestep-sampling is sigmoid"}), + "model_prediction_type": (["raw", "additive", "sigma_scaled"], {"tooltip": "How to interpret and process the model prediction: raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)"}), + "cpu_offload_checkpointing": ("BOOLEAN", {"default": True, "tooltip": "offload the gradient checkpointing to CPU. This reduces VRAM usage for about 2GB"}), + "optimizer_fusing": (['fused_backward_pass', 'blockwise_fused_optimizers'], {"tooltip": "reduces memory use"}), + "blocks_to_swap": ("INT", {"default": 0, "min": 0, "max": 100, "step": 1, "tooltip": "Sets the number of blocks (~640MB) to swap during the forward and backward passes, increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."}), + "guidance_scale": ("FLOAT", {"default": 1.0, "min": 1.0, "max": 32.0, "step": 0.01, "tooltip": "guidance scale"}), + "discrete_flow_shift": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.0001, "tooltip": "for the Euler Discrete Scheduler, default is 3.0"}), + "highvram": ("BOOLEAN", {"default": False, "tooltip": "memory mode"}), + "fp8_base": ("BOOLEAN", {"default": False, "tooltip": "use fp8 for base model"}), + "gradient_dtype": (["fp32", "fp16", "bf16"], {"default": "bf16", "tooltip": "to use the full fp16/bf16 training"}), + "save_dtype": (["fp32", "fp16", "bf16", "fp8_e4m3fn"], {"default": "bf16", "tooltip": "the dtype to save checkpoints as"}), + "attention_mode": (["sdpa", "xformers", "disabled"], {"default": "sdpa", "tooltip": "memory efficient attention mode"}), + "sample_prompts": ("STRING", {"multiline": True, "default": "illustration of a kitten | photograph of a turtle", "tooltip": "validation sample prompts, for multiple prompts, separate by `|`"}), + }, + "optional": { + "additional_args": ("STRING", {"multiline": True, "default": "", "tooltip": "additional args to pass to the training command"}), + "resume_args": ("ARGS", {"default": "", "tooltip": "resume args to pass to the training command"}), + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "INT", "KOHYA_ARGS") + RETURN_NAMES = ("network_trainer", "epochs_count", "args") + FUNCTION = "init_training" + CATEGORY = "FluxTrainer" + + def init_training(self, flux_models, optimizer_settings, dataset, sample_prompts, output_name, + attention_mode, gradient_dtype, save_dtype, optimizer_fusing, additional_args=None, resume_args=None, **kwargs,): + mm.soft_empty_cache() + + output_dir = os.path.abspath(kwargs.get("output_dir")) + os.makedirs(output_dir, exist_ok=True) + + total, used, free = shutil.disk_usage(output_dir) + required_free_space = 25 * (2**30) + if free <= required_free_space: + raise ValueError(f"Most likely insufficient disk space to complete training. Required: {required_free_space/2**30}GB. Available: {free/2**30}GB") + + dataset_config = dataset["datasets"] + dataset_toml = toml.dumps(json.loads(dataset_config)) + + parser = train_setup_parser() + flux_train_utils.add_flux_train_arguments(parser) + + if additional_args is not None: + print(f"additional_args: {additional_args}") + args, _ = parser.parse_known_args(args=shlex.split(additional_args)) + else: + args, _ = parser.parse_known_args() + + if kwargs.get("cache_latents") == "memory": + kwargs["cache_latents"] = True + kwargs["cache_latents_to_disk"] = False + elif kwargs.get("cache_latents") == "disk": + kwargs["cache_latents"] = True + kwargs["cache_latents_to_disk"] = True + kwargs["caption_dropout_rate"] = 0.0 + kwargs["shuffle_caption"] = False + kwargs["token_warmup_step"] = 0.0 + kwargs["caption_tag_dropout_rate"] = 0.0 + else: + kwargs["cache_latents"] = False + kwargs["cache_latents_to_disk"] = False + + if kwargs.get("cache_text_encoder_outputs") == "memory": + kwargs["cache_text_encoder_outputs"] = True + kwargs["cache_text_encoder_outputs_to_disk"] = False + elif kwargs.get("cache_text_encoder_outputs") == "disk": + kwargs["cache_text_encoder_outputs"] = True + kwargs["cache_text_encoder_outputs_to_disk"] = True + else: + kwargs["cache_text_encoder_outputs"] = False + kwargs["cache_text_encoder_outputs_to_disk"] = False + + if '|' in sample_prompts: + prompts = sample_prompts.split('|') + else: + prompts = [sample_prompts] + + config_dict = { + "sample_prompts": prompts, + "save_precision": save_dtype, + "mixed_precision": "bf16", + "num_cpu_threads_per_process": 1, + "pretrained_model_name_or_path": flux_models["transformer"], + "clip_l": flux_models["clip_l"], + "t5xxl": flux_models["t5"], + "ae": flux_models["vae"], + "save_model_as": "safetensors", + "persistent_data_loader_workers": False, + "max_data_loader_n_workers": 0, + "seed": 42, + "gradient_checkpointing": True, + "dataset_config": dataset_toml, + "output_name": f"{output_name}_{save_dtype}", + "mem_eff_save": True, + "disable_mmap_load_safetensors": True, + + } + optimizer_fusing_settings = { + "fused_backward_pass": {"fused_backward_pass": True, "blockwise_fused_optimizers": False}, + "blockwise_fused_optimizers": {"fused_backward_pass": False, "blockwise_fused_optimizers": True} + } + config_dict.update(optimizer_fusing_settings.get(optimizer_fusing, {})) + + attention_settings = { + "sdpa": {"mem_eff_attn": True, "xformers": False, "spda": True}, + "xformers": {"mem_eff_attn": True, "xformers": True, "spda": False} + } + config_dict.update(attention_settings.get(attention_mode, {})) + + gradient_dtype_settings = { + "fp16": {"full_fp16": True, "full_bf16": False, "mixed_precision": "fp16"}, + "bf16": {"full_bf16": True, "full_fp16": False, "mixed_precision": "bf16"} + } + config_dict.update(gradient_dtype_settings.get(gradient_dtype, {})) + + config_dict.update(kwargs) + config_dict.update(optimizer_settings) + + if resume_args: + config_dict.update(resume_args) + + for key, value in config_dict.items(): + setattr(args, key, value) + + with torch.inference_mode(False): + network_trainer = FluxTrainer() + training_loop = network_trainer.init_train(args) + + epochs_count = network_trainer.num_train_epochs + + + saved_args_file_path = os.path.join(output_dir, f"{output_name}_args.json") + with open(saved_args_file_path, 'w') as f: + json.dump(vars(args), f, indent=4) + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, epochs_count, args) + +class InitFluxTrainingFromPreset: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "flux_models": ("TRAIN_FLUX_MODELS",), + "dataset_settings": ("TOML_DATASET",), + "preset_args": ("KOHYA_ARGS",), + "output_name": ("STRING", {"default": "flux", "multiline": False}), + "output_dir": ("STRING", {"default": "flux_trainer_output", "multiline": False, "tooltip": "output directory, root is ComfyUI folder"}), + "sample_prompts": ("STRING", {"multiline": True, "default": "illustration of a kitten | photograph of a turtle", "tooltip": "validation sample prompts, for multiple prompts, separate by `|`"}), + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "INT", "STRING", "KOHYA_ARGS") + RETURN_NAMES = ("network_trainer", "epochs_count", "output_path", "args") + FUNCTION = "init_training" + CATEGORY = "FluxTrainer" + + def init_training(self, flux_models, dataset_settings, sample_prompts, output_name, preset_args, **kwargs,): + mm.soft_empty_cache() + + dataset = dataset_settings["dataset"] + dataset_repeats = dataset_settings["repeats"] + + parser = train_setup_parser() + args, _ = parser.parse_known_args() + for key, value in vars(preset_args).items(): + setattr(args, key, value) + + output_dir = os.path.join(script_directory, "output") + if '|' in sample_prompts: + prompts = sample_prompts.split('|') + else: + prompts = [sample_prompts] + + width, height = toml.loads(dataset)["datasets"][0]["resolution"] + config_dict = { + "sample_prompts": prompts, + "dataset_repeats": dataset_repeats, + "num_cpu_threads_per_process": 1, + "pretrained_model_name_or_path": flux_models["transformer"], + "clip_l": flux_models["clip_l"], + "t5xxl": flux_models["t5"], + "ae": flux_models["vae"], + "save_model_as": "safetensors", + "persistent_data_loader_workers": False, + "max_data_loader_n_workers": 0, + "seed": 42, + "gradient_checkpointing": True, + "dataset_config": dataset, + "output_dir": output_dir, + "output_name": f"{output_name}_rank{kwargs.get('network_dim')}_{args.save_precision}", + "width" : int(width), + "height" : int(height), + + } + + config_dict.update(kwargs) + + for key, value in config_dict.items(): + setattr(args, key, value) + + with torch.inference_mode(False): + network_trainer = FluxNetworkTrainer() + training_loop = network_trainer.init_train(args) + + final_output_path = os.path.join(output_dir, output_name) + + epochs_count = network_trainer.num_train_epochs + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, epochs_count, final_output_path, args) + +class FluxTrainLoop: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "steps": ("INT", {"default": 1, "min": 1, "max": 10000, "step": 1, "tooltip": "the step point in training to validate/save"}), + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "INT",) + RETURN_NAMES = ("network_trainer", "steps",) + FUNCTION = "train" + CATEGORY = "FluxTrainer" + + def train(self, network_trainer, steps): + with torch.inference_mode(False): + training_loop = network_trainer["training_loop"] + network_trainer = network_trainer["network_trainer"] + initial_global_step = network_trainer.global_step + + target_global_step = network_trainer.global_step + steps + comfy_pbar = comfy.utils.ProgressBar(steps) + network_trainer.comfy_pbar = comfy_pbar + + network_trainer.optimizer_train_fn() + + while network_trainer.global_step < target_global_step: + steps_done = training_loop( + break_at_steps = target_global_step, + epoch = network_trainer.current_epoch.value, + ) + #pbar.update(steps_done) + + # Also break if the global steps have reached the max train steps + if network_trainer.global_step >= network_trainer.args.max_train_steps: + break + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, network_trainer.global_step) + +class FluxTrainAndValidateLoop: + @classmethod + def INPUT_TYPES(cls): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "validate_at_steps": ("INT", {"default": 250, "min": 1, "max": 10000, "step": 1, "tooltip": "the step point in training to validate/save"}), + "save_at_steps": ("INT", {"default": 250, "min": 1, "max": 10000, "step": 1, "tooltip": "the step point in training to validate/save"}), + }, + "optional": { + "validation_settings": ("VALSETTINGS",), + } + } + + RETURN_TYPES = ("NETWORKTRAINER", "INT",) + RETURN_NAMES = ("network_trainer", "steps",) + FUNCTION = "train" + CATEGORY = "FluxTrainer" + + def train(self, network_trainer, validate_at_steps, save_at_steps, validation_settings=None): + with torch.inference_mode(False): + training_loop = network_trainer["training_loop"] + network_trainer = network_trainer["network_trainer"] + + target_global_step = network_trainer.args.max_train_steps + comfy_pbar = comfy.utils.ProgressBar(target_global_step) + network_trainer.comfy_pbar = comfy_pbar + + network_trainer.optimizer_train_fn() + + while network_trainer.global_step < target_global_step: + next_validate_step = ((network_trainer.global_step // validate_at_steps) + 1) * validate_at_steps + next_save_step = ((network_trainer.global_step // save_at_steps) + 1) * save_at_steps + + steps_done = training_loop( + break_at_steps=min(next_validate_step, next_save_step), + epoch=network_trainer.current_epoch.value, + ) + + # Check if we need to validate + if network_trainer.global_step % validate_at_steps == 0: + self.validate(network_trainer, validation_settings) + + # Check if we need to save + if network_trainer.global_step % save_at_steps == 0: + self.save(network_trainer) + + # Also break if the global steps have reached the max train steps + if network_trainer.global_step >= network_trainer.args.max_train_steps: + break + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, network_trainer.global_step) + + def validate(self, network_trainer, validation_settings=None): + params = ( + network_trainer.current_epoch.value, + network_trainer.global_step, + validation_settings + ) + network_trainer.optimizer_eval_fn() + image_tensors = network_trainer.sample_images(*params) + network_trainer.optimizer_train_fn() + print("Validating at step:", network_trainer.global_step) + + def save(self, network_trainer): + ckpt_name = train_util.get_step_ckpt_name(network_trainer.args, "." + network_trainer.args.save_model_as, network_trainer.global_step) + network_trainer.optimizer_eval_fn() + network_trainer.save_model(ckpt_name, network_trainer.accelerator.unwrap_model(network_trainer.network), network_trainer.global_step, network_trainer.current_epoch.value + 1) + network_trainer.optimizer_train_fn() + print("Saving at step:", network_trainer.global_step) + +class FluxTrainSave: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "save_state": ("BOOLEAN", {"default": False, "tooltip": "save the whole model state as well"}), + "copy_to_comfy_lora_folder": ("BOOLEAN", {"default": False, "tooltip": "copy the lora model to the comfy lora folder"}), + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "STRING", "INT",) + RETURN_NAMES = ("network_trainer","lora_path", "steps",) + FUNCTION = "save" + CATEGORY = "FluxTrainer" + + def save(self, network_trainer, save_state, copy_to_comfy_lora_folder): + import shutil + with torch.inference_mode(False): + trainer = network_trainer["network_trainer"] + global_step = trainer.global_step + + ckpt_name = train_util.get_step_ckpt_name(trainer.args, "." + trainer.args.save_model_as, global_step) + trainer.save_model(ckpt_name, trainer.accelerator.unwrap_model(trainer.network), global_step, trainer.current_epoch.value + 1) + + remove_step_no = train_util.get_remove_step_no(trainer.args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(trainer.args, "." + trainer.args.save_model_as, remove_step_no) + trainer.remove_model(remove_ckpt_name) + + if save_state: + train_util.save_and_remove_state_stepwise(trainer.args, trainer.accelerator, global_step) + + lora_path = os.path.join(trainer.args.output_dir, ckpt_name) + if copy_to_comfy_lora_folder: + destination_dir = os.path.join(folder_paths.models_dir, "loras", "flux_trainer") + os.makedirs(destination_dir, exist_ok=True) + shutil.copy(lora_path, os.path.join(destination_dir, ckpt_name)) + + + return (network_trainer, lora_path, global_step) + +class FluxTrainSaveModel: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "copy_to_comfy_model_folder": ("BOOLEAN", {"default": False, "tooltip": "copy the lora model to the comfy lora folder"}), + "end_training": ("BOOLEAN", {"default": False, "tooltip": "end the training"}), + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "STRING", "INT",) + RETURN_NAMES = ("network_trainer","model_path", "steps",) + FUNCTION = "save" + CATEGORY = "FluxTrainer" + + def save(self, network_trainer, copy_to_comfy_model_folder, end_training): + import shutil + with torch.inference_mode(False): + trainer = network_trainer["network_trainer"] + global_step = trainer.global_step + + trainer.optimizer_eval_fn() + + ckpt_name = train_util.get_step_ckpt_name(trainer.args, "." + trainer.args.save_model_as, global_step) + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + trainer.args, + False, + trainer.accelerator, + trainer.save_dtype, + trainer.current_epoch.value, + trainer.num_train_epochs, + global_step, + trainer.accelerator.unwrap_model(trainer.unet) + ) + + model_path = os.path.join(trainer.args.output_dir, ckpt_name) + if copy_to_comfy_model_folder: + shutil.copy(model_path, os.path.join(folder_paths.models_dir, "diffusion_models", "flux_trainer", ckpt_name)) + model_path = os.path.join(folder_paths.models_dir, "diffusion_models", "flux_trainer", ckpt_name) + if end_training: + trainer.accelerator.end_training() + + return (network_trainer, model_path, global_step) + +class FluxTrainEnd: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "save_state": ("BOOLEAN", {"default": True}), + }, + } + + RETURN_TYPES = ("STRING", "STRING", "STRING",) + RETURN_NAMES = ("lora_name", "metadata", "lora_path",) + FUNCTION = "endtrain" + CATEGORY = "FluxTrainer" + OUTPUT_NODE = True + + def endtrain(self, network_trainer, save_state): + with torch.inference_mode(False): + training_loop = network_trainer["training_loop"] + network_trainer = network_trainer["network_trainer"] + + network_trainer.metadata["ss_epoch"] = str(network_trainer.num_train_epochs) + network_trainer.metadata["ss_training_finished_at"] = str(time.time()) + + network = network_trainer.accelerator.unwrap_model(network_trainer.network) + + network_trainer.accelerator.end_training() + network_trainer.optimizer_eval_fn() + + if save_state: + train_util.save_state_on_train_end(network_trainer.args, network_trainer.accelerator) + + ckpt_name = train_util.get_last_ckpt_name(network_trainer.args, "." + network_trainer.args.save_model_as) + network_trainer.save_model(ckpt_name, network, network_trainer.global_step, network_trainer.num_train_epochs, force_sync_upload=True) + logger.info("model saved.") + + final_lora_name = str(network_trainer.args.output_name) + final_lora_path = os.path.join(network_trainer.args.output_dir, ckpt_name) + + # metadata + metadata = json.dumps(network_trainer.metadata, indent=2) + + training_loop = None + network_trainer = None + mm.soft_empty_cache() + + return (final_lora_name, metadata, final_lora_path) + +class FluxTrainResume: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "load_state_path": ("STRING", {"default": "", "multiline": True, "tooltip": "path to load state from"}), + "skip_until_initial_step" : ("BOOLEAN", {"default": False}), + }, + } + + RETURN_TYPES = ("ARGS", ) + RETURN_NAMES = ("resume_args", ) + FUNCTION = "resume" + CATEGORY = "FluxTrainer" + + def resume(self, load_state_path, skip_until_initial_step): + resume_args ={ + "resume": load_state_path, + "skip_until_initial_step": skip_until_initial_step + } + + return (resume_args, ) + +class FluxTrainBlockSelect: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "include": ("STRING", {"default": "lora_unet_single_blocks_20_linear2", "multiline": True, "tooltip": "blocks to include in the LoRA network, to select multiple blocks either input them as "}), + }, + } + + RETURN_TYPES = ("ARGS", ) + RETURN_NAMES = ("block_args", ) + FUNCTION = "block_select" + CATEGORY = "FluxTrainer" + + def block_select(self, include): + import re + + # Split the input string by commas to handle multiple ranges/blocks + elements = include.split(',') + + # Initialize a list to collect block names + blocks = [] + + # Pattern to find ranges like (10-20) + pattern = re.compile(r'\((\d+)-(\d+)\)') + + # Extract the prefix and suffix from the first element + prefix_suffix_pattern = re.compile(r'(.*)_blocks_(.*)') + + for element in elements: + element = element.strip() + match = prefix_suffix_pattern.match(element) + if match: + prefix = match.group(1) + "_blocks_" + suffix = match.group(2) + matches = pattern.findall(suffix) + if matches: + for start, end in matches: + # Generate block names for the range and add them to the list + blocks.extend([f"{prefix}{i}{suffix.replace(f'({start}-{end})', '', 1)}" for i in range(int(start), int(end) + 1)]) + else: + # If no range is found, add the block name directly + blocks.append(element) + else: + blocks.append(element) + + # Construct the `include` string + include_string = ','.join(blocks) + + block_args = { + "include": f"only_if_contains={include_string}", + } + + return (block_args, ) + +class FluxTrainValidationSettings: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "steps": ("INT", {"default": 20, "min": 1, "max": 256, "step": 1, "tooltip": "sampling steps"}), + "width": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 8, "tooltip": "image width"}), + "height": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 8, "tooltip": "image height"}), + "guidance_scale": ("FLOAT", {"default": 3.5, "min": 1.0, "max": 32.0, "step": 0.05, "tooltip": "guidance scale"}), + "seed": ("INT", {"default": 42,"min": 0, "max": 0xffffffffffffffff, "step": 1}), + "shift": ("BOOLEAN", {"default": True, "tooltip": "shift the schedule to favor high timesteps for higher signal images"}), + "base_shift": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}), + "max_shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 10.0, "step": 0.01}), + }, + } + + RETURN_TYPES = ("VALSETTINGS", ) + RETURN_NAMES = ("validation_settings", ) + FUNCTION = "set" + CATEGORY = "FluxTrainer" + + def set(self, **kwargs): + validation_settings = kwargs + print(validation_settings) + + return (validation_settings,) + +class FluxTrainValidate: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "network_trainer": ("NETWORKTRAINER",), + }, + "optional": { + "validation_settings": ("VALSETTINGS",), + } + } + + RETURN_TYPES = ("NETWORKTRAINER", "IMAGE",) + RETURN_NAMES = ("network_trainer", "validation_images",) + FUNCTION = "validate" + CATEGORY = "FluxTrainer" + + def validate(self, network_trainer, validation_settings=None): + training_loop = network_trainer["training_loop"] + network_trainer = network_trainer["network_trainer"] + + params = ( + network_trainer.current_epoch.value, + network_trainer.global_step, + validation_settings + ) + network_trainer.optimizer_eval_fn() + with torch.inference_mode(False): + image_tensors = network_trainer.sample_images(*params) + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, (0.5 * (image_tensors + 1.0)).cpu().float(),) + +class VisualizeLoss: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "plot_style": (plt.style.available,{"default": 'default', "tooltip": "matplotlib plot style"}), + "window_size": ("INT", {"default": 100, "min": 0, "max": 10000, "step": 1, "tooltip": "the window size of the moving average"}), + "normalize_y": ("BOOLEAN", {"default": True, "tooltip": "normalize the y-axis to 0"}), + "width": ("INT", {"default": 768, "min": 256, "max": 4096, "step": 2, "tooltip": "width of the plot in pixels"}), + "height": ("INT", {"default": 512, "min": 256, "max": 4096, "step": 2, "tooltip": "height of the plot in pixels"}), + "log_scale": ("BOOLEAN", {"default": False, "tooltip": "use log scale on the y-axis"}), + }, + } + + RETURN_TYPES = ("IMAGE", "FLOAT",) + RETURN_NAMES = ("plot", "loss_list",) + FUNCTION = "draw" + CATEGORY = "FluxTrainer" + + def draw(self, network_trainer, window_size, plot_style, normalize_y, width, height, log_scale): + import numpy as np + loss_values = network_trainer["network_trainer"].loss_recorder.global_loss_list + + # Apply moving average + def moving_average(values, window_size): + return np.convolve(values, np.ones(window_size) / window_size, mode='valid') + if window_size > 0: + loss_values = moving_average(loss_values, window_size) + + plt.style.use(plot_style) + + # Convert pixels to inches (assuming 100 pixels per inch) + width_inches = width / 100 + height_inches = height / 100 + + # Create a plot + fig, ax = plt.subplots(figsize=(width_inches, height_inches)) + ax.plot(loss_values, label='Training Loss') + ax.set_xlabel('Step') + ax.set_ylabel('Loss') + if normalize_y: + plt.ylim(bottom=0) + if log_scale: + ax.set_yscale('log') + ax.set_title('Training Loss Over Time') + ax.legend() + ax.grid(True) + + buf = io.BytesIO() + plt.savefig(buf, format='png') + plt.close(fig) + buf.seek(0) + + image = Image.open(buf).convert('RGB') + + image_tensor = transforms.ToTensor()(image) + image_tensor = image_tensor.unsqueeze(0).permute(0, 2, 3, 1).cpu().float() + + return image_tensor, loss_values, + +class FluxKohyaInferenceSampler: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "flux_models": ("TRAIN_FLUX_MODELS",), + "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}), + "lora_method": (["apply", "merge"], {"tooltip": "whether to apply or merge the lora weights"}), + "steps": ("INT", {"default": 20, "min": 1, "max": 256, "step": 1, "tooltip": "sampling steps"}), + "width": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 8, "tooltip": "image width"}), + "height": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 8, "tooltip": "image height"}), + "guidance_scale": ("FLOAT", {"default": 3.5, "min": 1.0, "max": 32.0, "step": 0.05, "tooltip": "guidance scale"}), + "seed": ("INT", {"default": 42,"min": 0, "max": 0xffffffffffffffff, "step": 1}), + "use_fp8": ("BOOLEAN", {"default": True, "tooltip": "use fp8 weights"}), + "apply_t5_attn_mask": ("BOOLEAN", {"default": True, "tooltip": "use t5 attention mask"}), + "prompt": ("STRING", {"multiline": True, "default": "illustration of a kitten", "tooltip": "prompt"}), + + }, + } + + RETURN_TYPES = ("IMAGE", ) + RETURN_NAMES = ("image", ) + FUNCTION = "sample" + CATEGORY = "FluxTrainer" + + def sample(self, flux_models, lora_name, steps, width, height, guidance_scale, seed, prompt, use_fp8, lora_method, apply_t5_attn_mask): + + from .library import flux_utils as flux_utils + from .library import strategy_flux as strategy_flux + from .networks import lora_flux as lora_flux + from typing import List, Optional, Callable + from tqdm import tqdm + import einops + import math + import accelerate + import gc + + device = "cuda" + + + if use_fp8: + accelerator = accelerate.Accelerator(mixed_precision="bf16") + dtype = torch.float8_e4m3fn + else: + dtype = torch.float16 + accelerator = None + loading_device = "cpu" + ae_dtype = torch.bfloat16 + + pretrained_model_name_or_path = flux_models["transformer"] + clip_l = flux_models["clip_l"] + t5xxl = flux_models["t5"] + ae = flux_models["vae"] + lora_path = folder_paths.get_full_path("loras", lora_name) + + # load clip_l + logger.info(f"Loading clip_l from {clip_l}...") + clip_l = flux_utils.load_clip_l(clip_l, None, loading_device) + clip_l.eval() + + logger.info(f"Loading t5xxl from {t5xxl}...") + t5xxl = flux_utils.load_t5xxl(t5xxl, None, loading_device) + t5xxl.eval() + + if use_fp8: + clip_l = accelerator.prepare(clip_l) + t5xxl = accelerator.prepare(t5xxl) + + t5xxl_max_length = 512 + tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) + encoding_strategy = strategy_flux.FluxTextEncodingStrategy() + + # DiT + model = flux_utils.load_flow_model("dev", pretrained_model_name_or_path, dtype, loading_device) + model.eval() + logger.info(f"Casting model to {dtype}") + model.to(dtype) # make sure model is dtype + if use_fp8: + model = accelerator.prepare(model) + + # AE + ae = flux_utils.load_ae("dev", ae, ae_dtype, loading_device) + ae.eval() + + + # LoRA + lora_models: List[lora_flux.LoRANetwork] = [] + multiplier = 1.0 + + lora_model, weights_sd = lora_flux.create_network_from_weights( + multiplier, lora_path, ae, [clip_l, t5xxl], model, None, True + ) + if lora_method == "merge": + lora_model.merge_to([clip_l, t5xxl], model, weights_sd) + elif lora_method == "apply": + lora_model.apply_to([clip_l, t5xxl], model) + info = lora_model.load_state_dict(weights_sd, strict=True) + logger.info(f"Loaded LoRA weights from {lora_name}: {info}") + lora_model.eval() + lora_model.to(device) + lora_models.append(lora_model) + + + packed_latent_height, packed_latent_width = math.ceil(height / 16), math.ceil(width / 16) + noise = torch.randn( + 1, + packed_latent_height * packed_latent_width, + 16 * 2 * 2, + device=device, + dtype=ae_dtype, + generator=torch.Generator(device=device).manual_seed(seed), + ) + + img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width) + + # prepare embeddings + logger.info("Encoding prompts...") + tokens_and_masks = tokenize_strategy.tokenize(prompt) + clip_l = clip_l.to(device) + t5xxl = t5xxl.to(device) + with torch.no_grad(): + if use_fp8: + clip_l.to(ae_dtype) + t5xxl.to(ae_dtype) + with accelerator.autocast(): + l_pooled, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, apply_t5_attn_mask + ) + else: + with torch.autocast(device_type=device.type, dtype=dtype): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + with torch.autocast(device_type=device.type, dtype=dtype): + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [None, t5xxl], tokens_and_masks, apply_t5_attn_mask + ) + # NaN check + if torch.isnan(l_pooled).any(): + raise ValueError("NaN in l_pooled") + + if torch.isnan(t5_out).any(): + raise ValueError("NaN in t5_out") + + + clip_l = clip_l.cpu() + t5xxl = t5xxl.cpu() + + gc.collect() + torch.cuda.empty_cache() + + # generate image + logger.info("Generating image...") + model = model.to(device) + print("MODEL DTYPE: ", model.dtype) + + img_ids = img_ids.to(device) + t5_attn_mask = t5_attn_mask.to(device) if apply_t5_attn_mask else None + def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + + def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + + def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, + ) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + + def denoise( + model, + img: torch.Tensor, + img_ids: torch.Tensor, + txt: torch.Tensor, + txt_ids: torch.Tensor, + vec: torch.Tensor, + timesteps: list[float], + guidance: float = 4.0, + t5_attn_mask: Optional[torch.Tensor] = None, + ): + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + comfy_pbar = comfy.utils.ProgressBar(total=len(timesteps)) + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + img = img + (t_prev - t_curr) * pred + comfy_pbar.update(1) + + return img + def do_sample( + accelerator: Optional[accelerate.Accelerator], + model, + img: torch.Tensor, + img_ids: torch.Tensor, + l_pooled: torch.Tensor, + t5_out: torch.Tensor, + txt_ids: torch.Tensor, + num_steps: int, + guidance: float, + t5_attn_mask: Optional[torch.Tensor], + is_schnell: bool, + device: torch.device, + flux_dtype: torch.dtype, + ): + timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell) + print(timesteps) + + # denoise initial noise + if accelerator: + with accelerator.autocast(), torch.no_grad(): + x = denoise( + model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask + ) + else: + with torch.autocast(device_type=device.type, dtype=flux_dtype): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + with torch.autocast(device_type=device.type, dtype=flux_dtype): + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [None, t5xxl], tokens_and_masks, apply_t5_attn_mask + ) + + return x + + x = do_sample(accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance_scale, t5_attn_mask, False, device, dtype) + + model = model.cpu() + gc.collect() + torch.cuda.empty_cache() + + # unpack + x = x.float() + x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) + + # decode + logger.info("Decoding image...") + ae = ae.to(device) + with torch.no_grad(): + if use_fp8: + with accelerator.autocast(): + x = ae.decode(x) + else: + with torch.autocast(device_type=device.type, dtype=ae_dtype): + x = ae.decode(x) + + ae = ae.cpu() + + x = x.clamp(-1, 1) + x = x.permute(0, 2, 3, 1) + + return ((0.5 * (x + 1.0)).cpu().float(),) + +class UploadToHuggingFace: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "network_trainer": ("NETWORKTRAINER",), + "source_path": ("STRING", {"default": ""}), + "repo_id": ("STRING",{"default": ""}), + "revision": ("STRING", {"default": ""}), + "private": ("BOOLEAN", {"default": True, "tooltip": "If creating a new repo, leave it private"}), + }, + "optional": { + "token": ("STRING", {"default": "","tooltip":"DO NOT LEAVE IN THE NODE or it might save in metadata, can also use the hf_token.json"}), + } + } + + RETURN_TYPES = ("NETWORKTRAINER", "STRING",) + RETURN_NAMES = ("network_trainer","status",) + FUNCTION = "upload" + CATEGORY = "FluxTrainer" + + def upload(self, source_path, network_trainer, repo_id, private, revision, token=""): + with torch.inference_mode(False): + from huggingface_hub import HfApi + + if not token: + with open(os.path.join(script_directory, "hf_token.json"), "r") as file: + token_data = json.load(file) + token = token_data["hf_token"] + print(token) + + # Save metadata to a JSON file + directory_path = os.path.dirname(os.path.dirname(source_path)) + file_name = os.path.basename(source_path) + + metadata = network_trainer["network_trainer"].metadata + metadata_file_path = os.path.join(directory_path, "metadata.json") + with open(metadata_file_path, 'w') as f: + json.dump(metadata, f, indent=4) + + repo_type = None + api = HfApi(token=token) + + try: + api.repo_info( + repo_id=repo_id, + revision=revision if revision != "" else None, + repo_type=repo_type) + repo_exists = True + logger.info(f"Repository {repo_id} exists.") + except Exception as e: # Catching a more specific exception would be better if you know what to expect + repo_exists = False + logger.error(f"Repository {repo_id} does not exist. Exception: {e}") + + if not repo_exists: + try: + api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) + except Exception as e: # Checked for RepositoryNotFoundError, but other exceptions could be problematic + logger.error("===========================================") + logger.error(f"failed to create HuggingFace repo: {e}") + logger.error("===========================================") + + is_folder = (type(source_path) == str and os.path.isdir(source_path)) or (isinstance(source_path, Path) and source_path.is_dir()) + print(source_path, is_folder) + + try: + if is_folder: + api.upload_folder( + repo_id=repo_id, + repo_type=repo_type, + folder_path=source_path, + path_in_repo=file_name, + ) + else: + api.upload_file( + repo_id=repo_id, + repo_type=repo_type, + path_or_fileobj=source_path, + path_in_repo=file_name, + ) + # Upload the metadata file separately if it's not a folder upload + if not is_folder: + api.upload_file( + repo_id=repo_id, + repo_type=repo_type, + path_or_fileobj=str(metadata_file_path), + path_in_repo='metadata.json', + ) + status = "Uploaded to HuggingFace succesfully" + except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので + logger.error("===========================================") + logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") + logger.error("===========================================") + status = f"Failed to upload to HuggingFace {e}" + + return (network_trainer, status,) + +class ExtractFluxLoRA: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "original_model": (folder_paths.get_filename_list("unet"), ), + "finetuned_model": (folder_paths.get_filename_list("unet"), ), + "output_path": ("STRING", {"default": f"{str(os.path.join(folder_paths.models_dir, 'loras', 'Flux'))}"}), + "dim": ("INT", {"default": 4, "min": 2, "max": 1024, "step": 2, "tooltip": "LoRA rank"}), + "save_dtype": (["fp32", "fp16", "bf16", "fp8_e4m3fn", "fp8_e5m2"], {"default": "bf16", "tooltip": "the dtype to save the LoRA as"}), + "load_device": (["cpu", "cuda"], {"default": "cuda", "tooltip": "the device to load the model to"}), + "store_device": (["cpu", "cuda"], {"default": "cpu", "tooltip": "the device to store the LoRA as"}), + "clamp_quantile": ("FLOAT", {"default": 0.99, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "clamp quantile"}), + "metadata": ("BOOLEAN", {"default": True, "tooltip": "build metadata"}), + "mem_eff_safe_open": ("BOOLEAN", {"default": False, "tooltip": "memory efficient loading"}), + }, + } + + RETURN_TYPES = ("STRING", ) + RETURN_NAMES = ("output_path",) + FUNCTION = "extract" + CATEGORY = "FluxTrainer" + + def extract(self, original_model, finetuned_model, output_path, dim, save_dtype, load_device, store_device, clamp_quantile, metadata, mem_eff_safe_open): + from .flux_extract_lora import svd + transformer_path = folder_paths.get_full_path("unet", original_model) + finetuned_model_path = folder_paths.get_full_path("unet", finetuned_model) + outpath = svd( + model_org = transformer_path, + model_tuned = finetuned_model_path, + save_to = os.path.join(output_path, f"{finetuned_model.replace('.safetensors', '')}_extracted_lora_rank_{dim}-{save_dtype}.safetensors"), + dim = dim, + device = load_device, + store_device = store_device, + save_precision = save_dtype, + clamp_quantile = clamp_quantile, + no_metadata = not metadata, + mem_eff_safe_open = mem_eff_safe_open + ) + + return (outpath,) + +NODE_CLASS_MAPPINGS = { + "InitFluxLoRATraining": InitFluxLoRATraining, + "InitFluxTraining": InitFluxTraining, + "FluxTrainModelSelect": FluxTrainModelSelect, + "TrainDatasetGeneralConfig": TrainDatasetGeneralConfig, + "TrainDatasetAdd": TrainDatasetAdd, + "FluxTrainLoop": FluxTrainLoop, + "VisualizeLoss": VisualizeLoss, + "FluxTrainValidate": FluxTrainValidate, + "FluxTrainValidationSettings": FluxTrainValidationSettings, + "FluxTrainEnd": FluxTrainEnd, + "FluxTrainSave": FluxTrainSave, + "FluxKohyaInferenceSampler": FluxKohyaInferenceSampler, + "UploadToHuggingFace": UploadToHuggingFace, + "OptimizerConfig": OptimizerConfig, + "OptimizerConfigAdafactor": OptimizerConfigAdafactor, + "FluxTrainSaveModel": FluxTrainSaveModel, + "ExtractFluxLoRA": ExtractFluxLoRA, + "OptimizerConfigProdigy": OptimizerConfigProdigy, + "FluxTrainResume": FluxTrainResume, + "FluxTrainBlockSelect": FluxTrainBlockSelect, + "TrainDatasetRegularization": TrainDatasetRegularization, + "FluxTrainAndValidateLoop": FluxTrainAndValidateLoop, + "OptimizerConfigProdigyPlusScheduleFree": OptimizerConfigProdigyPlusScheduleFree, + "FluxTrainerLossConfig": FluxTrainerLossConfig, + "TrainNetworkConfig": TrainNetworkConfig, +} +NODE_DISPLAY_NAME_MAPPINGS = { + "InitFluxLoRATraining": "Init Flux LoRA Training", + "InitFluxTraining": "Init Flux Training", + "FluxTrainModelSelect": "FluxTrain ModelSelect", + "TrainDatasetGeneralConfig": "TrainDatasetGeneralConfig", + "TrainDatasetAdd": "TrainDatasetAdd", + "FluxTrainLoop": "Flux Train Loop", + "VisualizeLoss": "Visualize Loss", + "FluxTrainValidate": "Flux Train Validate", + "FluxTrainValidationSettings": "Flux Train Validation Settings", + "FluxTrainEnd": "Flux LoRA Train End", + "FluxTrainSave": "Flux Train Save LoRA", + "FluxKohyaInferenceSampler": "Flux Kohya Inference Sampler", + "UploadToHuggingFace": "Upload To HuggingFace", + "OptimizerConfig": "Optimizer Config", + "OptimizerConfigAdafactor": "Optimizer Config Adafactor", + "FluxTrainSaveModel": "Flux Train Save Model", + "ExtractFluxLoRA": "Extract Flux LoRA", + "OptimizerConfigProdigy": "Optimizer Config Prodigy", + "FluxTrainResume": "Flux Train Resume", + "FluxTrainBlockSelect": "Flux Train Block Select", + "TrainDatasetRegularization": "Train Dataset Regularization", + "FluxTrainAndValidateLoop": "Flux Train And Validate Loop", + "OptimizerConfigProdigyPlusScheduleFree": "Optimizer Config ProdigyPlusScheduleFree", + "FluxTrainerLossConfig": "Flux Trainer Loss Config", + "TrainNetworkConfig": "Train Network Config", +} diff --git a/nodes_sd3.py b/nodes_sd3.py new file mode 100644 index 0000000000000000000000000000000000000000..512d0af11f4b48febacfd1cc932cfd32ba057635 --- /dev/null +++ b/nodes_sd3.py @@ -0,0 +1,467 @@ +import os +import torch + +import folder_paths +import comfy.model_management as mm +import comfy.utils +import toml +import json +import time +import shutil +import shlex + +script_directory = os.path.dirname(os.path.abspath(__file__)) + +from .sd3_train_network import Sd3NetworkTrainer +from .library import sd3_train_utils as sd3_train_utils +from .library.device_utils import init_ipex +init_ipex() + +from .library import train_util +from .train_network import setup_parser as train_network_setup_parser + +import logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +class SD3ModelSelect: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "transformer": (folder_paths.get_filename_list("checkpoints"), ), + "clip_l": (folder_paths.get_filename_list("clip"), ), + "clip_g": (folder_paths.get_filename_list("clip"), ), + "t5": (folder_paths.get_filename_list("clip"), ), + }, + "optional": { + "lora_path": ("STRING",{"multiline": True, "forceInput": True, "default": "", "tooltip": "pre-trained LoRA path to load (network_weights)"}), + } + } + + RETURN_TYPES = ("TRAIN_SD3_MODELS",) + RETURN_NAMES = ("sd3_models",) + FUNCTION = "loadmodel" + CATEGORY = "FluxTrainer/SD3" + + def loadmodel(self, transformer, clip_l, clip_g, t5, lora_path=""): + + transformer_path = folder_paths.get_full_path("checkpoints", transformer) + clip_l_path = folder_paths.get_full_path("clip", clip_l) + clip_g_path = folder_paths.get_full_path("clip", clip_g) + t5_path = folder_paths.get_full_path("clip", t5) + + sd3_models = { + "transformer": transformer_path, + "clip_l": clip_l_path, + "clip_g": clip_g_path, + "t5": t5_path, + "lora_path": lora_path + } + + return (sd3_models,) + +class InitSD3LoRATraining: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "sd3_models": ("TRAIN_SD3_MODELS",), + "dataset": ("JSON",), + "optimizer_settings": ("ARGS",), + "output_name": ("STRING", {"default": "sd35_lora", "multiline": False}), + "output_dir": ("STRING", {"default": "sd35_trainer_output", "multiline": False, "tooltip": "path to dataset, root is the 'ComfyUI' folder, with windows portable 'ComfyUI_windows_portable'"}), + "network_dim": ("INT", {"default": 16, "min": 1, "max": 2048, "step": 1, "tooltip": "network dim"}), + "network_alpha": ("FLOAT", {"default": 16, "min": 0.0, "max": 2048.0, "step": 0.01, "tooltip": "network alpha"}), + "learning_rate": ("FLOAT", {"default": 1e-4, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "learning rate"}), + "max_train_steps": ("INT", {"default": 1500, "min": 1, "max": 100000, "step": 1, "tooltip": "max number of training steps"}), + "cache_latents": (["disk", "memory", "disabled"], {"tooltip": "caches text encoder outputs"}), + "cache_text_encoder_outputs": (["disk", "memory", "disabled"], {"tooltip": "caches text encoder outputs"}), + "training_shift ": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.0001, "tooltip": "shift value for the training distribution of timesteps"}), + "highvram": ("BOOLEAN", {"default": False, "tooltip": "memory mode"}), + "blocks_to_swap": ("INT", {"default": 0, "min": 0, "max": 100, "step": 1, "tooltip": "option for memory use reduction. The maximum number of blocks that can be swapped is 36 for SD3.5L and 22 for SD3.5M"}), + "fp8_base": ("BOOLEAN", {"default": False, "tooltip": "use fp8 for base model"}), + "gradient_dtype": (["fp32", "fp16", "bf16"], {"default": "fp32", "tooltip": "the actual dtype training uses"}), + "save_dtype": (["fp32", "fp16", "bf16", "fp8_e4m3fn", "fp8_e5m2"], {"default": "bf16", "tooltip": "the dtype to save checkpoints as"}), + "attention_mode": (["sdpa", "xformers", "disabled"], {"default": "sdpa", "tooltip": "memory efficient attention mode"}), + "train_text_encoder": (['disabled', 'clip_l', 'clip_l_fp8', 'clip_l+T5', 'clip_l+T5_fp8'], {"default": 'disabled', "tooltip": "also train the selected text encoders using specified dtype, T5 can not be trained without clip_l"}), + "clip_l_lr": ("FLOAT", {"default": 0, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "text encoder learning rate"}), + "clip_g_lr": ("FLOAT", {"default": 0, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "text encoder learning rate"}), + "T5_lr": ("FLOAT", {"default": 0, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "text encoder learning rate"}), + "sample_prompts": ("STRING", {"multiline": True, "default": "illustration of a kitten | photograph of a turtle", "tooltip": "validation sample prompts, for multiple prompts, separate by `|`"}), + "gradient_checkpointing": (["enabled", "disabled"], {"default": "enabled", "tooltip": "use gradient checkpointing"}), + }, + "optional": { + "additional_args": ("STRING", {"multiline": True, "default": "", "tooltip": "additional args to pass to the training command"}), + "resume_args": ("ARGS", {"default": "", "tooltip": "resume args to pass to the training command"}), + "block_args": ("ARGS", {"default": "", "tooltip": "limit the blocks used in the LoRA"}), + "loss_args": ("ARGS", {"default": "", "tooltip": "loss args"}), + }, + "hidden": { + "prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO" + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "INT", "KOHYA_ARGS",) + RETURN_NAMES = ("network_trainer", "epochs_count", "args",) + FUNCTION = "init_training" + CATEGORY = "FluxTrainer/SD3" + + def init_training(self, sd3_models, dataset, optimizer_settings, sample_prompts, output_name, attention_mode, + gradient_dtype, save_dtype, additional_args=None, resume_args=None, train_text_encoder='disabled', + block_args=None, gradient_checkpointing="enabled", prompt=None, extra_pnginfo=None, clip_l_lr=0, clip_g_lr=0, T5_lr=0, loss_args=None, **kwargs): + mm.soft_empty_cache() + + output_dir = os.path.abspath(kwargs.get("output_dir")) + os.makedirs(output_dir, exist_ok=True) + + total, used, free = shutil.disk_usage(output_dir) + + required_free_space = 2 * (2**30) + if free <= required_free_space: + raise ValueError(f"Insufficient disk space. Required: {required_free_space/2**30}GB. Available: {free/2**30}GB") + + dataset_config = dataset["datasets"] + dataset_toml = toml.dumps(json.loads(dataset_config)) + + parser = train_network_setup_parser() + sd3_train_utils.add_sd3_training_arguments(parser) + if additional_args is not None: + print(f"additional_args: {additional_args}") + args, _ = parser.parse_known_args(args=shlex.split(additional_args)) + else: + args, _ = parser.parse_known_args() + + if kwargs.get("cache_latents") == "memory": + kwargs["cache_latents"] = True + kwargs["cache_latents_to_disk"] = False + elif kwargs.get("cache_latents") == "disk": + kwargs["cache_latents"] = True + kwargs["cache_latents_to_disk"] = True + kwargs["caption_dropout_rate"] = 0.0 + kwargs["shuffle_caption"] = False + kwargs["token_warmup_step"] = 0.0 + kwargs["caption_tag_dropout_rate"] = 0.0 + else: + kwargs["cache_latents"] = False + kwargs["cache_latents_to_disk"] = False + + if kwargs.get("cache_text_encoder_outputs") == "memory": + kwargs["cache_text_encoder_outputs"] = True + kwargs["cache_text_encoder_outputs_to_disk"] = False + elif kwargs.get("cache_text_encoder_outputs") == "disk": + kwargs["cache_text_encoder_outputs"] = True + kwargs["cache_text_encoder_outputs_to_disk"] = True + else: + kwargs["cache_text_encoder_outputs"] = False + kwargs["cache_text_encoder_outputs_to_disk"] = False + + if '|' in sample_prompts: + prompts = sample_prompts.split('|') + else: + prompts = [sample_prompts] + + config_dict = { + "sample_prompts": prompts, + "save_precision": save_dtype, + "mixed_precision": "bf16", + "num_cpu_threads_per_process": 1, + "pretrained_model_name_or_path": sd3_models["transformer"], + "clip_l": sd3_models["clip_l"], + "clip_g": sd3_models["clip_g"], + "t5xxl": sd3_models["t5"], + "save_model_as": "safetensors", + "persistent_data_loader_workers": False, + "max_data_loader_n_workers": 0, + "seed": 42, + "network_module": ".networks.lora_sd3", + "dataset_config": dataset_toml, + "output_name": f"{output_name}_rank{kwargs.get('network_dim')}_{save_dtype}", + "loss_type": "l2", + "t5xxl_max_token_length": 512, + "alpha_mask": dataset["alpha_mask"], + "network_train_unet_only": True if train_text_encoder == 'disabled' else False, + "fp8_base_unet": True if "fp8" in train_text_encoder else False, + "disable_mmap_load_safetensors": False, + } + attention_settings = { + "sdpa": {"mem_eff_attn": True, "xformers": False, "spda": True}, + "xformers": {"mem_eff_attn": True, "xformers": True, "spda": False} + } + config_dict.update(attention_settings.get(attention_mode, {})) + + gradient_dtype_settings = { + "fp16": {"full_fp16": True, "full_bf16": False, "mixed_precision": "fp16"}, + "bf16": {"full_bf16": True, "full_fp16": False, "mixed_precision": "bf16"} + } + config_dict.update(gradient_dtype_settings.get(gradient_dtype, {})) + + if train_text_encoder != 'disabled': + config_dict["text_encoder_lr"] = [clip_l_lr, clip_g_lr, T5_lr] + + #network args + additional_network_args = [] + + if "T5" in train_text_encoder: + additional_network_args.append("train_t5xxl=True") + + if block_args: + additional_network_args.append(block_args["include"]) + + # Handle network_args in args Namespace + if hasattr(args, 'network_args') and isinstance(args.network_args, list): + args.network_args.extend(additional_network_args) + else: + setattr(args, 'network_args', additional_network_args) + + if gradient_checkpointing == "disabled": + config_dict["gradient_checkpointing"] = False + elif gradient_checkpointing == "enabled_with_cpu_offloading": + config_dict["gradient_checkpointing"] = True + config_dict["cpu_offload_checkpointing"] = True + else: + config_dict["gradient_checkpointing"] = True + + if sd3_models["lora_path"]: + config_dict["network_weights"] = sd3_models["lora_path"] + + config_dict.update(kwargs) + config_dict.update(optimizer_settings) + + if loss_args: + config_dict.update(loss_args) + + if resume_args: + config_dict.update(resume_args) + + for key, value in config_dict.items(): + setattr(args, key, value) + + saved_args_file_path = os.path.join(output_dir, f"{output_name}_args.json") + with open(saved_args_file_path, 'w') as f: + json.dump(vars(args), f, indent=4) + + #workflow saving + metadata = {} + if extra_pnginfo is not None: + metadata.update(extra_pnginfo["workflow"]) + + saved_workflow_file_path = os.path.join(output_dir, f"{output_name}_workflow.json") + with open(saved_workflow_file_path, 'w') as f: + json.dump(metadata, f, indent=4) + + #pass args to kohya and initialize trainer + with torch.inference_mode(False): + network_trainer = Sd3NetworkTrainer() + training_loop = network_trainer.init_train(args) + + epochs_count = network_trainer.num_train_epochs + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, epochs_count, args) + + +class SD3TrainLoop: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "steps": ("INT", {"default": 1, "min": 1, "max": 10000, "step": 1, "tooltip": "the step point in training to validate/save"}), + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "INT",) + RETURN_NAMES = ("network_trainer", "steps",) + FUNCTION = "train" + CATEGORY = "FluxTrainer" + + def train(self, network_trainer, steps): + with torch.inference_mode(False): + training_loop = network_trainer["training_loop"] + network_trainer = network_trainer["network_trainer"] + initial_global_step = network_trainer.global_step + + target_global_step = network_trainer.global_step + steps + comfy_pbar = comfy.utils.ProgressBar(steps) + network_trainer.comfy_pbar = comfy_pbar + + network_trainer.optimizer_train_fn() + + while network_trainer.global_step < target_global_step: + steps_done = training_loop( + break_at_steps = target_global_step, + epoch = network_trainer.current_epoch.value, + ) + + # Also break if the global steps have reached the max train steps + if network_trainer.global_step >= network_trainer.args.max_train_steps: + break + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, network_trainer.global_step) + + +class SD3TrainLoRASave: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "save_state": ("BOOLEAN", {"default": False, "tooltip": "save the whole model state as well"}), + "copy_to_comfy_lora_folder": ("BOOLEAN", {"default": False, "tooltip": "copy the lora model to the comfy lora folder"}), + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "STRING", "INT",) + RETURN_NAMES = ("network_trainer","lora_path", "steps",) + FUNCTION = "save" + CATEGORY = "FluxTrainer" + + def save(self, network_trainer, save_state, copy_to_comfy_lora_folder): + import shutil + with torch.inference_mode(False): + trainer = network_trainer["network_trainer"] + global_step = trainer.global_step + + ckpt_name = train_util.get_step_ckpt_name(trainer.args, "." + trainer.args.save_model_as, global_step) + trainer.save_model(ckpt_name, trainer.accelerator.unwrap_model(trainer.network), global_step, trainer.current_epoch.value + 1) + + remove_step_no = train_util.get_remove_step_no(trainer.args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(trainer.args, "." + trainer.args.save_model_as, remove_step_no) + trainer.remove_model(remove_ckpt_name) + + if save_state: + train_util.save_and_remove_state_stepwise(trainer.args, trainer.accelerator, global_step) + + lora_path = os.path.join(trainer.args.output_dir, ckpt_name) + if copy_to_comfy_lora_folder: + destination_dir = os.path.join(folder_paths.models_dir, "loras", "flux_trainer") + os.makedirs(destination_dir, exist_ok=True) + shutil.copy(lora_path, os.path.join(destination_dir, ckpt_name)) + + + return (network_trainer, lora_path, global_step) + + + +class SD3TrainEnd: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "save_state": ("BOOLEAN", {"default": True}), + }, + } + + RETURN_TYPES = ("STRING", "STRING", "STRING",) + RETURN_NAMES = ("lora_name", "metadata", "lora_path",) + FUNCTION = "endtrain" + CATEGORY = "FluxTrainer" + OUTPUT_NODE = True + + def endtrain(self, network_trainer, save_state): + with torch.inference_mode(False): + training_loop = network_trainer["training_loop"] + network_trainer = network_trainer["network_trainer"] + + network_trainer.metadata["ss_epoch"] = str(network_trainer.num_train_epochs) + network_trainer.metadata["ss_training_finished_at"] = str(time.time()) + + network = network_trainer.accelerator.unwrap_model(network_trainer.network) + + network_trainer.accelerator.end_training() + network_trainer.optimizer_eval_fn() + + if save_state: + train_util.save_state_on_train_end(network_trainer.args, network_trainer.accelerator) + + ckpt_name = train_util.get_last_ckpt_name(network_trainer.args, "." + network_trainer.args.save_model_as) + network_trainer.save_model(ckpt_name, network, network_trainer.global_step, network_trainer.num_train_epochs, force_sync_upload=True) + logger.info("model saved.") + + final_lora_name = str(network_trainer.args.output_name) + final_lora_path = os.path.join(network_trainer.args.output_dir, ckpt_name) + + # metadata + metadata = json.dumps(network_trainer.metadata, indent=2) + + training_loop = None + network_trainer = None + mm.soft_empty_cache() + + return (final_lora_name, metadata, final_lora_path) + +class SD3TrainValidationSettings: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "steps": ("INT", {"default": 20, "min": 1, "max": 256, "step": 1, "tooltip": "sampling steps"}), + "width": ("INT", {"default": 1024, "min": 64, "max": 4096, "step": 8, "tooltip": "image width"}), + "height": ("INT", {"default": 1024, "min": 64, "max": 4096, "step": 8, "tooltip": "image height"}), + "guidance_scale": ("FLOAT", {"default": 4, "min": 1.0, "max": 32.0, "step": 0.05, "tooltip": "guidance scale"}), + "seed": ("INT", {"default": 42,"min": 0, "max": 0xffffffffffffffff, "step": 1}), + }, + } + + RETURN_TYPES = ("VALSETTINGS", ) + RETURN_NAMES = ("validation_settings", ) + FUNCTION = "set" + CATEGORY = "FluxTrainer" + + def set(self, **kwargs): + validation_settings = kwargs + print(validation_settings) + + return (validation_settings,) + +class SD3TrainValidate: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "network_trainer": ("NETWORKTRAINER",), + }, + "optional": { + "validation_settings": ("VALSETTINGS",), + } + } + + RETURN_TYPES = ("NETWORKTRAINER", "IMAGE",) + RETURN_NAMES = ("network_trainer", "validation_images",) + FUNCTION = "validate" + CATEGORY = "FluxTrainer" + + def validate(self, network_trainer, validation_settings=None): + training_loop = network_trainer["training_loop"] + network_trainer = network_trainer["network_trainer"] + + params = ( + network_trainer.current_epoch.value, + network_trainer.global_step, + validation_settings + ) + network_trainer.optimizer_eval_fn() + with torch.inference_mode(False): + image_tensors = network_trainer.sample_images(*params) + + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, (0.5 * (image_tensors + 1.0)).cpu().float(),) + +NODE_CLASS_MAPPINGS = { + "SD3ModelSelect": SD3ModelSelect, + "InitSD3LoRATraining": InitSD3LoRATraining, + "SD3TrainValidationSettings": SD3TrainValidationSettings, +} +NODE_DISPLAY_NAME_MAPPINGS = { + "SD3ModelSelect": "SD3 Model Select", + "InitSD3LoRATraining": "Init SD3 LoRA Training", + "SD3TrainValidationSettings": "SD3 Train Validation Settings", +} diff --git a/nodes_sdxl.py b/nodes_sdxl.py new file mode 100644 index 0000000000000000000000000000000000000000..753fc9f62c9de15f8064deea4271eb47124472ed --- /dev/null +++ b/nodes_sdxl.py @@ -0,0 +1,465 @@ +import os +import torch + +import folder_paths +import comfy.model_management as mm +import comfy.utils +import toml +import json +import time +import shutil +import shlex + +script_directory = os.path.dirname(os.path.abspath(__file__)) + +from .sdxl_train_network import SdxlNetworkTrainer +from .library import sdxl_train_util +from .library.device_utils import init_ipex +init_ipex() + +from .library import train_util +from .train_network import setup_parser as train_network_setup_parser + +import logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +class SDXLModelSelect: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "checkpoint": (folder_paths.get_filename_list("checkpoints"), ), + }, + "optional": { + "lora_path": ("STRING",{"multiline": True, "forceInput": True, "default": "", "tooltip": "pre-trained LoRA path to load (network_weights)"}), + } + } + + RETURN_TYPES = ("TRAIN_SDXL_MODELS",) + RETURN_NAMES = ("sdxl_models",) + FUNCTION = "loadmodel" + CATEGORY = "FluxTrainer/SDXL" + + def loadmodel(self, checkpoint, lora_path=""): + + checkpoint_path = folder_paths.get_full_path("checkpoints", checkpoint) + + SDXL_models = { + "checkpoint": checkpoint_path, + "lora_path": lora_path + } + + return (SDXL_models,) + +class InitSDXLLoRATraining: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "SDXL_models": ("TRAIN_SDXL_MODELS",), + "dataset": ("JSON",), + "optimizer_settings": ("ARGS",), + "output_name": ("STRING", {"default": "SDXL_lora", "multiline": False}), + "output_dir": ("STRING", {"default": "SDXL_trainer_output", "multiline": False, "tooltip": "path to dataset, root is the 'ComfyUI' folder, with windows portable 'ComfyUI_windows_portable'"}), + "network_dim": ("INT", {"default": 16, "min": 1, "max": 100000, "step": 1, "tooltip": "network dim"}), + "network_alpha": ("FLOAT", {"default": 16, "min": 0.0, "max": 2048.0, "step": 0.01, "tooltip": "network alpha"}), + "learning_rate": ("FLOAT", {"default": 1e-6, "min": 0.0, "max": 10.0, "step": 0.0000001, "tooltip": "learning rate"}), + "max_train_steps": ("INT", {"default": 1500, "min": 1, "max": 100000, "step": 1, "tooltip": "max number of training steps"}), + "cache_latents": (["disk", "memory", "disabled"], {"tooltip": "caches text encoder outputs"}), + "cache_text_encoder_outputs": (["disk", "memory", "disabled"], {"tooltip": "caches text encoder outputs"}), + "highvram": ("BOOLEAN", {"default": False, "tooltip": "memory mode"}), + "blocks_to_swap": ("INT", {"default": 0, "min": 0, "max": 100, "step": 1, "tooltip": "option for memory use reduction. The maximum number of blocks that can be swapped is 36 for SDXL.5L and 22 for SDXL.5M"}), + "fp8_base": ("BOOLEAN", {"default": False, "tooltip": "use fp8 for base model"}), + "gradient_dtype": (["fp32", "fp16", "bf16"], {"default": "fp32", "tooltip": "the actual dtype training uses"}), + "save_dtype": (["fp32", "fp16", "bf16", "fp8_e4m3fn", "fp8_e5m2"], {"default": "fp16", "tooltip": "the dtype to save checkpoints as"}), + "attention_mode": (["sdpa", "xformers", "disabled"], {"default": "sdpa", "tooltip": "memory efficient attention mode"}), + "train_text_encoder": (['disabled', 'clip_l',], {"default": 'disabled', "tooltip": "also train the selected text encoders using specified dtype, T5 can not be trained without clip_l"}), + "clip_l_lr": ("FLOAT", {"default": 0, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "text encoder learning rate"}), + "clip_g_lr": ("FLOAT", {"default": 0, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "text encoder learning rate"}), + "sample_prompts_pos": ("STRING", {"multiline": True, "default": "illustration of a kitten | photograph of a turtle", "tooltip": "validation sample prompts, for multiple prompts, separate by `|`"}), + "sample_prompts_neg": ("STRING", {"multiline": True, "default": "", "tooltip": "validation sample prompts, for multiple prompts, separate by `|`"}), + "gradient_checkpointing": (["enabled", "disabled"], {"default": "enabled", "tooltip": "use gradient checkpointing"}), + }, + "optional": { + "additional_args": ("STRING", {"multiline": True, "default": "", "tooltip": "additional args to pass to the training command"}), + "resume_args": ("ARGS", {"default": "", "tooltip": "resume args to pass to the training command"}), + "block_args": ("ARGS", {"default": "", "tooltip": "limit the blocks used in the LoRA"}), + "loss_args": ("ARGS", {"default": "", "tooltip": "loss args"}), + "network_config": ("NETWORK_CONFIG", {"tooltip": "additional network config"}), + }, + "hidden": { + "prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO" + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "INT", "KOHYA_ARGS",) + RETURN_NAMES = ("network_trainer", "epochs_count", "args",) + FUNCTION = "init_training" + CATEGORY = "FluxTrainer/SDXL" + + def init_training(self, SDXL_models, dataset, optimizer_settings, sample_prompts_pos, sample_prompts_neg, output_name, attention_mode, + gradient_dtype, save_dtype, additional_args=None, resume_args=None, train_text_encoder='disabled', + gradient_checkpointing="enabled", prompt=None, extra_pnginfo=None, clip_l_lr=0, clip_g_lr=0, loss_args=None, network_config=None, **kwargs): + mm.soft_empty_cache() + + output_dir = os.path.abspath(kwargs.get("output_dir")) + os.makedirs(output_dir, exist_ok=True) + + total, used, free = shutil.disk_usage(output_dir) + + required_free_space = 2 * (2**30) + if free <= required_free_space: + raise ValueError(f"Insufficient disk space. Required: {required_free_space/2**30}GB. Available: {free/2**30}GB") + + dataset_config = dataset["datasets"] + dataset_toml = toml.dumps(json.loads(dataset_config)) + + parser = train_network_setup_parser() + #sdxl_train_util.add_sdxl_training_arguments(parser) + if additional_args is not None: + print(f"additional_args: {additional_args}") + args, _ = parser.parse_known_args(args=shlex.split(additional_args)) + else: + args, _ = parser.parse_known_args() + + if kwargs.get("cache_latents") == "memory": + kwargs["cache_latents"] = True + kwargs["cache_latents_to_disk"] = False + elif kwargs.get("cache_latents") == "disk": + kwargs["cache_latents"] = True + kwargs["cache_latents_to_disk"] = True + kwargs["caption_dropout_rate"] = 0.0 + kwargs["shuffle_caption"] = False + kwargs["token_warmup_step"] = 0.0 + kwargs["caption_tag_dropout_rate"] = 0.0 + else: + kwargs["cache_latents"] = False + kwargs["cache_latents_to_disk"] = False + + if kwargs.get("cache_text_encoder_outputs") == "memory": + kwargs["cache_text_encoder_outputs"] = True + kwargs["cache_text_encoder_outputs_to_disk"] = False + elif kwargs.get("cache_text_encoder_outputs") == "disk": + kwargs["cache_text_encoder_outputs"] = True + kwargs["cache_text_encoder_outputs_to_disk"] = True + else: + kwargs["cache_text_encoder_outputs"] = False + kwargs["cache_text_encoder_outputs_to_disk"] = False + + if '|' in sample_prompts_pos: + positive_prompts = sample_prompts_pos.split('|') + else: + positive_prompts = [sample_prompts_pos] + + if '|' in sample_prompts_neg: + negative_prompts = sample_prompts_neg.split('|') + else: + negative_prompts = [sample_prompts_neg] + + config_dict = { + "sample_prompts": positive_prompts, + "negative_prompts": negative_prompts, + "save_precision": save_dtype, + "mixed_precision": "bf16", + "num_cpu_threads_per_process": 1, + "pretrained_model_name_or_path": SDXL_models["checkpoint"], + "save_model_as": "safetensors", + "persistent_data_loader_workers": False, + "max_data_loader_n_workers": 0, + "seed": 42, + "network_module": ".networks.lora" if network_config is None else network_config["network_module"], + "dataset_config": dataset_toml, + "output_name": f"{output_name}_rank{kwargs.get('network_dim')}_{save_dtype}", + "loss_type": "l2", + "alpha_mask": dataset["alpha_mask"], + "network_train_unet_only": True if train_text_encoder == 'disabled' else False, + "disable_mmap_load_safetensors": False, + "network_args": None if network_config is None else network_config["network_args"], + } + attention_settings = { + "sdpa": {"mem_eff_attn": True, "xformers": False, "spda": True}, + "xformers": {"mem_eff_attn": True, "xformers": True, "spda": False} + } + config_dict.update(attention_settings.get(attention_mode, {})) + + gradient_dtype_settings = { + "fp16": {"full_fp16": True, "full_bf16": False, "mixed_precision": "fp16"}, + "bf16": {"full_bf16": True, "full_fp16": False, "mixed_precision": "bf16"} + } + config_dict.update(gradient_dtype_settings.get(gradient_dtype, {})) + + if train_text_encoder != 'disabled': + config_dict["text_encoder_lr"] = [clip_l_lr, clip_g_lr] + + #network args + additional_network_args = [] + + # Handle network_args in args Namespace + if hasattr(args, 'network_args') and isinstance(args.network_args, list): + args.network_args.extend(additional_network_args) + else: + setattr(args, 'network_args', additional_network_args) + + if gradient_checkpointing == "disabled": + config_dict["gradient_checkpointing"] = False + elif gradient_checkpointing == "enabled_with_cpu_offloading": + config_dict["gradient_checkpointing"] = True + config_dict["cpu_offload_checkpointing"] = True + else: + config_dict["gradient_checkpointing"] = True + + if SDXL_models["lora_path"]: + config_dict["network_weights"] = SDXL_models["lora_path"] + + config_dict.update(kwargs) + config_dict.update(optimizer_settings) + + if loss_args: + config_dict.update(loss_args) + + if resume_args: + config_dict.update(resume_args) + + for key, value in config_dict.items(): + setattr(args, key, value) + + saved_args_file_path = os.path.join(output_dir, f"{output_name}_args.json") + with open(saved_args_file_path, 'w') as f: + json.dump(vars(args), f, indent=4) + + #workflow saving + metadata = {} + if extra_pnginfo is not None: + metadata.update(extra_pnginfo["workflow"]) + + saved_workflow_file_path = os.path.join(output_dir, f"{output_name}_workflow.json") + with open(saved_workflow_file_path, 'w') as f: + json.dump(metadata, f, indent=4) + + #pass args to kohya and initialize trainer + with torch.inference_mode(False): + network_trainer = SdxlNetworkTrainer() + training_loop = network_trainer.init_train(args) + + epochs_count = network_trainer.num_train_epochs + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, epochs_count, args) + + +class SDXLTrainLoop: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "steps": ("INT", {"default": 1, "min": 1, "max": 10000, "step": 1, "tooltip": "the step point in training to validate/save"}), + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "INT",) + RETURN_NAMES = ("network_trainer", "steps",) + FUNCTION = "train" + CATEGORY = "FluxTrainer/SDXL" + + def train(self, network_trainer, steps): + with torch.inference_mode(False): + training_loop = network_trainer["training_loop"] + network_trainer = network_trainer["network_trainer"] + initial_global_step = network_trainer.global_step + + target_global_step = network_trainer.global_step + steps + comfy_pbar = comfy.utils.ProgressBar(steps) + network_trainer.comfy_pbar = comfy_pbar + + network_trainer.optimizer_train_fn() + + while network_trainer.global_step < target_global_step: + steps_done = training_loop( + break_at_steps = target_global_step, + epoch = network_trainer.current_epoch.value, + ) + + # Also break if the global steps have reached the max train steps + if network_trainer.global_step >= network_trainer.args.max_train_steps: + break + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, network_trainer.global_step) + + +class SDXLTrainLoRASave: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "save_state": ("BOOLEAN", {"default": False, "tooltip": "save the whole model state as well"}), + "copy_to_comfy_lora_folder": ("BOOLEAN", {"default": False, "tooltip": "copy the lora model to the comfy lora folder"}), + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "STRING", "INT",) + RETURN_NAMES = ("network_trainer","lora_path", "steps",) + FUNCTION = "save" + CATEGORY = "FluxTrainer/SDXL" + + def save(self, network_trainer, save_state, copy_to_comfy_lora_folder): + import shutil + with torch.inference_mode(False): + trainer = network_trainer["network_trainer"] + global_step = trainer.global_step + + ckpt_name = train_util.get_step_ckpt_name(trainer.args, "." + trainer.args.save_model_as, global_step) + trainer.save_model(ckpt_name, trainer.accelerator.unwrap_model(trainer.network), global_step, trainer.current_epoch.value + 1) + + remove_step_no = train_util.get_remove_step_no(trainer.args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(trainer.args, "." + trainer.args.save_model_as, remove_step_no) + trainer.remove_model(remove_ckpt_name) + + if save_state: + train_util.save_and_remove_state_stepwise(trainer.args, trainer.accelerator, global_step) + + lora_path = os.path.join(trainer.args.output_dir, ckpt_name) + if copy_to_comfy_lora_folder: + destination_dir = os.path.join(folder_paths.models_dir, "loras", "flux_trainer") + os.makedirs(destination_dir, exist_ok=True) + shutil.copy(lora_path, os.path.join(destination_dir, ckpt_name)) + + + return (network_trainer, lora_path, global_step) + + + +class SDXLTrainEnd: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "save_state": ("BOOLEAN", {"default": True}), + }, + } + + RETURN_TYPES = ("STRING", "STRING", "STRING",) + RETURN_NAMES = ("lora_name", "metadata", "lora_path",) + FUNCTION = "endtrain" + CATEGORY = "FluxTrainer/SDXL" + OUTPUT_NODE = True + + def endtrain(self, network_trainer, save_state): + with torch.inference_mode(False): + training_loop = network_trainer["training_loop"] + network_trainer = network_trainer["network_trainer"] + + network_trainer.metadata["ss_epoch"] = str(network_trainer.num_train_epochs) + network_trainer.metadata["ss_training_finished_at"] = str(time.time()) + + network = network_trainer.accelerator.unwrap_model(network_trainer.network) + + network_trainer.accelerator.end_training() + network_trainer.optimizer_eval_fn() + + if save_state: + train_util.save_state_on_train_end(network_trainer.args, network_trainer.accelerator) + + ckpt_name = train_util.get_last_ckpt_name(network_trainer.args, "." + network_trainer.args.save_model_as) + network_trainer.save_model(ckpt_name, network, network_trainer.global_step, network_trainer.num_train_epochs, force_sync_upload=True) + logger.info("model saved.") + + final_lora_name = str(network_trainer.args.output_name) + final_lora_path = os.path.join(network_trainer.args.output_dir, ckpt_name) + + # metadata + metadata = json.dumps(network_trainer.metadata, indent=2) + + training_loop = None + network_trainer = None + mm.soft_empty_cache() + + return (final_lora_name, metadata, final_lora_path) + +class SDXLTrainValidationSettings: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "steps": ("INT", {"default": 20, "min": 1, "max": 256, "step": 1, "tooltip": "sampling steps"}), + "width": ("INT", {"default": 1024, "min": 64, "max": 4096, "step": 8, "tooltip": "image width"}), + "height": ("INT", {"default": 1024, "min": 64, "max": 4096, "step": 8, "tooltip": "image height"}), + "guidance_scale": ("FLOAT", {"default": 7.5, "min": 1.0, "max": 32.0, "step": 0.05, "tooltip": "guidance scale"}), + "sampler": (["ddim", "ddpm", "pndm", "lms", "euler", "euler_a", "dpmsolver", "dpmsingle", "heun", "dpm_2", "dpm_2_a",], {"default": "dpm_2", "tooltip": "sampler"}), + "seed": ("INT", {"default": 42,"min": 0, "max": 0xffffffffffffffff, "step": 1}), + }, + } + + RETURN_TYPES = ("VALSETTINGS", ) + RETURN_NAMES = ("validation_settings", ) + FUNCTION = "set" + CATEGORY = "FluxTrainer/SDXL" + + def set(self, **kwargs): + validation_settings = kwargs + print(validation_settings) + + return (validation_settings,) + +class SDXLTrainValidate: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "network_trainer": ("NETWORKTRAINER",), + }, + "optional": { + "validation_settings": ("VALSETTINGS",), + } + } + + RETURN_TYPES = ("NETWORKTRAINER", "IMAGE",) + RETURN_NAMES = ("network_trainer", "validation_images",) + FUNCTION = "validate" + CATEGORY = "FluxTrainer/SDXL" + + def validate(self, network_trainer, validation_settings=None): + training_loop = network_trainer["training_loop"] + network_trainer = network_trainer["network_trainer"] + + params = ( + network_trainer.accelerator, + network_trainer.args, + network_trainer.current_epoch.value, + network_trainer.global_step, + network_trainer.accelerator.device, + network_trainer.vae, + network_trainer.tokenizers, + network_trainer.text_encoder, + network_trainer.unet, + validation_settings, + ) + network_trainer.optimizer_eval_fn() + with torch.inference_mode(False): + image_tensors = network_trainer.sample_images(*params) + + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, (0.5 * (image_tensors + 1.0)).cpu().float(),) + +NODE_CLASS_MAPPINGS = { + "SDXLModelSelect": SDXLModelSelect, + "InitSDXLLoRATraining": InitSDXLLoRATraining, + "SDXLTrainValidationSettings": SDXLTrainValidationSettings, + "SDXLTrainValidate": SDXLTrainValidate, + +} +NODE_DISPLAY_NAME_MAPPINGS = { + "SDXLModelSelect": "SDXL Model Select", + "InitSDXLLoRATraining": "Init SDXL LoRA Training", + "SDXLTrainValidationSettings": "SDXL Train Validation Settings", + "SDXLTrainValidate": "SDXL Train Validate", +} diff --git a/pinokio.js b/pinokio.js new file mode 100644 index 0000000000000000000000000000000000000000..9562448e3955e7f6d3c369a46e9042fac5adb26c --- /dev/null +++ b/pinokio.js @@ -0,0 +1,95 @@ +const path = require('path') +module.exports = { + version: "3.2", + title: "fluxgym", + description: "[NVIDIA Only] Dead simple web UI for training FLUX LoRA with LOW VRAM support (From 12GB)", + icon: "icon.png", + menu: async (kernel, info) => { + let installed = info.exists("env") + let running = { + install: info.running("install.js"), + start: info.running("start.js"), + update: info.running("update.js"), + reset: info.running("reset.js") + } + if (running.install) { + return [{ + default: true, + icon: "fa-solid fa-plug", + text: "Installing", + href: "install.js", + }] + } else if (installed) { + if (running.start) { + let local = info.local("start.js") + if (local && local.url) { + return [{ + default: true, + icon: "fa-solid fa-rocket", + text: "Open Web UI", + href: local.url, + }, { + icon: 'fa-solid fa-terminal', + text: "Terminal", + href: "start.js", + }, { + icon: "fa-solid fa-flask", + text: "Outputs", + href: "outputs?fs" + }] + } else { + return [{ + default: true, + icon: 'fa-solid fa-terminal', + text: "Terminal", + href: "start.js", + }] + } + } else if (running.update) { + return [{ + default: true, + icon: 'fa-solid fa-terminal', + text: "Updating", + href: "update.js", + }] + } else if (running.reset) { + return [{ + default: true, + icon: 'fa-solid fa-terminal', + text: "Resetting", + href: "reset.js", + }] + } else { + return [{ + default: true, + icon: "fa-solid fa-power-off", + text: "Start", + href: "start.js", + }, { + icon: "fa-solid fa-flask", + text: "Outputs", + href: "sd-scripts/fluxgym/outputs?fs" + }, { + icon: "fa-solid fa-plug", + text: "Update", + href: "update.js", + }, { + icon: "fa-solid fa-plug", + text: "Install", + href: "install.js", + }, { + icon: "fa-regular fa-circle-xmark", + text: "Reset", + href: "reset.js", + }] + } + } else { + return [{ + default: true, + icon: "fa-solid fa-plug", + text: "Install", + href: "install.js", + }] + } + } +} diff --git a/pinokio_meta.json b/pinokio_meta.json new file mode 100644 index 0000000000000000000000000000000000000000..c5d4580458be87c97c311020241b5c49ef93c870 --- /dev/null +++ b/pinokio_meta.json @@ -0,0 +1,39 @@ +{ + "posts": [ + "https://x.com/cocktailpeanut/status/1851721405408166064", + "https://x.com/cocktailpeanut/status/1835719701172756592", + "https://x.com/LikeToasters/status/1834258975384092858", + "https://x.com/cocktailpeanut/status/1834245329627009295", + "https://x.com/jkch0205/status/1834003420132614450", + "https://x.com/huwhitememes/status/1834074992209699132", + "https://x.com/GorillaRogueGam/status/1834148656791888139", + "https://x.com/cocktailpeanut/status/1833964839519068303", + "https://x.com/cocktailpeanut/status/1833935061907079521", + "https://x.com/cocktailpeanut/status/1833940728881242135", + "https://x.com/cocktailpeanut/status/1833881392482066638", + "https://x.com/Alone1Moon/status/1833348850662445369", + "https://x.com/_f_ai_9/status/1833485349995397167", + "https://x.com/intocryptoast/status/1833061082862412186", + "https://x.com/cocktailpeanut/status/1833888423716827321", + "https://x.com/cocktailpeanut/status/1833884852992516596", + "https://x.com/cocktailpeanut/status/1833885335077417046", + "https://x.com/NiwonArt/status/1833565746624139650", + "https://x.com/cocktailpeanut/status/1833884361986380117", + "https://x.com/NiwonArt/status/1833599399764889685", + "https://x.com/LikeToasters/status/1832934391217045913", + "https://x.com/cocktailpeanut/status/1832924887456817415", + "https://x.com/cocktailpeanut/status/1832927154536902897", + "https://x.com/YabaiHamster/status/1832697724690386992", + "https://x.com/cocktailpeanut/status/1832747889497366706", + "https://x.com/PhotogenicWeekE/status/1832720544959185202", + "https://x.com/zuzaritt/status/1832748542164652390", + "https://x.com/foxyy4i/status/1832764883710185880", + "https://x.com/waynedahlberg/status/1832226132999213095", + "https://x.com/PhotoGarrido/status/1832214644515041770", + "https://x.com/cocktailpeanut/status/1832787205774786710", + "https://x.com/cocktailpeanut/status/1832151307198541961", + "https://x.com/cocktailpeanut/status/1832145996014612735", + "https://x.com/cocktailpeanut/status/1832084951115972653", + "https://x.com/cocktailpeanut/status/1832091112086843684" + ] +} diff --git a/publish_to_hf.png b/publish_to_hf.png new file mode 100644 index 0000000000000000000000000000000000000000..25c572624bcbbcb5f0a5d626ce99daaf05da40ad --- /dev/null +++ b/publish_to_hf.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cac2aa25db8911b38ed7e084bbbafb226252e26935dbb107ee66b8cc626a95e6 +size 418251 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..291426c178063a47091e4d82760b7f702ebeb9bd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,43 @@ +safetensors +git+https://github.com/huggingface/diffusers.git +gradio_logsview@https://huggingface.co/spaces/cocktailpeanut/gradio_logsview/resolve/main/gradio_logsview-0.0.17-py3-none-any.whl +transformers +lycoris-lora==1.8.3 +flatten_json +pyyaml +oyaml +tensorboard +kornia +invisible-watermark +einops +accelerate +toml +albumentations +pydantic +omegaconf +k-diffusion +open_clip_torch +timm +prodigyopt +controlnet_aux==0.0.7 +python-dotenv +bitsandbytes +hf_transfer +lpips +pytorch_fid +optimum-quanto +sentencepiece +huggingface_hub +peft +gradio +python-slugify +imagesize +pydantic==2.9.2 +slugify +easygui +argparse +pytorch-lightning==1.9.0 +# triton +comfy +gradio +voluptuous diff --git a/reset.js b/reset.js new file mode 100644 index 0000000000000000000000000000000000000000..0278948c73f96521352a9358dc55b62010efae3e --- /dev/null +++ b/reset.js @@ -0,0 +1,13 @@ +module.exports = { + run: [{ + method: "fs.rm", + params: { + path: "sd-scripts" + } + }, { + method: "fs.rm", + params: { + path: "env" + } + }] +} diff --git a/sample.png b/sample.png new file mode 100644 index 0000000000000000000000000000000000000000..a7f29b47eabfec81b53f82b6cc336bbd2863c468 --- /dev/null +++ b/sample.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a1670e3ce2a35d0cffec798ea04f4216b7d4d766e1e785ef23e94f6d2d22ff1 +size 1293751 diff --git a/sample_fields.png b/sample_fields.png new file mode 100644 index 0000000000000000000000000000000000000000..bc50bee365e03c2e25779299dc854712838f61c3 Binary files /dev/null and b/sample_fields.png differ diff --git a/screenshot.png b/screenshot.png new file mode 100644 index 0000000000000000000000000000000000000000..4a4fd700582986ab500e118e5764b04bc4c14f85 --- /dev/null +++ b/screenshot.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cde964e4a233bf3ad7219ac91058f805ae0cbc0f853b7f6aa552af5b6f8c5c8a +size 242712 diff --git a/sd3_train_network.py b/sd3_train_network.py new file mode 100644 index 0000000000000000000000000000000000000000..9c4273a5ac74b82cc8c9b60f5285012d6e87b62f --- /dev/null +++ b/sd3_train_network.py @@ -0,0 +1,486 @@ +import argparse +import math + +from typing import Any, Optional + +import torch +from accelerate import Accelerator +from .library import sd3_models, strategy_sd3, utils +from .library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from .library import flux_models, flux_utils, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3, train_util +from . import train_network +from .library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class Sd3NetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.sample_prompts_te_outputs = None + + def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup): + # super().assert_extra_args(args, train_dataset_group) + # sdxl_train_util.verify_sdxl_training_args(args) + + if args.fp8_base_unet: + args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for SD3 + + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # prepare CLIP-L/CLIP-G/T5XXL training flags + self.train_clip = not args.network_train_unet_only + self.train_t5xxl = False # default is False even if args.network_train_unet_only is False + + if args.max_token_length is not None: + logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + + # enumerate resolutions from dataset for positional embeddings + self.resolutions = train_dataset_group.get_resolutions() + + def load_target_model(self, args, weight_dtype, accelerator): + # currently offload to cpu for some models + + # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) + loading_dtype = None if args.fp8_base else weight_dtype + + # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future + state_dict = utils.load_safetensors( + args.pretrained_model_name_or_path, "cpu", disable_mmap=args.disable_mmap_load_safetensors, dtype=loading_dtype + ) + mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu") + self.model_type = mmdit.model_type + mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate) + + # set resolutions for positional embeddings + if args.enable_scaled_pos_embed: + latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent + latent_sizes = list(set(latent_sizes)) # remove duplicates + logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}") + mmdit.enable_scaled_pos_embed(True, latent_sizes) + + if args.fp8_base: + # check dtype of model + if mmdit.dtype == torch.float8_e4m3fnuz or mmdit.dtype == torch.float8_e5m2 or mmdit.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {mmdit.dtype}") + elif mmdit.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 SD3 model") + else: + logger.info( + "Cast SD3 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." + ) + mmdit.to(torch.float8_e4m3fn) + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + mmdit.enable_block_swap(args.blocks_to_swap, accelerator.device) + + clip_l = sd3_utils.load_clip_l( + args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + clip_l.eval() + clip_g = sd3_utils.load_clip_g( + args.clip_g, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + clip_g.eval() + + # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) + if args.fp8_base and not args.fp8_base_unet: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + + # loading t5xxl to cpu takes a long time, so we should load to gpu in future + t5xxl = sd3_utils.load_t5xxl( + args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + t5xxl.eval() + if args.fp8_base and not args.fp8_base_unet: + # check dtype of model + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") + + vae = sd3_utils.load_vae( + args.vae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + + return mmdit.model_type, [clip_l, clip_g, t5xxl], vae, mmdit + + def get_tokenize_strategy(self, args): + logger.info(f"t5xxl_max_token_length: {args.t5xxl_max_token_length}") + return strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy): + return [tokenize_strategy.clip_l, tokenize_strategy.clip_g, tokenize_strategy.t5xxl] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sd3.Sd3TextEncodingStrategy( + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + args.clip_l_dropout_rate, + args.clip_g_dropout_rate, + args.t5_dropout_rate, + ) + + def post_process_network(self, args, accelerator, network, text_encoders, unet): + # check t5xxl is trained or not + self.train_t5xxl = network.train_t5xxl + + if self.train_t5xxl and args.cache_text_encoder_outputs: + raise ValueError( + "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" + ) + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + if args.cache_text_encoder_outputs: + if self.train_clip and not self.train_t5xxl: + return text_encoders[0:2] + [None] # only CLIP-L/CLIP-G is needed for encoding because T5XXL is cached + else: + return None # no text encoders are needed for encoding because both are cached + else: + return text_encoders # CLIP-L, CLIP-G and T5XXL are needed for encoding + + def get_text_encoders_train_flags(self, args, text_encoders): + return [self.train_clip, self.train_clip, self.train_t5xxl] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True + return strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=self.train_clip or self.train_t5xxl, + apply_lg_attn_mask=args.apply_lg_attn_mask, + apply_t5_attn_mask=args.apply_t5_attn_mask, + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[1].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[2].to(accelerator.device) # may be fp8 + + if text_encoders[2].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(2, text_encoders[2], text_encoders[2].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[2].to(weight_dtype) + + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) + + # cache sample prompts + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = [] + for line in args.sample_prompts: + line = line.strip() + if len(line) > 0 and line[0] != "#": + prompts.append(line) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + from .library.train_util import line_to_prompt_dict + + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, + text_encoders, + tokens_and_masks, + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + + accelerator.wait_for_everyone() + + # move back to cpu + if not self.is_train_text_encoder(args): + logger.info("move CLIP-L back to cpu") + text_encoders[0].to("cpu") + logger.info("move CLIP-G back to cpu") + text_encoders[1].to("cpu") + logger.info("move t5XXL back to cpu") + text_encoders[2].to("cpu") + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) + text_encoders[2].to(accelerator.device) + + # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # # get size embeddings + # orig_size = batch["original_sizes_hw"] + # crop_size = batch["crop_top_lefts"] + # target_size = batch["target_sizes_hw"] + # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # # concat embeddings + # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + # return noise_pred + + def sample_images(self, epoch, global_step, validation_settings): + text_encoders = self.get_models_for_text_encoding(self.args, self.accelerator, self.text_encoder) + image_tensors = sd3_train_utils.sample_images( + self.accelerator, self.args, epoch, global_step, self.unet, self.vae, text_encoders, self.sample_prompts_te_outputs, validation_settings + ) + + return image_tensors.permute(0, 2, 3, 1) + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + # this scheduler is not used in training, but used to get num_train_timesteps etc. + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images) + + def shift_scale_latents(self, args, latents): + return sd3_models.SDVAE.process_in(latents) + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet: flux_models.Flux, + network, + weight_dtype, + train_unet, + ): + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps( + args, latents, noise, accelerator.device, weight_dtype + ) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + if t is not None and t.dtype.is_floating_point: + t.requires_grad_(True) + + # Predict the noise residual + lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds + text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled) + if not args.apply_lg_attn_mask: + l_attn_mask = None + g_attn_mask = None + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + # call model + with accelerator.autocast(): + # TODO support attention mask + model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled) + + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + # Preconditioning of the model outputs. + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = sd3_train_utils.compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = latents + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + model_pred_prior = unet( + noisy_model_input[diff_output_pr_indices], + timesteps[diff_output_pr_indices], + context=context[diff_output_pr_indices], + y=lg_pooled[diff_output_pr_indices], + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + + model_pred_prior = model_pred_prior * (-sigmas[diff_output_pr_indices]) + noisy_model_input[diff_output_pr_indices] + + # weighting for differential output preservation is not needed because it is already applied + + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + + return model_pred, target, timesteps, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, sd3=self.model_type) + + def update_metadata(self, metadata, args): + metadata["ss_apply_lg_attn_mask"] = args.apply_lg_attn_mask + metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + if index == 0 or index == 1: # CLIP-L/CLIP-G + return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) + else: # T5XXL + text_encoder.encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + if index == 0 or index == 1: # CLIP-L/CLIP-G + clip_type = "CLIP-L" if index == 0 else "CLIP-G" + logger.info(f"prepare CLIP-{clip_type} for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + else: # T5XXL + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: + logger.info(f"T5XXL already prepared for fp8") + else: + logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") + text_encoder.to(te_weight_dtype) # fp8 + prepare_fp8(text_encoder, weight_dtype) + + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + # drop cached text encoder outputs + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) + batch["text_encoder_outputs_list"] = text_encoder_outputs_list + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + mmdit: sd3_models.MMDiT = unet + mmdit = accelerator.prepare(mmdit, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(mmdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(mmdit).prepare_block_swap_before_forward() + + return mmdit + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + train_util.add_dit_training_arguments(parser) + sd3_train_utils.add_sd3_training_arguments(parser) + return parser diff --git a/sdxl_train_network.py b/sdxl_train_network.py new file mode 100644 index 0000000000000000000000000000000000000000..3fc3c3ac3db947a17e1c6b7fc95f017c6f9402c1 --- /dev/null +++ b/sdxl_train_network.py @@ -0,0 +1,228 @@ +import argparse +from typing import List, Optional + +import torch +from accelerate import Accelerator +from .library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from .library import sdxl_model_util, sdxl_train_util, strategy_base, strategy_sd, strategy_sdxl, train_util +from . import train_network +from .library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class SdxlNetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR + self.is_sdxl = True + + def assert_extra_args(self, args, train_dataset_group): + super().assert_extra_args(args, train_dataset_group) + sdxl_train_util.verify_sdxl_training_args(args) + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + assert ( + args.network_train_unet_only or not args.cache_text_encoder_outputs + ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + + train_dataset_group.verify_bucket_reso_steps(32) + + def load_target_model(self, args, weight_dtype, accelerator): + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) + + self.load_stable_diffusion_format = load_stable_diffusion_format + self.logit_scale = logit_scale + self.ckpt_info = ckpt_info + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) + + return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet + + def get_tokenize_strategy(self, args): + return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy): + return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sdxl.SdxlTextEncodingStrategy() + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + return text_encoders + [accelerator.unwrap_model(text_encoders[-1])] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator) + accelerator.wait_for_everyone() + + text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU + text_encoders[1].to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) + + def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): + if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: + input_ids1 = batch["input_ids"] + input_ids2 = batch["input_ids2"] + with torch.enable_grad(): + # Get the text embedding for conditioning + # TODO support weighted captions + # if args.weighted_captions: + # encoder_hidden_states = get_weighted_text_embeddings( + # tokenizer, + # text_encoder, + # batch["captions"], + # accelerator.device, + # args.max_token_length // 75 if args.max_token_length else 1, + # clip_skip=args.clip_skip, + # ) + # else: + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( + args.max_token_length, + input_ids1, + input_ids2, + tokenizers[0], + tokenizers[1], + text_encoders[0], + text_encoders[1], + None if not args.full_fp16 else weight_dtype, + accelerator=accelerator, + ) + else: + encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) + encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) + pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + + # # verify that the text encoder outputs are correct + # ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl( + # args.max_token_length, + # batch["input_ids"].to(text_encoders[0].device), + # batch["input_ids2"].to(text_encoders[0].device), + # tokenizers[0], + # tokenizers[1], + # text_encoders[0], + # text_encoders[1], + # None if not args.full_fp16 else weight_dtype, + # ) + # b_size = encoder_hidden_states1.shape[0] + # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 + # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 + # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 + # logger.info("text encoder outputs verified") + + return encoder_hidden_states1, encoder_hidden_states2, pool2 + + def call_unet( + self, + args, + accelerator, + unet, + noisy_latents, + timesteps, + text_conds, + batch, + weight_dtype, + indices: Optional[List[int]] = None, + ): + noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # get size embeddings + orig_size = batch["original_sizes_hw"] + crop_size = batch["crop_top_lefts"] + target_size = batch["target_sizes_hw"] + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # concat embeddings + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + if indices is not None and len(indices) > 0: + noisy_latents = noisy_latents[indices] + timesteps = timesteps[indices] + text_embedding = text_embedding[indices] + vector_embedding = vector_embedding[indices] + + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, validation_settings=None): + image_tensors = sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, validation_settings) + return image_tensors + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + sdxl_train_util.add_sdxl_training_arguments(parser) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = SdxlNetworkTrainer() + trainer.train(args) diff --git a/seed.gif b/seed.gif new file mode 100644 index 0000000000000000000000000000000000000000..c1d209a98db2408230a2d2c67f749fb0bee8793b --- /dev/null +++ b/seed.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:271dbf11ef0c709558bb570c4c2b7765001356eefcbcc9cf0f0713262a91937f +size 3622293 diff --git a/start.js b/start.js new file mode 100644 index 0000000000000000000000000000000000000000..494cb260c56aa96749753b67790aa4af4530e34d --- /dev/null +++ b/start.js @@ -0,0 +1,37 @@ +module.exports = { + daemon: true, + run: [ + { + method: "shell.run", + params: { + venv: "env", // Edit this to customize the venv folder path + env: { + LOG_LEVEL: "DEBUG", + CUDA_VISIBLE_DEVICES: "0" + }, // Edit this to customize environment variables (see documentation) + message: [ + "python app.py", // Edit with your custom commands + ], + on: [{ + // The regular expression pattern to monitor. + // When this pattern occurs in the shell terminal, the shell will return, + // and the script will go onto the next step. + "event": "/http:\/\/\\S+/", + + // "done": true will move to the next step while keeping the shell alive. + // "kill": true will move to the next step after killing the shell. + "done": true + }] + } + }, + { + // This step sets the local variable 'url'. + // This local variable will be used in pinokio.js to display the "Open WebUI" tab when the value is set. + method: "local.set", + params: { + // the input.event is the regular expression match object from the previous step + url: "{{input.event[0]}}" + } + } + ] +} diff --git a/torch.js b/torch.js new file mode 100644 index 0000000000000000000000000000000000000000..4d6e4d1ad3377b2a7300e0da9e1883399743fd70 --- /dev/null +++ b/torch.js @@ -0,0 +1,75 @@ +module.exports = { + run: [ + // windows nvidia + { + "when": "{{platform === 'win32' && gpu === 'nvidia'}}", + "method": "shell.run", + "params": { + "venv": "{{args && args.venv ? args.venv : null}}", + "path": "{{args && args.path ? args.path : '.'}}", + "message": "uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 --force-reinstall" + + } + }, + // windows amd + { + "when": "{{platform === 'win32' && gpu === 'amd'}}", + "method": "shell.run", + "params": { + "venv": "{{args && args.venv ? args.venv : null}}", + "path": "{{args && args.path ? args.path : '.'}}", + "message": "uv pip install torch-directml torchaudio torchvision" + } + }, + // windows cpu + { + "when": "{{platform === 'win32' && (gpu !== 'nvidia' && gpu !== 'amd')}}", + "method": "shell.run", + "params": { + "venv": "{{args && args.venv ? args.venv : null}}", + "path": "{{args && args.path ? args.path : '.'}}", + "message": "uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --force-reinstall" + } + }, + // mac + { + "when": "{{platform === 'darwin'}}", + "method": "shell.run", + "params": { + "venv": "{{args && args.venv ? args.venv : null}}", + "path": "{{args && args.path ? args.path : '.'}}", + "message": "uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --force-reinstall" + } + }, + // linux nvidia + { + "when": "{{platform === 'linux' && gpu === 'nvidia'}}", + "method": "shell.run", + "params": { + "venv": "{{args && args.venv ? args.venv : null}}", + "path": "{{args && args.path ? args.path : '.'}}", + "message": "uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 --force-reinstall" + } + }, + // linux rocm (amd) + { + "when": "{{platform === 'linux' && gpu === 'amd'}}", + "method": "shell.run", + "params": { + "venv": "{{args && args.venv ? args.venv : null}}", + "path": "{{args && args.path ? args.path : '.'}}", + "message": "uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1 --force-reinstall" + } + }, + // linux cpu + { + "when": "{{platform === 'linux' && (gpu !== 'amd' && gpu !=='nvidia')}}", + "method": "shell.run", + "params": { + "venv": "{{args && args.venv ? args.venv : null}}", + "path": "{{args && args.path ? args.path : '.'}}", + "message": "uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --force-reinstall" + } + } + ] +} diff --git a/train_db.py b/train_db.py new file mode 100644 index 0000000000000000000000000000000000000000..3b8d7b205284c09cd867afda259f8028f589958d --- /dev/null +++ b/train_db.py @@ -0,0 +1,558 @@ +# DreamBooth training +# XXX dropped option: fine_tune + +import argparse +import itertools +import math +import os +from multiprocessing import Value +import toml + +from tqdm import tqdm + +import torch +from library import deepspeed_utils, strategy_base +from library.device_utils import init_ipex, clean_memory_on_device + + +init_ipex() + +from accelerate.utils import set_seed +from diffusers import DDPMScheduler + +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + apply_snr_weight, + get_weighted_text_embeddings, + prepare_scheduler_for_custom_training, + pyramid_noise_like, + apply_noise_offset, + scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, + apply_masked_loss, +) +from .utils import setup_logging, add_logging_arguments +import library.strategy_sd as strategy_sd + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +# perlin_noise, + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, False) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + cache_latents = args.cache_latents + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, args.masked_loss, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + if args.no_token_padding: + train_dataset_group.disable_token_padding() + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + + if args.gradient_accumulation_steps > 1: + logger.warning( + f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong" + ) + logger.warning( + f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です" + ) + + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator) + + # verify load/save model formats + if load_stable_diffusion_format: + src_stable_diffusion_ckpt = args.pretrained_model_name_or_path + src_diffusers_model_path = None + else: + src_stable_diffusion_ckpt = None + src_diffusers_model_path = args.pretrained_model_name_or_path + + if args.save_model_as is None: + save_stable_diffusion_format = load_stable_diffusion_format + use_safetensors = args.use_safetensors + else: + save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors" + use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + + train_dataset_group.new_cache_latents(vae, accelerator) + + vae.to("cpu") + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # 学習を準備する:モデルを適切な状態にする + train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 + unet.requires_grad_(True) # 念のため追加 + text_encoder.requires_grad_(train_text_encoder) + if not train_text_encoder: + accelerator.print("Text Encoder is not trained.") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + if train_text_encoder: + if args.learning_rate_te is None: + # wightout list, adamw8bit is crashed + trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters())) + else: + trainable_params = [ + {"params": list(unet.parameters()), "lr": args.learning_rate}, + {"params": list(text_encoder.parameters()), "lr": args.learning_rate_te}, + ] + else: + trainable_params = unet.parameters() + + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + if args.stop_text_encoder_training is None: + args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end + + # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + unet.to(weight_dtype) + text_encoder.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + if args.deepspeed: + if args.train_text_encoder: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder) + else: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet) + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + if train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + training_models = [unet, text_encoder] + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + training_models = [unet] + + if not train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + accelerator.print( + f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + ) + accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "dreambooth" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + # For --sample_at_first + train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) + + loss_recorder = train_util.LossRecorder() + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + # 指定したステップ数までText Encoderを学習する:epoch最初の状態 + unet.train() + # train==True is required to enable gradient_checkpointing + if args.gradient_checkpointing or global_step < args.stop_text_encoder_training: + text_encoder.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + # 指定したステップ数でText Encoderの学習を止める + if global_step == args.stop_text_encoder_training: + accelerator.print(f"stop text encoder training at step {global_step}") + if not args.gradient_checkpointing: + text_encoder.train(False) + text_encoder.requires_grad_(False) + if len(training_models) == 2: + training_models = training_models[0] # remove text_encoder from training_models + + with accelerator.accumulate(*training_models): + with torch.no_grad(): + # latentに変換 + if cache_latents: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + # Get the text embedding for conditioning + with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): + if args.weighted_captions: + encoder_hidden_states = get_weighted_text_embeddings( + tokenize_strategy.tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + input_ids = batch["input_ids_list"][0].to(accelerator.device) + encoder_hidden_states = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder], [input_ids] + )[0] + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + if train_text_encoder: + params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters()) + else: + params_to_clip = unet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(text_encoder), + accelerator.unwrap_model(unet), + vae, + ) + + current_loss = loss.detach().item() + if args.logging_dir is not None: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + # checking for saving is in util + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(text_encoder), + accelerator.unwrap_model(unet), + vae, + ) + + train_util.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) + + is_main_process = accelerator.is_main_process + if is_main_process: + unet = accelerator.unwrap_model(unet) + text_encoder = accelerator.unwrap_model(text_encoder) + + accelerator.end_training() + + if is_main_process and (args.save_state or args.save_state_on_train_end): + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_train_end( + args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae + ) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, False, True) + train_util.add_training_arguments(parser, True) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + + parser.add_argument( + "--learning_rate_te", + type=float, + default=None, + help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ", + ) + parser.add_argument( + "--no_token_padding", + action="store_true", + help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)", + ) + parser.add_argument( + "--stop_text_encoder_training", + type=int, + default=None, + help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない", + ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/train_network.py b/train_network.py new file mode 100644 index 0000000000000000000000000000000000000000..6f8a3476f1a1ae253d9b8d155dfb147c7c82459c --- /dev/null +++ b/train_network.py @@ -0,0 +1,1500 @@ +import importlib +import argparse +import math +import os +import sys +import random +import time +import json +from multiprocessing import Value +from typing import Any, List +import toml + +from tqdm import tqdm +# from comfy.utils import ProgressBar + +import torch +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate.utils import set_seed +from accelerate import Accelerator +from diffusers import DDPMScheduler +from library import deepspeed_utils, model_util, strategy_base, strategy_sd + +from library import train_util as train_util +from library.train_util import DreamBoothDataset +from library import config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from library import huggingface_util as huggingface_util +from library import custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + apply_snr_weight, + get_weighted_text_embeddings, + prepare_scheduler_for_custom_training, + scale_v_prediction_loss_like_noise_prediction, + add_v_prediction_like_loss, + apply_debiased_estimation, + apply_masked_loss, +) +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class NetworkTrainer: + def __init__(self): + self.vae_scale_factor = 0.18215 + self.is_sdxl = False + + # TODO 他のスクリプトと共通化する + def generate_step_logs( + self, + args: argparse.Namespace, + current_loss, + avr_loss, + lr_scheduler, + lr_descriptions, + optimizer=None, + keys_scaled=None, + mean_norm=None, + maximum_norm=None, + ): + logs = {"loss/current": current_loss, "loss/average": avr_loss} + + if keys_scaled is not None: + logs["max_norm/keys_scaled"] = keys_scaled + logs["max_norm/average_key_norm"] = mean_norm + logs["max_norm/max_key_norm"] = maximum_norm + + lrs = lr_scheduler.get_last_lr() + for i, lr in enumerate(lrs): + if lr_descriptions is not None: + lr_desc = lr_descriptions[i] + else: + idx = i - (0 if args.network_train_unet_only else -1) + if idx == -1: + lr_desc = "textencoder" + else: + if len(lrs) > 2: + lr_desc = f"group{idx}" + else: + lr_desc = "unet" + + logs[f"lr/{lr_desc}"] = lr + + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): + # tracking d*lr value + logs[f"lr/d*lr/{lr_desc}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] + ) + if ( + args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None + ): # tracking d*lr value of unet. + logs["lr/d*lr"] = ( + optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + ) + else: + idx = 0 + if not args.network_train_unet_only: + logs["lr/textencoder"] = float(lrs[0]) + idx = 1 + + for i in range(idx, len(lrs)): + logs[f"lr/group{i}"] = float(lrs[i]) + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): + logs[f"lr/d*lr/group{i}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] + ) + if ( + args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None + ): + logs[f"lr/d*lr/group{i}"] = ( + optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] + ) + + return logs + + def assert_extra_args(self, args, train_dataset_group): + train_dataset_group.verify_bucket_reso_steps(64) + + def load_target_model(self, args, weight_dtype, accelerator): + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) + + # Incorporate xformers or memory efficient attention into the model + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if torch.__version__ >= "2.0.0": # If you have xformers compatible with PyTorch 2.0.0 or higher, you can use the following + vae.set_use_memory_efficient_attention_xformers(args.xformers) + + return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet + + def get_tokenize_strategy(self, args): + return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]: + return [tokenize_strategy.tokenizer] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sd.SdTextEncodingStrategy(args.clip_skip) + + def get_text_encoder_outputs_caching_strategy(self, args): + return None + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + """ + Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models. + FLUX.1 and SD3 may cache some outputs of the text encoder, so return the models that will be used for encoding (not cached). + """ + return text_encoders + + # returns a list of bool values indicating whether each text encoder should be trained + def get_text_encoders_train_flags(self, args, text_encoders): + return [True] * len(text_encoders) if self.is_train_text_encoder(args) else [False] * len(text_encoders) + + def is_train_text_encoder(self, args): + return not args.network_train_unet_only + + def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype): + for t_enc in text_encoders: + t_enc.to(accelerator.device, dtype=weight_dtype) + + def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs): + noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample + return noise_pred + + def all_reduce_network(self, accelerator, network): + for param in network.parameters(): + if param.grad is not None: + param.grad = accelerator.reduce(param.grad, reduction="mean") + + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet): + train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet) + + # region SD/SDXL + + def post_process_network(self, args, accelerator, network, text_encoders, unet): + pass + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images).latent_dist.sample() + + def shift_scale_latents(self, args, latents): + return latents * self.vae_scale_factor + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + ): + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + for x in noisy_latents: + x.requires_grad_(True) + for t in text_encoder_conds: + t.requires_grad_(True) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = self.call_unet( + args, + accelerator, + unet, + noisy_latents.requires_grad_(train_unet), + timesteps, + text_encoder_conds, + batch, + weight_dtype, + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + noise_pred_prior = self.call_unet( + args, + accelerator, + unet, + noisy_latents, + timesteps, + text_encoder_conds, + batch, + weight_dtype, + indices=diff_output_pr_indices, + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) + + return noise_pred, target, timesteps, None + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False) + + def update_metadata(self, metadata, args): + pass + + def is_text_encoder_not_needed_for_training(self, args): + return False # use for sample images + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + # set top parameter requires_grad = True for gradient checkpointing works + text_encoder.text_model.embeddings.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + return accelerator.prepare(unet) + + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + pass + + # endregion + + def init_train(self, args): + session_id = random.randint(0, 2**32) + training_started_at = time.time() + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + tokenize_strategy = self.get_tokenize_strategy(args) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = self.get_latents_caching_strategy(args) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # pbar = ProgressBar(5) + + # Prepare the dataset + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if use_user_config: + logger.info(f"Loading dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignoring the following options because config file is found: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + # use arbitrary dataset class + train_dataset_group = train_util.load_arbitrary_dataset(args) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + if args.debug_dataset: + train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + self.assert_extra_args(args, train_dataset_group) # may change some args + + # prepare accelerator + logger.info("preparing accelerator") + accelerator = train_util.prepare_accelerator(args) + + + # Prepare a type that supports mixed precision and cast it as appropriate. + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # Load the model + model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator) + + # text_encoder is List[CLIPTextModel] or CLIPTextModel + text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] + + # pbar.update(1) + + # Load the model for incremental learning + #sys.path.append(os.path.dirname(__file__)) + accelerator.print("import network module:", args.network_module) + package = __name__.split('.')[0] + network_module = importlib.import_module(args.network_module, package=package) + + if args.base_weights is not None: + # base_weights が指定されている場合は、指定された重みを読み込みマージする + for i, weight_path in enumerate(args.base_weights): + if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i: + multiplier = 1.0 + else: + multiplier = args.base_weights_multiplier[i] + + accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}") + + module, weights_sd = network_module.create_network_from_weights( + multiplier, weight_path, vae, text_encoder, unet, for_inference=True + ) + module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") + + accelerator.print(f"all weights merged: {', '.join(args.base_weights)}") + + + # cache latents + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + + train_dataset_group.new_cache_latents(vae, accelerator) + + vae.to("cpu") + clean_memory_on_device(accelerator.device) + + # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu + text_encoding_strategy = self.get_text_encoding_strategy(args) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + text_encoder_outputs_caching_strategy = self.get_text_encoder_outputs_caching_strategy(args) + if text_encoder_outputs_caching_strategy is not None: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy) + self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype) + + # pbar.update(1) + + # prepare network + net_kwargs = {} + if args.network_args is not None: + for net_arg in args.network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value + + # if a new network is added in future, add if ~ then blocks for each network (;'∀') + if args.dim_from_weights: + network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs) + else: + if "dropout" not in net_kwargs: + # workaround for LyCORIS (;^ω^) + net_kwargs["dropout"] = args.network_dropout + + network = network_module.create_network( + 1.0, + args.network_dim, + args.network_alpha, + vae, + text_encoder, + unet, + neuron_dropout=args.network_dropout, + **net_kwargs, + ) + if network is None: + return + network_has_multiplier = hasattr(network, "set_multiplier") + + if hasattr(network, "prepare_network"): + network.prepare_network(args) + if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"): + logger.warning( + "warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません" + ) + args.scale_weight_norms = False + + self.post_process_network(args, accelerator, network, text_encoders, unet) + + # apply network to unet and text_encoder + train_unet = not args.network_train_text_encoder_only + train_text_encoder = self.is_train_text_encoder(args) + network.apply_to(text_encoder, unet, train_text_encoder, train_unet) + + if args.network_weights is not None: + # FIXME consider alpha of weights: this assumes that the alpha is not changed + info = network.load_weights(args.network_weights) + accelerator.print(f"load network weights from {args.network_weights}: {info}") + + if args.gradient_checkpointing: + if args.cpu_offload_checkpointing: + unet.enable_gradient_checkpointing(cpu_offload=True) + else: + unet.enable_gradient_checkpointing() + + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): + if flag: + if t_enc.supports_gradient_checkpointing: + t_enc.gradient_checkpointing_enable() + del t_enc + network.enable_gradient_checkpointing() # may have no effect + + # Prepare classes necessary for learning + accelerator.print("prepare optimizer, data loader etc.") + + # make backward compatibility for text_encoder_lr + support_multiple_lrs = hasattr(network, "prepare_optimizer_params_with_multiple_te_lrs") + if support_multiple_lrs: + text_encoder_lr = args.text_encoder_lr + else: + # toml backward compatibility + if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float) or isinstance(args.text_encoder_lr, int): + text_encoder_lr = args.text_encoder_lr + else: + text_encoder_lr = None if len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] + try: + if support_multiple_lrs: + results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate) + else: + results = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr, args.learning_rate) + if type(results) is tuple: + trainable_params = results[0] + lr_descriptions = results[1] + else: + trainable_params = results + lr_descriptions = None + except TypeError as e: + trainable_params = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr) + lr_descriptions = None + + # if len(trainable_params) == 0: + # accelerator.print("no trainable parameters found / 学習可能なパラメータが見つかりませんでした") + # for params in trainable_params: + # for k, v in params.items(): + # if type(v) == float: + # pass + # else: + # v = len(v) + # accelerator.print(f"trainable_params: {k} = {v}") + + optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + self.optimizer_train_fn, self.optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # Send learning steps to the dataset side as well + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr scheduler init + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # Experimental function: performs fp16/bf16 learning including gradients, sets the entire model to fp16/bf16 + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + network.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + network.to(weight_dtype) + + unet_weight_dtype = te_weight_dtype = weight_dtype + # Experimental Feature: Put base model into fp8 to save vram + if args.fp8_base or args.fp8_base_unet: + assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0" + assert ( + args.mixed_precision != "no" + ), "fp8_base requires mixed precision='fp16' or 'bf16'" + accelerator.print("enable fp8 training for U-Net.") + unet_weight_dtype = torch.float8_e4m3fn if args.fp8_dtype == "e4m3" else torch.float8_e5m2 + accelerator.print(f"unet_weight_dtype: {unet_weight_dtype}") + + if not args.fp8_base_unet and not args.network_train_unet_only: + accelerator.print("enable fp8 training for Text Encoder.") + te_weight_dtype = torch.float8_e4m3fn if args.fp8_dtype == "e4m3" else torch.float8_e5m2 + + # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM + # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory + + # logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}") + # unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above + logger.info(f"set U-Net weight dtype to {unet_weight_dtype}") + unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator + + unet.requires_grad_(False) + unet.to(dtype=unet_weight_dtype) + for i, t_enc in enumerate(text_encoders): + t_enc.requires_grad_(False) + + # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 + if t_enc.device.type != "cpu": + t_enc.to(dtype=te_weight_dtype) + + # nn.Embedding not support FP8 + if te_weight_dtype != weight_dtype: + self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good + if args.deepspeed: + flags = self.get_text_encoders_train_flags(args, text_encoders) + ds_model = deepspeed_utils.prepare_deepspeed_model( + args, + unet=unet if train_unet else None, + text_encoder1=text_encoders[0] if flags[0] else None, + text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None, + network=network, + ) + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_model = ds_model + else: + if train_unet: + # default implementation is: unet = accelerator.prepare(unet) + unet = self.prepare_unet_with_accelerator(args, accelerator, unet) # accelerator does some magic here + else: + unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator + if train_text_encoder: + text_encoders = [ + (accelerator.prepare(t_enc) if flag else t_enc) + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)) + ] + if len(text_encoders) > 1: + text_encoder = text_encoders + else: + text_encoder = text_encoders[0] + else: + pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set + + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + network, optimizer, train_dataloader, lr_scheduler + ) + training_model = network + + if args.gradient_checkpointing: + # according to TI example in Diffusers, train is required + unet.train() + for i, (t_enc, frag) in enumerate(zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))): + t_enc.train() + + # set top parameter requires_grad = True for gradient checkpointing works + if frag: + self.prepare_text_encoder_grad_ckpt_workaround(i, t_enc) + + else: + unet.eval() + for t_enc in text_encoders: + t_enc.eval() + + del t_enc + + accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet) + + if not cache_latents: # If you do not cache, VAE will be used, so enable VAE preparation. + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + # Experimental feature: Perform fp16 learning including gradients Apply a patch to PyTorch to enable grad scale in fp16 + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # pbar.update(1) + + # before resuming make hook for saving/loading to save/load the network weights only + def save_model_hook(models, weights, output_dir): + # pop weights of other models than network to save only network weights + # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606 + #if args.deepspeed: + remove_indices = [] + for i, model in enumerate(models): + if not isinstance(model, type(accelerator.unwrap_model(network))): + remove_indices.append(i) + for i in reversed(remove_indices): + if len(weights) > i: + weights.pop(i) + # print(f"save model hook: {len(weights)} weights will be saved") + + # save current ecpoch and step + train_state_file = os.path.join(output_dir, "train_state.json") + # +1 is needed because the state is saved before current_step is set from global_step + logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}") + with open(train_state_file, "w", encoding="utf-8") as f: + json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f) + + steps_from_state = None + + def load_model_hook(models, input_dir): + # remove models except network + remove_indices = [] + for i, model in enumerate(models): + if not isinstance(model, type(accelerator.unwrap_model(network))): + remove_indices.append(i) + for i in reversed(remove_indices): + models.pop(i) + # print(f"load model hook: {len(models)} models will be loaded") + + # load current epoch and step to + nonlocal steps_from_state + train_state_file = os.path.join(input_dir, "train_state.json") + if os.path.exists(train_state_file): + with open(train_state_file, "r", encoding="utf-8") as f: + data = json.load(f) + steps_from_state = data["current_step"] + logger.info(f"load train state from {train_state_file}: {data}") + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # resume from local or huggingface + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # pbar.update(1) + + # Calculate the number of epochs + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + accelerator.print("running training") + accelerator.print(f" num train images * repeats: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch: {len(train_dataloader)}") + accelerator.print(f" num epochs: {num_train_epochs}") + accelerator.print( + f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + accelerator.print(f" gradient accumulation steps: {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps: {args.max_train_steps}") + + # TODO refactor metadata creation and move to util + metadata = { + "ss_session_id": session_id, # random integer indicating which group of epochs the model came from + "ss_training_started_at": training_started_at, # unix timestamp + "ss_output_name": args.output_name, + "ss_learning_rate": args.learning_rate, + "ss_text_encoder_lr": text_encoder_lr, + "ss_unet_lr": args.unet_lr, + "ss_num_train_images": train_dataset_group.num_train_images, + "ss_num_reg_images": train_dataset_group.num_reg_images, + "ss_num_batches_per_epoch": len(train_dataloader), + "ss_num_epochs": num_train_epochs, + "ss_gradient_checkpointing": args.gradient_checkpointing, + "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, + "ss_max_train_steps": args.max_train_steps, + "ss_lr_warmup_steps": args.lr_warmup_steps, + "ss_lr_scheduler": args.lr_scheduler, + "ss_network_module": args.network_module, + "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim + "ss_network_alpha": args.network_alpha, # some networks may not have alpha + "ss_network_dropout": args.network_dropout, # some networks may not have dropout + "ss_mixed_precision": args.mixed_precision, + "ss_full_fp16": bool(args.full_fp16), + "ss_v2": bool(args.v2), + "ss_base_model_version": model_version, + "ss_clip_skip": args.clip_skip, + "ss_max_token_length": args.max_token_length, + "ss_cache_latents": bool(args.cache_latents), + "ss_seed": args.seed, + "ss_lowram": args.lowram, + "ss_noise_offset": args.noise_offset, + "ss_multires_noise_iterations": args.multires_noise_iterations, + "ss_multires_noise_discount": args.multires_noise_discount, + "ss_adaptive_noise_scale": args.adaptive_noise_scale, + "ss_zero_terminal_snr": args.zero_terminal_snr, + "ss_training_comment": args.training_comment, # will not be updated after training + "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), + "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), + "ss_max_grad_norm": args.max_grad_norm, + "ss_caption_dropout_rate": args.caption_dropout_rate, + "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs, + "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate, + "ss_face_crop_aug_range": args.face_crop_aug_range, + "ss_prior_loss_weight": args.prior_loss_weight, + "ss_min_snr_gamma": args.min_snr_gamma, + "ss_scale_weight_norms": args.scale_weight_norms, + "ss_ip_noise_gamma": args.ip_noise_gamma, + "ss_debiased_estimation": bool(args.debiased_estimation_loss), + "ss_noise_offset_random_strength": args.noise_offset_random_strength, + "ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength, + "ss_loss_type": args.loss_type, + "ss_huber_schedule": args.huber_schedule, + "ss_huber_scale": args.huber_scale, + "ss_huber_c": args.huber_c, + "ss_fp8_base": bool(args.fp8_base), + "ss_fp8_base_unet": bool(args.fp8_base_unet), + } + + self.update_metadata(metadata, args) # architecture specific metadata + + if use_user_config: + # save metadata of multiple datasets + # NOTE: pack "ss_datasets" value as json one time + # or should also pack nested collections as json? + datasets_metadata = [] + tag_frequency = {} # merge tag frequency for metadata editor + dataset_dirs_info = {} # merge subset dirs for metadata editor + + for dataset in train_dataset_group.datasets: + is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset) + dataset_metadata = { + "is_dreambooth": is_dreambooth_dataset, + "batch_size_per_device": dataset.batch_size, + "num_train_images": dataset.num_train_images, # includes repeating + "num_reg_images": dataset.num_reg_images, + "resolution": (dataset.width, dataset.height), + "enable_bucket": bool(dataset.enable_bucket), + "min_bucket_reso": dataset.min_bucket_reso, + "max_bucket_reso": dataset.max_bucket_reso, + "tag_frequency": dataset.tag_frequency, + "bucket_info": dataset.bucket_info, + } + + subsets_metadata = [] + for subset in dataset.subsets: + subset_metadata = { + "img_count": subset.img_count, + "num_repeats": subset.num_repeats, + "color_aug": bool(subset.color_aug), + "flip_aug": bool(subset.flip_aug), + "random_crop": bool(subset.random_crop), + "shuffle_caption": bool(subset.shuffle_caption), + "keep_tokens": subset.keep_tokens, + "keep_tokens_separator": subset.keep_tokens_separator, + "secondary_separator": subset.secondary_separator, + "enable_wildcard": bool(subset.enable_wildcard), + "caption_prefix": subset.caption_prefix, + "caption_suffix": subset.caption_suffix, + } + + image_dir_or_metadata_file = None + if subset.image_dir: + image_dir = os.path.basename(subset.image_dir) + subset_metadata["image_dir"] = image_dir + image_dir_or_metadata_file = image_dir + + if is_dreambooth_dataset: + subset_metadata["class_tokens"] = subset.class_tokens + subset_metadata["is_reg"] = subset.is_reg + if subset.is_reg: + image_dir_or_metadata_file = None # not merging reg dataset + else: + metadata_file = os.path.basename(subset.metadata_file) + subset_metadata["metadata_file"] = metadata_file + image_dir_or_metadata_file = metadata_file # may overwrite + + subsets_metadata.append(subset_metadata) + + # merge dataset dir: not reg subset only + # TODO update additional-network extension to show detailed dataset config from metadata + if image_dir_or_metadata_file is not None: + # datasets may have a certain dir multiple times + v = image_dir_or_metadata_file + i = 2 + while v in dataset_dirs_info: + v = image_dir_or_metadata_file + f" ({i})" + i += 1 + image_dir_or_metadata_file = v + + dataset_dirs_info[image_dir_or_metadata_file] = { + "n_repeats": subset.num_repeats, + "img_count": subset.img_count, + } + + dataset_metadata["subsets"] = subsets_metadata + datasets_metadata.append(dataset_metadata) + + # merge tag frequency: + for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items(): + # If a directory is used by multiple datasets, count only once + # Since the number of repetitions is originally specified, the number of times a tag appears in the caption does not match the number of times it is used in training. + # Therefore, it is not very meaningful to add up the number of times for multiple datasets here. + if ds_dir_name in tag_frequency: + continue + tag_frequency[ds_dir_name] = ds_freq_for_dir + + metadata["ss_datasets"] = json.dumps(datasets_metadata) + metadata["ss_tag_frequency"] = json.dumps(tag_frequency) + metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info) + else: + # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir + assert ( + len(train_dataset_group.datasets) == 1 + ), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug." + + dataset = train_dataset_group.datasets[0] + + dataset_dirs_info = {} + reg_dataset_dirs_info = {} + if use_dreambooth_method: + for subset in dataset.subsets: + info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info + info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count} + else: + for subset in dataset.subsets: + dataset_dirs_info[os.path.basename(subset.metadata_file)] = { + "n_repeats": subset.num_repeats, + "img_count": subset.img_count, + } + + metadata.update( + { + "ss_batch_size_per_device": args.train_batch_size, + "ss_total_batch_size": total_batch_size, + "ss_resolution": args.resolution, + "ss_color_aug": bool(args.color_aug), + "ss_flip_aug": bool(args.flip_aug), + "ss_random_crop": bool(args.random_crop), + "ss_shuffle_caption": bool(args.shuffle_caption), + "ss_enable_bucket": bool(dataset.enable_bucket), + "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), + "ss_min_bucket_reso": dataset.min_bucket_reso, + "ss_max_bucket_reso": dataset.max_bucket_reso, + "ss_keep_tokens": args.keep_tokens, + "ss_dataset_dirs": json.dumps(dataset_dirs_info), + "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), + "ss_tag_frequency": json.dumps(dataset.tag_frequency), + "ss_bucket_info": json.dumps(dataset.bucket_info), + } + ) + + # add extra args + if args.network_args: + metadata["ss_network_args"] = json.dumps(net_kwargs) + + # model name and hash + if args.pretrained_model_name_or_path is not None: + sd_model_name = args.pretrained_model_name_or_path + if os.path.exists(sd_model_name): + metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) + metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name) + sd_model_name = os.path.basename(sd_model_name) + metadata["ss_sd_model_name"] = sd_model_name + + if args.vae is not None: + vae_name = args.vae + if os.path.exists(vae_name): + metadata["ss_vae_hash"] = train_util.model_hash(vae_name) + metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name) + vae_name = os.path.basename(vae_name) + metadata["ss_vae_name"] = vae_name + + metadata = {k: str(v) for k, v in metadata.items()} + + # make minimum metadata for filtering + minimum_metadata = {} + for key in train_util.SS_METADATA_MINIMUM_KEYS: + if key in metadata: + minimum_metadata[key] = metadata[key] + + # calculate steps to skip when resuming or starting from a specific step + initial_step = 0 + if args.initial_epoch is not None or args.initial_step is not None: + # if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming + if steps_from_state is not None: + logger.warning( + "steps from the state is ignored because initial_step is specified" + ) + if args.initial_step is not None: + initial_step = args.initial_step + else: + # num steps per epoch is calculated by num_processes and gradient_accumulation_steps + initial_step = (args.initial_epoch - 1) * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + else: + # if initial_epoch and initial_step are not specified, steps_from_state is used when resuming + if steps_from_state is not None: + initial_step = steps_from_state + steps_from_state = None + + if initial_step > 0: + assert ( + args.max_train_steps > initial_step + ), f"max_train_steps should be greater than initial step: {args.max_train_steps} vs {initial_step}" + + epoch_to_start = 0 + if initial_step > 0: + if args.skip_until_initial_step: + # if skip_until_initial_step is specified, load data and discard it to ensure the same data is used + if not args.resume: + logger.info( + f"initial_step is specified but not resuming. lr scheduler will be started from the beginning" + ) + logger.info(f"skipping {initial_step} steps") + initial_step *= args.gradient_accumulation_steps + + # set epoch to start to make initial_step less than len(train_dataloader) + epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + else: + # if not, only epoch no is skipped for informative purpose + epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + initial_step = 0 # do not skip + + + + noise_scheduler = self.get_noise_scheduler(args, accelerator.device) + + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "network_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + self.loss_recorder = train_util.LossRecorder() + del train_dataset_group + + # pbar.update(1) + + # callback for step start + if hasattr(accelerator.unwrap_model(network), "on_step_start"): + on_step_start_for_network = accelerator.unwrap_model(network).on_step_start + else: + on_step_start_for_network = lambda *args, **kwargs: None + + # function for saving/removing + def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + metadata["ss_training_finished_at"] = str(time.time()) + metadata["ss_steps"] = str(steps) + metadata["ss_epoch"] = str(epoch_no) + + metadata_to_save = minimum_metadata if args.no_metadata else metadata + sai_metadata = self.get_sai_model_spec(args) + metadata_to_save.update(sai_metadata) + + unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + if self.is_text_encoder_not_needed_for_training(args): + logger.info("text_encoder is not needed for training. deleting to save memory.") + for t_enc in text_encoders: + del t_enc + text_encoders = [] + text_encoder = None + + # For --sample_at_first + #self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + + self.global_step = 0 + # training loop + if initial_step > 0: # only if skip_until_initial_step is specified + for skip_epoch in range(epoch_to_start): # skip epochs + logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") + initial_step -= len(train_dataloader) + + # log device and dtype for each model + logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}") + for i, t_enc in enumerate(text_encoders): + params_itr = t_enc.parameters() + params_itr.__next__() # skip the first parameter + params_itr.__next__() # skip the second parameter. because CLIP first two parameters are embeddings + param_3rd = params_itr.__next__() + logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}") + + clean_memory_on_device(accelerator.device) + + self.epoch_to_start = epoch_to_start + self.num_train_epochs = num_train_epochs + self.accelerator = accelerator + self.network = network + self.text_encoder = text_encoder + self.unet = unet + self.vae = vae + self.tokenizers = tokenizers + self.args = args + self.train_dataloader = train_dataloader + self.initial_step = initial_step + self.current_epoch = current_epoch + self.metadata = metadata + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.save_model = save_model + self.remove_model = remove_model + # self.comfy_pbar = None + + progress_bar = tqdm(range(args.max_train_steps - initial_step), smoothing=0, disable=False, desc="steps") + + def training_loop(break_at_steps, epoch): + steps_done = 0 + + #accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + progress_bar.set_description(f"Epoch {epoch + 1}/{num_train_epochs} - steps") + + current_epoch.value = epoch + 1 + + metadata["ss_epoch"] = str(epoch + 1) + + accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) + + skipped_dataloader = None + if self.initial_step > 0: + skipped_dataloader = accelerator.skip_first_batches(train_dataloader, self.initial_step - 1) + self.initial_step = 1 + + for step, batch in enumerate(skipped_dataloader or train_dataloader): + current_step.value = self.global_step + if self.initial_step > 0: + self.initial_step -= 1 + continue + + with accelerator.accumulate(training_model): + on_step_start_for_network(text_encoder, unet) + + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + with torch.no_grad(): + # encode latents + latents = self.encode_images_to_latents(args, accelerator, vae, batch["images"].to(vae_dtype)) + latents = latents.to(dtype=weight_dtype) + + # NaN check + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + latents = self.shift_scale_latents(args, latents) + + # get multiplier for each sample + if network_has_multiplier: + multipliers = batch["network_multipliers"] + # if all multipliers are same, use single multiplier + if torch.all(multipliers == multipliers[0]): + multipliers = multipliers[0].item() + else: + raise NotImplementedError("multipliers for each sample is not supported yet") + # print(f"set multiplier: {multipliers}") + accelerator.unwrap_model(network).set_multiplier(multipliers) + + text_encoder_conds = [] + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs + + if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: + # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' + with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): + # Get the text embedding for conditioning + if args.weighted_captions: + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + input_ids_list, + weights_list, + ) + else: + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + input_ids, + ) + if args.full_fp16: + encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] + + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + + # sample noise, call unet, get target + noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + ) + + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # weight for each sample + loss = loss * loss_weights + + # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. + loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + + loss = loss.mean() # No need to divide by batch_size since it's an average + + accelerator.backward(loss) + if accelerator.sync_gradients: + self.all_reduce_network(accelerator, network) # sync DDP grad manually + if args.max_grad_norm != 0.0: + params_to_clip = accelerator.unwrap_model(network).get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + if args.scale_weight_norms: + keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization( + args.scale_weight_norms, accelerator.device + ) + max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} + else: + keys_scaled, mean_norm, maximum_norm = None, None, None + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + self.global_step += 1 + + current_loss = loss.detach().item() + self.loss_recorder.add(epoch=epoch, step=step, global_step=self.global_step, loss=current_loss) + avr_loss: float = self.loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if args.scale_weight_norms: + progress_bar.set_postfix(**{**max_mean_logs, **logs}) + + if len(accelerator.trackers) > 0: + logs = self.generate_step_logs( + args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm + ) + accelerator.log(logs, step=self.global_step) + + if self.global_step >= break_at_steps: + break + steps_done += 1 + # self.comfy_pbar.update(1) + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": self.loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + return steps_done + + + + return training_loop + + # metadata["ss_epoch"] = str(num_train_epochs) + # metadata["ss_training_finished_at"] = str(time.time()) + + # network = accelerator.unwrap_model(network) + + # accelerator.end_training() + + # if (args.save_state or args.save_state_on_train_end): + # train_util.save_state_on_train_end(args, accelerator) + + # ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + # save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) + + # logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, True) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + train_util.add_dit_training_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing for U-Net or DiT, if supported" + " / 勾配チェックポイント時にテンソルをCPUにオフロードする(U-NetまたはDiTのみ、サポートされている場合)", + ) + parser.add_argument( + "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない" + ) + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) + + parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") + parser.add_argument( + "--text_encoder_lr", + type=float, + default=None, + nargs="*", + help="learning rate for Text Encoder, can be multiple / Text Encoderの学習率、複数指定可能", + ) + parser.add_argument( + "--fp8_base_unet", + action="store_true", + help="use fp8 for U-Net (or DiT), Text Encoder is fp16 or bf16" + " / U-Net(またはDiT)にfp8を使用する。Text Encoderはfp16またはbf16", + ) + + parser.add_argument( + "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" + ) + parser.add_argument( + "--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール" + ) + parser.add_argument( + "--network_dim", + type=int, + default=None, + help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)", + ) + parser.add_argument( + "--network_alpha", + type=float, + default=1, + help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)", + ) + parser.add_argument( + "--network_dropout", + type=float, + default=None, + help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)", + ) + parser.add_argument( + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", + ) + parser.add_argument( + "--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する" + ) + parser.add_argument( + "--network_train_text_encoder_only", + action="store_true", + help="only training Text Encoder part / Text Encoder関連部分のみ学習する", + ) + parser.add_argument( + "--training_comment", + type=str, + default=None, + help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列", + ) + parser.add_argument( + "--dim_from_weights", + action="store_true", + help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する", + ) + parser.add_argument( + "--scale_weight_norms", + type=float, + default=None, + help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)", + ) + parser.add_argument( + "--base_weights", + type=str, + default=None, + nargs="*", + help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル", + ) + parser.add_argument( + "--base_weights_multiplier", + type=float, + default=None, + nargs="*", + help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率", + ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) + parser.add_argument( + "--skip_until_initial_step", + action="store_true", + help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする", + ) + parser.add_argument( + "--initial_epoch", + type=int, + default=None, + help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`." + + " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる", + ) + parser.add_argument( + "--initial_step", + type=int, + default=None, + help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch." + + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする", + ) + # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") + # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") + # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = NetworkTrainer() + trainer.train(args) diff --git a/update.js b/update.js new file mode 100644 index 0000000000000000000000000000000000000000..f98498574cadd258d4e00726d831a8469d6cd8d8 --- /dev/null +++ b/update.js @@ -0,0 +1,46 @@ +module.exports = { + run: [{ + method: "shell.run", + params: { + message: "git pull" + } + }, { + method: "shell.run", + params: { + path: "sd-scripts", + message: "git pull" + } + }, { + method: "shell.run", + params: { + path: "sd-scripts", + venv: "../env", + message: [ + "uv pip install -r requirements.txt", + ] + } + }, { + method: "shell.run", + params: { + venv: "env", + message: [ + "pip uninstall -y diffusers[torch] torch torchaudio torchvision", + "uv pip install -r requirements.txt", + ] + } + }, { + method: "script.start", + params: { + uri: "torch.js", + params: { + venv: "env", + // xformers: true // uncomment this line if your project requires xformers + } + } + }, { + method: "fs.link", + params: { + venv: "env" + } + }] +}